Skip to content

Commit 210d8ae

Browse files
authored
Fix parfor tests (#180)
1 parent 58a125c commit 210d8ae

File tree

5 files changed

+128
-40
lines changed

5 files changed

+128
-40
lines changed

dpcomp_runtime/lib/tbb_parallel.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,20 +126,23 @@ static void run_parallel_for(const InputRange *input_ranges, size_t depth,
126126
auto runFunc = [&](Dim *current) {
127127
auto thread_index =
128128
static_cast<size_t>(tbb::this_task_arena::current_thread_index());
129+
assert(thread_index >= 0);
129130
std::array<Range, 8> static_ranges;
130131
std::unique_ptr<Range[]> dyn_ranges;
131132
auto *range_ptr = [&]() -> Range * {
132-
if (num_loops <= static_ranges.size()) {
133+
if (num_loops <= static_ranges.size())
133134
return static_ranges.data();
134-
}
135+
135136
dyn_ranges.reset(new Range[num_loops]);
136137
return dyn_ranges.get();
137138
}();
138139

139140
for (size_t i = 0; i < num_loops; ++i) {
141+
assert(current);
140142
range_ptr[num_loops - i - 1] = current->val;
141143
current = current->prev;
142144
}
145+
143146
if (DEBUG) {
144147
std::lock_guard<std::mutex> lock(getDebugMutext());
145148
fprintf(stderr, "parallel_for func: thread_index=%d",
@@ -174,7 +177,7 @@ static void run_parallel_for(const InputRange *input_ranges, size_t depth,
174177
runFunc(prev);
175178
} else {
176179
auto next = depth + N;
177-
parallel_for_nested(input_ranges, next, num_threads, num_loops, prev_dim,
180+
parallel_for_nested(input_ranges, next, num_threads, num_loops, prev,
178181
func, ctx);
179182
}
180183
};

mlir/lib/dialect/plier_util/dialect.cpp

Lines changed: 96 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <mlir/IR/BuiltinTypes.h>
2121
#include <mlir/IR/DialectImplementation.h>
2222
#include <mlir/IR/Dominance.h>
23+
#include <mlir/IR/Matchers.h>
2324
#include <mlir/IR/PatternMatch.h>
2425
#include <mlir/Transforms/InliningUtils.h>
2526

@@ -269,14 +270,91 @@ struct GenGlobalId : public mlir::OpRewritePattern<mlir::arith::AddIOp> {
269270
return mlir::failure();
270271
}
271272
};
273+
274+
struct InvertCmpi : public mlir::OpRewritePattern<mlir::arith::CmpIOp> {
275+
using OpRewritePattern::OpRewritePattern;
276+
277+
mlir::LogicalResult
278+
matchAndRewrite(mlir::arith::CmpIOp op,
279+
mlir::PatternRewriter &rewriter) const override {
280+
281+
if (!mlir::matchPattern(op.getLhs(), mlir::m_Constant()) ||
282+
mlir::matchPattern(op.getRhs(), mlir::m_Constant()))
283+
return mlir::failure();
284+
285+
using Pred = mlir::arith::CmpIPredicate;
286+
const std::pair<Pred, Pred> inv[] = {
287+
// clang-format off
288+
{Pred::slt, Pred::sgt},
289+
{Pred::sle, Pred::sge},
290+
{Pred::ult, Pred::ugt},
291+
{Pred::ule, Pred::uge},
292+
{Pred::eq, Pred::eq},
293+
{Pred::ne, Pred::ne},
294+
// clang-format on
295+
};
296+
297+
auto newPred = [&]() -> Pred {
298+
auto oldPred = op.getPredicate();
299+
for (auto it : inv) {
300+
if (it.first == oldPred)
301+
return it.second;
302+
if (it.second == oldPred)
303+
return it.first;
304+
}
305+
306+
llvm_unreachable("Unknown predicate");
307+
}();
308+
309+
rewriter.replaceOpWithNewOp<mlir::arith::CmpIOp>(op, newPred, op.getRhs(),
310+
op.getLhs());
311+
;
312+
return mlir::success();
313+
}
314+
};
315+
316+
struct ReshapeAlloca : public mlir::OpRewritePattern<mlir::memref::ReshapeOp> {
317+
using OpRewritePattern::OpRewritePattern;
318+
319+
mlir::LogicalResult
320+
matchAndRewrite(mlir::memref::ReshapeOp op,
321+
mlir::PatternRewriter &rewriter) const override {
322+
auto shapeOp = op.shape().getDefiningOp<mlir::memref::AllocOp>();
323+
if (!shapeOp)
324+
return mlir::failure();
325+
326+
for (auto user : shapeOp->getUsers())
327+
if (!mlir::isa<mlir::memref::StoreOp, mlir::memref::ReshapeOp>(user))
328+
return mlir::failure();
329+
330+
if (!shapeOp.dynamicSizes().empty() || !shapeOp.symbolOperands().empty())
331+
return mlir::failure();
332+
333+
auto func = op->getParentOfType<mlir::FuncOp>();
334+
if (!func)
335+
return mlir::failure();
336+
337+
if (shapeOp->getParentOp() != func) {
338+
rewriter.setInsertionPointToStart(&func.getBody().front());
339+
} else {
340+
rewriter.setInsertionPoint(shapeOp);
341+
}
342+
343+
auto type = shapeOp.getType().cast<mlir::MemRefType>();
344+
auto alignment = shapeOp.alignmentAttr().cast<mlir::IntegerAttr>();
345+
rewriter.replaceOpWithNewOp<mlir::memref::AllocaOp>(shapeOp, type,
346+
alignment);
347+
return mlir::success();
348+
}
349+
};
272350
} // namespace
273351

274352
void PlierUtilDialect::getCanonicalizationPatterns(
275353
mlir::RewritePatternSet &results) const {
276354
results.add<DimExpandShape<mlir::tensor::DimOp, mlir::tensor::ExpandShapeOp>,
277355
DimExpandShape<mlir::memref::DimOp, mlir::memref::ExpandShapeOp>,
278-
DimInsertSlice, FillExtractSlice, SpirvInputCSE, GenGlobalId>(
279-
getContext());
356+
DimInsertSlice, FillExtractSlice, SpirvInputCSE, GenGlobalId,
357+
InvertCmpi, ReshapeAlloca>(getContext());
280358
}
281359

282360
OpaqueType OpaqueType::get(mlir::MLIRContext *context) {
@@ -330,35 +408,35 @@ EnforceShapeOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
330408
}
331409

332410
namespace {
333-
struct EnforceShapeDim : public mlir::OpRewritePattern<mlir::memref::DimOp> {
334-
using mlir::OpRewritePattern<mlir::memref::DimOp>::OpRewritePattern;
411+
template <typename DimOp>
412+
struct EnforceShapeDim : public mlir::OpRewritePattern<DimOp> {
413+
using mlir::OpRewritePattern<DimOp>::OpRewritePattern;
335414

336415
mlir::LogicalResult
337-
matchAndRewrite(mlir::memref::DimOp op,
338-
mlir::PatternRewriter &rewriter) const override {
339-
auto enforce_op = mlir::dyn_cast_or_null<plier::EnforceShapeOp>(
340-
op.source().getDefiningOp());
341-
if (!enforce_op) {
416+
matchAndRewrite(DimOp op, mlir::PatternRewriter &rewriter) const override {
417+
auto enforceOp =
418+
op.source().template getDefiningOp<plier::EnforceShapeOp>();
419+
if (!enforceOp)
342420
return mlir::failure();
343-
}
344-
auto const_ind = plier::getConstVal<mlir::IntegerAttr>(op.index());
345-
if (!const_ind) {
421+
422+
auto constInd = mlir::getConstantIntValue(op.index());
423+
if (!constInd)
346424
return mlir::failure();
347-
}
348-
auto index = const_ind.getInt();
349-
if (index < 0 || index >= static_cast<int64_t>(enforce_op.sizes().size())) {
425+
426+
auto index = *constInd;
427+
if (index < 0 || index >= static_cast<int64_t>(enforceOp.sizes().size()))
350428
return mlir::failure();
351-
}
352429

353-
rewriter.replaceOp(op, enforce_op.sizes()[static_cast<unsigned>(index)]);
430+
rewriter.replaceOp(op, enforceOp.sizes()[static_cast<unsigned>(index)]);
354431
return mlir::success();
355432
}
356433
};
357434
} // namespace
358435

359436
void EnforceShapeOp::getCanonicalizationPatterns(
360437
::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context) {
361-
results.insert<EnforceShapeDim>(context);
438+
results.insert<EnforceShapeDim<mlir::tensor::DimOp>,
439+
EnforceShapeDim<mlir::memref::DimOp>>(context);
362440
}
363441

364442
mlir::LogicalResult

numba_dpcomp/numba_dpcomp/mlir/tests/test_numba_parfor.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def _gen_tests():
5252
]
5353

5454
xfail_tests = {
55+
"test_prange26",
5556
"test_prange03mul",
5657
"test_prange09",
5758
"test_prange03sub",
@@ -182,12 +183,7 @@ def _gen_tests():
182183
"test_one_d_array_reduction",
183184
}
184185

185-
skip_tests = {
186-
"test_copy_global_for_parfor",
187-
"test_high_dimension1",
188-
"test_three_d_array_reduction",
189-
"test_prange26",
190-
}
186+
skip_tests = {}
191187

192188
def countParfors(test_func, args, **kws):
193189
pytest.xfail()

numba_dpcomp/numba_dpcomp/mlir_compiler/lib/pipelines/plier_to_linalg.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2851,8 +2851,6 @@ static void populatePlierToLinalgOptPipeline(mlir::OpPassManager &pm) {
28512851
mlir::bufferization::createBufferHoistingPass());
28522852
pm.addNestedPass<mlir::FuncOp>(
28532853
mlir::bufferization::createBufferLoopHoistingPass());
2854-
pm.addNestedPass<mlir::FuncOp>(
2855-
mlir::bufferization::createPromoteBuffersToStackPass());
28562854

28572855
pm.addNestedPass<mlir::FuncOp>(std::make_unique<CloneArgsPass>());
28582856
pm.addPass(std::make_unique<MakeStridedLayoutPass>());
@@ -2863,6 +2861,8 @@ static void populatePlierToLinalgOptPipeline(mlir::OpPassManager &pm) {
28632861
pm.addPass(mlir::createCanonicalizerPass());
28642862

28652863
pm.addNestedPass<mlir::FuncOp>(std::make_unique<LowerCloneOpsPass>());
2864+
pm.addNestedPass<mlir::FuncOp>(
2865+
mlir::bufferization::createPromoteBuffersToStackPass());
28662866

28672867
pm.addPass(std::make_unique<LowerLinalgPass>());
28682868
pm.addPass(plier::createForceInlinePass());

numba_dpcomp/numba_dpcomp/mlir_compiler/lib/pipelines/plier_to_scf.cpp

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -432,17 +432,17 @@ struct ScfWhileRewrite : public mlir::OpRewritePattern<mlir::cf::BranchOp> {
432432
}
433433
}
434434

435-
llvm::transform(
436-
beforeBlock->getArguments(), yieldVars.begin(),
437-
[&](mlir::Value val) { return mapper.lookupOrDefault(val); });
435+
llvm::transform(beforeBlockArgs, yieldVars.begin(), [&](mlir::Value val) {
436+
return mapper.lookupOrDefault(val);
437+
});
438438

439-
auto term =
440-
mlir::cast<mlir::cf::CondBranchOp>(beforeBlock->getTerminator());
441-
for (auto arg : term.getFalseDestOperands()) {
439+
auto falseArgs = beforeTerm.getFalseDestOperands();
440+
for (auto arg : falseArgs) {
442441
origVars.emplace_back(arg);
443442
yieldVars.emplace_back(mapper.lookupOrDefault(arg));
444443
}
445-
auto cond = mapper.lookupOrDefault(term.getCondition());
444+
445+
auto cond = mapper.lookupOrDefault(beforeTerm.getCondition());
446446
builder.create<mlir::scf::ConditionOp>(loc, cond, yieldVars);
447447
};
448448
auto afterBody = [&](mlir::OpBuilder &builder, mlir::Location loc,
@@ -465,11 +465,22 @@ struct ScfWhileRewrite : public mlir::OpRewritePattern<mlir::cf::BranchOp> {
465465
beforeBody, afterBody);
466466

467467
assert(origVars.size() == whileOp.getNumResults());
468-
for (auto arg : llvm::zip(origVars, whileOp.getResults()))
469-
std::get<0>(arg).replaceAllUsesWith(std::get<1>(arg));
468+
for (auto arg : llvm::zip(origVars, whileOp.getResults())) {
469+
auto origVal = std::get<0>(arg);
470+
for (auto &use : llvm::make_early_inc_range(origVal.getUses())) {
471+
auto *owner = use.getOwner();
472+
auto *block = owner->getBlock();
473+
if (block != &whileOp.getBefore().front() &&
474+
block != &whileOp.getAfter().front()) {
475+
auto newVal = std::get<1>(arg);
476+
rewriter.updateRootInPlace(owner, [&]() { use.set(newVal); });
477+
}
478+
}
479+
}
470480

471-
rewriter.replaceOpWithNewOp<mlir::cf::BranchOp>(
472-
op, postBlock, beforeTerm.getFalseDestOperands());
481+
auto results =
482+
whileOp.getResults().take_back(beforeTerm.getNumFalseOperands());
483+
rewriter.replaceOpWithNewOp<mlir::cf::BranchOp>(op, postBlock, results);
473484
return mlir::success();
474485
}
475486
};

0 commit comments

Comments
 (0)