Skip to content

Commit af27639

Browse files
authored
Affine if simplification (#258)
* Affine if simplification * fix and add test
1 parent f73f194 commit af27639

File tree

2 files changed

+160
-12
lines changed

2 files changed

+160
-12
lines changed

lib/polygeist/Ops.cpp

Lines changed: 99 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2665,38 +2665,125 @@ struct AffineIfSinking : public OpRewritePattern<AffineIfOp> {
26652665
auto par = dyn_cast<AffineParallelOp>(op->getParentOp());
26662666
if (!par)
26672667
return failure();
2668+
if (!isa<AffineYieldOp>(op->getNextNode()))
2669+
return failure();
2670+
2671+
bool failed = false;
2672+
op->walk([&](Operation *sub) {
2673+
if (sub != op) {
2674+
for (auto oper : sub->getOperands()) {
2675+
if (par.getRegion().isAncestor(((Value)oper).getParentRegion()) &&
2676+
!op.thenRegion().isAncestor(((Value)oper).getParentRegion())) {
2677+
failed = true;
2678+
return;
2679+
}
2680+
}
2681+
}
2682+
});
2683+
if (failed)
2684+
return failure();
2685+
26682686
if (par.getSteps().size() != op.getIntegerSet().getConstraints().size())
26692687
return failure();
2688+
26702689
for (auto cst : llvm::enumerate(op.getIntegerSet().getConstraints())) {
26712690
auto opd = cst.value().dyn_cast<AffineDimExpr>();
2672-
if (!opd)
2691+
if (!opd) {
2692+
opd = (-cst.value()).dyn_cast<AffineDimExpr>();
2693+
}
2694+
if (!opd) {
26732695
return failure();
2674-
if (op.getOperands()[opd.getPosition()] != par.getIVs()[cst.index()])
2696+
}
2697+
if (op.getOperands()[opd.getPosition()] != par.getIVs()[cst.index()]) {
26752698
return failure();
2676-
if (!op.getIntegerSet().isEq(cst.index()))
2699+
}
2700+
if (!op.getIntegerSet().isEq(cst.index())) {
26772701
return failure();
2702+
}
26782703

26792704
for (auto lb : par.getLowerBoundMap(cst.index()).getResults()) {
26802705
auto opd = lb.dyn_cast<AffineConstantExpr>();
2681-
if (!opd)
2706+
if (!opd) {
26822707
return failure();
2683-
if (opd.getValue() > 0)
2708+
}
2709+
if (opd.getValue() > 0) {
26842710
return failure();
2711+
}
26852712
}
26862713
for (auto ub : par.getUpperBoundMap(cst.index()).getResults()) {
26872714
auto opd = ub.dyn_cast<AffineConstantExpr>();
2688-
if (!opd)
2715+
if (!opd) {
26892716
return failure();
2690-
if (opd.getValue() <= 0)
2717+
}
2718+
if (opd.getValue() <= 0) {
26912719
return failure();
2720+
}
2721+
}
2722+
}
2723+
2724+
rewriter.eraseOp(op.getThenBlock()->getTerminator());
2725+
rewriter.mergeBlockBefore(op.getThenBlock(), par->getNextNode());
2726+
rewriter.eraseOp(op);
2727+
return success();
2728+
}
2729+
};
2730+
2731+
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
2732+
Region &region, ValueRange blockArgs = {}) {
2733+
assert(llvm::hasSingleElement(region) && "expected single-region block");
2734+
Block *block = &region.front();
2735+
Operation *terminator = block->getTerminator();
2736+
ValueRange results = terminator->getOperands();
2737+
rewriter.mergeBlockBefore(block, op, blockArgs);
2738+
rewriter.replaceOp(op, results);
2739+
rewriter.eraseOp(terminator);
2740+
}
2741+
2742+
struct AffineIfSimplification : public OpRewritePattern<AffineIfOp> {
2743+
using OpRewritePattern<AffineIfOp>::OpRewritePattern;
2744+
2745+
LogicalResult matchAndRewrite(AffineIfOp op,
2746+
PatternRewriter &rewriter) const override {
2747+
SmallVector<AffineExpr> todo;
2748+
bool knownFalse = false;
2749+
bool removed = false;
2750+
for (auto cst : llvm::enumerate(op.getIntegerSet().getConstraints())) {
2751+
auto opd = cst.value().dyn_cast<AffineConstantExpr>();
2752+
if (!opd) {
2753+
todo.push_back(cst.value());
2754+
continue;
2755+
}
2756+
removed = true;
2757+
2758+
if (op.getIntegerSet().isEq(cst.index())) {
2759+
if (opd.getValue() != 0) {
2760+
knownFalse = true;
2761+
break;
2762+
}
2763+
}
2764+
if (!(opd.getValue() >= 0)) {
2765+
knownFalse = true;
2766+
break;
26922767
}
26932768
}
2694-
if (isa<AffineYieldOp>(op->getNextNode())) {
2695-
rewriter.eraseOp(op.getThenBlock()->getTerminator());
2696-
rewriter.mergeBlockBefore(op.getThenBlock(), par->getNextNode());
2697-
rewriter.eraseOp(op);
2769+
2770+
if (knownFalse) {
2771+
todo.clear();
2772+
}
2773+
2774+
if (todo.size() == 0) {
2775+
2776+
if (!knownFalse)
2777+
replaceOpWithRegion(rewriter, op, op.thenRegion());
2778+
else if (!op.elseRegion().empty())
2779+
replaceOpWithRegion(rewriter, op, op.elseRegion());
2780+
else
2781+
rewriter.eraseOp(op);
2782+
26982783
return success();
26992784
}
2785+
// TODO can reduce the number of conditions, even if cannot eliminate
2786+
// entirely.
27002787
return failure();
27012788
}
27022789
};
@@ -2709,7 +2796,7 @@ void TypeAlignOp::getCanonicalizationPatterns(RewritePatternSet &results,
27092796
AlwaysAllocaScopeHoister<memref::AllocaScopeOp>,
27102797
AlwaysAllocaScopeHoister<scf::ForOp>,
27112798
AlwaysAllocaScopeHoister<AffineForOp>, ConstantRankReduction,
2712-
AffineIfSinking,
2799+
AffineIfSinking, AffineIfSimplification,
27132800
// RankReduction<memref::AllocaOp, scf::ParallelOp>,
27142801
AggressiveAllocaScopeInliner, InductiveVarRemoval>(context);
27152802
}

test/polygeist-opt/ifsink.mlir

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// RUN: polygeist-opt --canonicalize --allow-unregistered-dialect --split-input-file %s | FileCheck %s
2+
3+
#set0 = affine_set<(d0) : (-d0 == 0)>
4+
#set1 = affine_set<(d0) : (d0 == 0)>
5+
module {
6+
func.func @bpnn_train_cuda() {
7+
affine.parallel (%arg7) = (0) to (16) {
8+
"test.pre"() : () -> ()
9+
affine.if #set0(%arg7) {
10+
%a = "test.create"() : () -> i32
11+
"test.use"(%a) : (i32) -> ()
12+
}
13+
}
14+
return
15+
}
16+
func.func @bpnn_train_cuda1() {
17+
affine.parallel (%arg7) = (0) to (16) {
18+
"test.pre"() : () -> ()
19+
affine.if #set1(%arg7) {
20+
%a = "test.create"() : () -> i32
21+
"test.use"(%a) : (i32) -> ()
22+
}
23+
}
24+
return
25+
}
26+
func.func @bpnn_train_cuda2() {
27+
affine.parallel (%arg7) = (0) to (16) {
28+
%a = "test.create"() : () -> i32
29+
affine.if #set1(%arg7) {
30+
"test.use"(%a) : (i32) -> ()
31+
}
32+
}
33+
return
34+
}
35+
}
36+
37+
// CHECK: func.func @bpnn_train_cuda() {
38+
// CHECK-NEXT: affine.parallel (%arg0) = (0) to (16) {
39+
// CHECK-NEXT: "test.pre"() : () -> ()
40+
// CHECK-NEXT: }
41+
// CHECK-NEXT: %0 = "test.create"() : () -> i32
42+
// CHECK-NEXT: "test.use"(%0) : (i32) -> ()
43+
// CHECK-NEXT: return
44+
// CHECK-NEXT: }
45+
// CHECK: func.func @bpnn_train_cuda1() {
46+
// CHECK-NEXT: affine.parallel (%arg0) = (0) to (16) {
47+
// CHECK-NEXT: "test.pre"() : () -> ()
48+
// CHECK-NEXT: }
49+
// CHECK-NEXT: %0 = "test.create"() : () -> i32
50+
// CHECK-NEXT: "test.use"(%0) : (i32) -> ()
51+
// CHECK-NEXT: return
52+
// CHECK-NEXT: }
53+
// CHECK: func.func @bpnn_train_cuda2() {
54+
// CHECK-NEXT: affine.parallel (%arg0) = (0) to (16) {
55+
// CHECK-NEXT: %0 = "test.create"() : () -> i32
56+
// CHECK-NEXT: affine.if #set(%arg0) {
57+
// CHECK-NEXT: "test.use"(%0) : (i32) -> ()
58+
// CHECK-NEXT: }
59+
// CHECK-NEXT: }
60+
// CHECK-NEXT: return
61+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)