Skip to content

Commit f348c1e

Browse files
authored
Fix some parfor tests (#181)
1 parent 210d8ae commit f348c1e

File tree

3 files changed

+112
-26
lines changed

3 files changed

+112
-26
lines changed

azure-pipelines.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ jobs:
4141
- script: |
4242
call "C:\Miniconda\Scripts\activate"
4343
call cd numba_dpcomp
44-
pytest -n1 -vv --capture=tee-sys
44+
pytest -n1 -vv --capture=tee-sys -rX
4545
displayName: 'Tests'
4646
4747
- script: |
@@ -127,7 +127,7 @@ jobs:
127127
export SYCL_DEVICE_FILTER=opencl:cpu
128128
export NUMBA_DISABLE_PERFORMANCE_WARNINGS=1
129129
export DPCOMP_ENABLE_PARFOR_TESTS=1
130-
pytest -n1 -vv --capture=tee-sys
130+
pytest -n1 -vv --capture=tee-sys -rX
131131
displayName: 'Tests'
132132
133133
- script: |
@@ -197,7 +197,7 @@ jobs:
197197
source /usr/local/miniconda/bin/activate
198198
cd numba_dpcomp
199199
conda activate test_env
200-
pytest -n1 -vv --capture=tee-sys
200+
pytest -n1 -vv --capture=tee-sys -rX
201201
displayName: 'Tests'
202202
203203
- script: |

mlir/lib/dialect/plier_util/dialect.cpp

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1416,6 +1416,95 @@ struct SignCastMemrefSubviewPropagate
14161416
}
14171417
};
14181418

1419+
struct SignCastForPropagate : public mlir::OpRewritePattern<mlir::scf::ForOp> {
1420+
using OpRewritePattern::OpRewritePattern;
1421+
1422+
mlir::LogicalResult
1423+
matchAndRewrite(mlir::scf::ForOp op,
1424+
mlir::PatternRewriter &rewriter) const override {
1425+
auto &body = op.getLoopBody().front();
1426+
auto term = mlir::cast<mlir::scf::YieldOp>(body.getTerminator());
1427+
auto termResults = term.getResults();
1428+
auto initArgs = op.getInitArgs();
1429+
auto count = static_cast<unsigned>(initArgs.size());
1430+
1431+
assert(termResults.size() == count);
1432+
llvm::SmallVector<mlir::Value> newInitArgs(count);
1433+
bool needUpdate = false;
1434+
for (auto i : llvm::seq(0u, count)) {
1435+
auto initArg = initArgs[i];
1436+
auto yieldArg = termResults[i];
1437+
assert(initArg.getType() == yieldArg.getType());
1438+
auto initCast = initArg.getDefiningOp<plier::SignCastOp>();
1439+
auto yieldCast = yieldArg.getDefiningOp<plier::SignCastOp>();
1440+
if (initCast && yieldCast &&
1441+
initCast.value().getType() == yieldCast.value().getType()) {
1442+
newInitArgs[i] = initCast.value();
1443+
needUpdate = true;
1444+
} else {
1445+
newInitArgs[i] = initArg;
1446+
}
1447+
}
1448+
1449+
if (!needUpdate)
1450+
return mlir::failure();
1451+
1452+
auto bodyBuilder = [&](mlir::OpBuilder &builder, mlir::Location loc,
1453+
mlir::Value iter, mlir::ValueRange iterVals) {
1454+
assert(iterVals.size() == count);
1455+
mlir::BlockAndValueMapping mapping;
1456+
mapping.map(body.getArguments()[0], iter);
1457+
auto oldIterVals = body.getArguments().drop_front(1);
1458+
for (auto i : llvm::seq(0u, count)) {
1459+
auto iterVal = iterVals[i];
1460+
auto oldIterVal = oldIterVals[i];
1461+
auto oldType = oldIterVal.getType();
1462+
if (iterVal.getType() != oldType) {
1463+
auto newIterVal =
1464+
builder.create<plier::SignCastOp>(loc, oldType, iterVal);
1465+
mapping.map(oldIterVal, newIterVal.getResult());
1466+
} else {
1467+
mapping.map(oldIterVal, iterVal);
1468+
}
1469+
}
1470+
1471+
for (auto &bodyOp : body.without_terminator())
1472+
builder.clone(bodyOp, mapping);
1473+
1474+
llvm::SmallVector<mlir::Value> newYieldArgs(count);
1475+
for (auto i : llvm::seq(0u, count)) {
1476+
auto val = mapping.lookupOrDefault(termResults[i]);
1477+
auto newType = newInitArgs[i].getType();
1478+
if (val.getType() != newType)
1479+
val = builder.create<plier::SignCastOp>(loc, newType, val);
1480+
1481+
newYieldArgs[i] = val;
1482+
}
1483+
builder.create<mlir::scf::YieldOp>(loc, newYieldArgs);
1484+
};
1485+
1486+
auto loc = op->getLoc();
1487+
auto newResults = rewriter
1488+
.create<mlir::scf::ForOp>(
1489+
loc, op.getLowerBound(), op.getUpperBound(),
1490+
op.getStep(), newInitArgs, bodyBuilder)
1491+
->getResults();
1492+
1493+
for (auto i : llvm::seq(0u, count)) {
1494+
auto oldRersultType = initArgs[i].getType();
1495+
mlir::Value newResult = newResults[i];
1496+
if (newResult.getType() != oldRersultType)
1497+
newResult =
1498+
rewriter.create<plier::SignCastOp>(loc, oldRersultType, newResult);
1499+
1500+
newInitArgs[i] = newResult;
1501+
}
1502+
1503+
rewriter.replaceOp(op, newInitArgs);
1504+
return mlir::success();
1505+
}
1506+
};
1507+
14191508
} // namespace
14201509

