Skip to content

Commit be9ee42

Browse files
authored
Affraise (#261)
* Affraise improve * Handle ext on affraise
1 parent 2dded82 commit be9ee42

File tree

2 files changed

+112
-20
lines changed

2 files changed

+112
-20
lines changed

lib/polygeist/Passes/AffineCFG.cpp

Lines changed: 83 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,31 +26,61 @@ using namespace polygeist;
2626
bool isReadOnly(Operation *op);
2727
bool aboveEq(Value, int64_t);
2828

29+
bool isValidSymbolInt(Value value, bool recur = true);
30+
bool isValidSymbolInt(Operation *defOp, bool recur) {
31+
Attribute operandCst;
32+
if (matchPattern(defOp, m_Constant(&operandCst)))
33+
return true;
34+
35+
if (recur) {
36+
if (isa<SelectOp, IndexCastOp, AddIOp, MulIOp, DivSIOp, DivUIOp, RemSIOp,
37+
RemUIOp, SubIOp, CmpIOp, TruncIOp, ExtUIOp, ExtSIOp>(defOp))
38+
if (llvm::all_of(defOp->getOperands(), [&](Value v) {
39+
bool b = isValidSymbolInt(v, recur);
40+
// if (!b)
41+
// LLVM_DEBUG(llvm::dbgs() << "illegal isValidSymbolInt: "
42+
//<< value << " due to " << v << "\n");
43+
return b;
44+
}))
45+
return true;
46+
if (auto ifOp = dyn_cast<scf::IfOp>(defOp)) {
47+
if (isValidSymbolInt(ifOp.getCondition(), recur)) {
48+
if (llvm::all_of(
49+
ifOp.thenBlock()->without_terminator(),
50+
[&](Operation &o) { return isValidSymbolInt(&o, recur); }) &&
51+
llvm::all_of(
52+
ifOp.elseBlock()->without_terminator(),
53+
[&](Operation &o) { return isValidSymbolInt(&o, recur); }))
54+
return true;
55+
}
56+
}
57+
if (auto ifOp = dyn_cast<AffineIfOp>(defOp)) {
58+
if (llvm::all_of(ifOp.getOperands(),
59+
[&](Value o) { return isValidSymbolInt(o, recur); }))
60+
if (llvm::all_of(
61+
ifOp.getThenBlock()->without_terminator(),
62+
[&](Operation &o) { return isValidSymbolInt(&o, recur); }) &&
63+
llvm::all_of(
64+
ifOp.getElseBlock()->without_terminator(),
65+
[&](Operation &o) { return isValidSymbolInt(&o, recur); }))
66+
return true;
67+
}
68+
}
69+
return false;
70+
}
71+
2972
// isValidSymbol, even if not index
30-
bool isValidSymbolInt(Value value, bool recur = true) {
73+
bool isValidSymbolInt(Value value, bool recur) {
3174
// Check that the value is a top level value.
3275
if (isTopLevelValue(value))
3376
return true;
3477

3578
if (auto *defOp = value.getDefiningOp()) {
36-
Attribute operandCst;
37-
if (matchPattern(defOp, m_Constant(&operandCst)))
79+
if (isValidSymbolInt(defOp, recur))
3880
return true;
39-
40-
if (recur) {
41-
if (isa<SelectOp, IndexCastOp, AddIOp, MulIOp, DivSIOp, DivUIOp, RemSIOp,
42-
RemUIOp, SubIOp, CmpIOp>(defOp))
43-
if (llvm::all_of(defOp->getOperands(), [&](Value v) {
44-
bool b = isValidSymbolInt(v, true);
45-
// if (!b)
46-
// LLVM_DEBUG(llvm::dbgs() << "illegal isValidSymbolInt: "
47-
//<< value << " due to " << v << "\n");
48-
return b;
49-
}))
50-
return true;
51-
}
5281
return isValidSymbol(value, getAffineScope(defOp));
5382
}
83+
5484
return false;
5585
}
5686

@@ -189,7 +219,22 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map,
189219
return nullptr;
190220
}
191221
Operation *front = nullptr;
192-
for (auto o : op->getOperands()) {
222+
SmallVector<Value> ops;
223+
std::function<void(Operation *)> getAllOps = [&](Operation *todo) {
224+
for (auto v : todo->getOperands()) {
225+
if (llvm::all_of(op->getRegions(), [&](Region &r) {
226+
return !r.isAncestor(v.getParentRegion());
227+
}))
228+
ops.push_back(v);
229+
}
230+
for (auto &r : todo->getRegions()) {
231+
for (auto &b : r.getBlocks())
232+
for (auto &o2 : b.without_terminator())
233+
getAllOps(&o2);
234+
}
235+
};
236+
getAllOps(op);
237+
for (auto o : ops) {
193238
Operation *next;
194239
if (auto *op = o.getDefiningOp()) {
195240
if (Value nv = fix(o, index)) {
@@ -234,9 +279,22 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map,
234279
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
235280
auto t = operands[i];
236281
auto decast = t;
237-
while (auto idx = decast.getDefiningOp<IndexCastOp>()) {
238-
decast = idx.getIn();
282+
while (true) {
283+
if (auto idx = decast.getDefiningOp<IndexCastOp>()) {
284+
decast = idx.getIn();
285+
continue;
286+
}
287+
if (auto idx = decast.getDefiningOp<ExtUIOp>()) {
288+
decast = idx.getIn();
289+
continue;
290+
}
291+
if (auto idx = decast.getDefiningOp<ExtSIOp>()) {
292+
decast = idx.getIn();
293+
continue;
294+
}
295+
break;
239296
}
297+
240298
if (!isValidSymbolInt(t, /*recur*/ false)) {
241299
t = decast;
242300
}
@@ -895,6 +953,12 @@ bool isValidIndex(Value val) {
895953
if (auto cast = val.getDefiningOp<IndexCastOp>())
896954
return isValidIndex(cast.getOperand());
897955

956+
if (auto cast = val.getDefiningOp<ExtSIOp>())
957+
return isValidIndex(cast.getOperand());
958+
959+
if (auto cast = val.getDefiningOp<ExtUIOp>())
960+
return isValidIndex(cast.getOperand());
961+
898962
if (auto bop = val.getDefiningOp<AddIOp>())
899963
return isValidIndex(bop.getOperand(0)) && isValidIndex(bop.getOperand(1));
900964

test/polygeist-opt/affinecfg.mlir

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: polygeist-opt --affine-cfg --split-input-file %s | FileCheck %s
1+
// RUN: polygeist-opt --affine-cfg --split-input-file --allow-unregistered-dialect %s | FileCheck %s
22
module {
33
func.func @_Z7runTestiPPc(%arg0: index, %arg2: memref<?xi32>) {
44
%c0_i32 = arith.constant 0 : i32
@@ -135,3 +135,31 @@ func.func @_Z7runTestiPPc(%arg0: i32, %39: memref<?xi32>, %arg1: !llvm.ptr<i8>)
135135
// CHECK-NEXT: }
136136
// CHECK-NEXT: return
137137
// CHECK-NEXT: }
138+
139+
// -----
140+
141+
module {
142+
func.func @c(%71: memref<?xf32>, %39: i64) {
143+
affine.parallel (%arg2, %arg3) = (0, 0) to (42, 512) {
144+
%262 = arith.index_cast %arg2 : index to i32
145+
%a264 = arith.extsi %262 : i32 to i64
146+
%268 = arith.cmpi slt, %a264, %39 : i64
147+
scf.if %268 {
148+
"test.something"() : () -> ()
149+
}
150+
}
151+
return
152+
}
153+
}
154+
155+
// CHECK: #set = affine_set<(d0)[s0] : (-d0 + s0 - 1 >= 0)>
156+
// CHECK: func.func @c(%arg0: memref<?xf32>, %arg1: i64) {
157+
// CHECK-NEXT: %0 = arith.index_cast %arg1 : i64 to index
158+
// CHECK-NEXT: affine.parallel (%arg2, %arg3) = (0, 0) to (42, 512) {
159+
// CHECK-NEXT: affine.if #set(%arg2)[%0] {
160+
// CHECK-NEXT: "test.something"() : () -> ()
161+
// CHECK-NEXT: }
162+
// CHECK-NEXT: }
163+
// CHECK-NEXT: return
164+
// CHECK-NEXT: }
165+

0 commit comments

Comments
 (0)