14211510
void SignCastOp::getCanonicalizationPatterns(::mlir::RewritePatternSet &results,
@@ -1428,7 +1517,7 @@ void SignCastOp::getCanonicalizationPatterns(::mlir::RewritePatternSet &results,
14281517
SignCastAllocPropagate<mlir::memref::AllocaOp>,
14291518
SignCastTensorFromElementsPropagate, SignCastTensorCollapseShapePropagate,
14301519
SignCastTensorToMemrefPropagate, SignCastMemrefToTensorPropagate,
1431-
SignCastMemrefSubviewPropagate>(context);
1520+
SignCastMemrefSubviewPropagate, SignCastForPropagate>(context);
14321521
}
14331522

14341523
void ExtractMemrefMetadataOp::build(::mlir::OpBuilder &odsBuilder,

numba_dpcomp/numba_dpcomp/mlir/tests/test_numba_parfor.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -52,33 +52,20 @@ def _gen_tests():
5252
]
5353

5454
xfail_tests = {
55-
"test_prange26",
56-
"test_prange03mul",
5755
"test_prange09",
5856
"test_prange03sub",
59-
"test_prange10",
60-
"test_prange03",
6157
"test_prange03div",
6258
"test_prange07",
6359
"test_prange06",
64-
"test_prange16",
6560
"test_prange12",
66-
"test_prange04",
67-
"test_prange13",
6861
"test_prange25",
69-
"test_prange21",
70-
"test_prange14",
7162
"test_prange18",
7263
"test_prange_nested_reduction1",
7364
"test_list_setitem_hoisting",
74-
"test_prange23",
75-
"test_prange24",
7665
"test_list_comprehension_prange",
77-
"test_prange22",
7866
"test_prange_raises_invalid_step_size",
7967
"test_issue7501",
8068
"test_parfor_race_1",
81-
"test_check_alias_analysis",
8269
"test_nested_parfor_push_call_vars",
8370
"test_record_array_setitem_yield_array",
8471
"test_record_array_setitem",
@@ -127,7 +114,6 @@ def _gen_tests():
127114
"test_namedtuple2",
128115
"test_simple19",
129116
"test_no_hoisting_with_member_function_call",
130-
"test_reduction_var_reuse",
131117
"test_parfor_dtype_type",
132118
"test_tuple3",
133119
"test_parfor_array_access3",
@@ -179,7 +165,6 @@ def _gen_tests():
179165
"test_parfor_generate_fuse",
180166
"test_parfor_slice7",
181167
"test_parfor_bitmask2",
182-
"test_parfor_alias2",
183168
"test_one_d_array_reduction",
184169
}
185170

@@ -217,19 +202,31 @@ def wrapper(*args, **kwargs):
217202
return wrapper
218203

219204
def _gen_parallel_fastmath(self, func):
205+
ops = (
206+
"fadd",
207+
"fsub",
208+
"fmul",
209+
"fdiv",
210+
"frem",
211+
"fcmp",
212+
)
213+
220214
def wrapper(*args, **kwargs):
221215
with print_pass_ir([], ["PostLLVMLowering"]):
222216
res = njit(parallel=True, fastmath=True)(func)(*args, **kwargs)
223217
ir = get_print_buffer()
224218
# Check some fastmath llvm flags were generated
225-
count = 0
219+
opCount = 0
220+
fastCount = 0
226221
for line in ir.splitlines():
227-
for op in ("fadd", "fsub", "fmul", "fdiv", "frem", "fcmp"):
228-
if line.count("llvm." + op) and line.count(
229-
"llvm.fastmath<fast>"
230-
):
231-
count += 1
232-
assert count > 0, ir
222+
for op in ops:
223+
if line.count("llvm." + op) > 0:
224+
opCount += 1
225+
if line.count("llvm.fastmath<fast>") > 0:
226+
fastCount += 1
227+
break
228+
if opCount > 0:
229+
assert fastCount > 0, it
233230
return res
234231

235232
return wrapper

0 commit comments

Comments
 (0)