@@ -26,31 +26,61 @@ using namespace polygeist;
26
26
bool isReadOnly (Operation *op);
27
27
bool aboveEq (Value, int64_t );
28
28
29
+ bool isValidSymbolInt (Value value, bool recur = true );
30
+ bool isValidSymbolInt (Operation *defOp, bool recur) {
31
+ Attribute operandCst;
32
+ if (matchPattern (defOp, m_Constant (&operandCst)))
33
+ return true ;
34
+
35
+ if (recur) {
36
+ if (isa<SelectOp, IndexCastOp, AddIOp, MulIOp, DivSIOp, DivUIOp, RemSIOp,
37
+ RemUIOp, SubIOp, CmpIOp, TruncIOp, ExtUIOp, ExtSIOp>(defOp))
38
+ if (llvm::all_of (defOp->getOperands (), [&](Value v) {
39
+ bool b = isValidSymbolInt (v, recur);
40
+ // if (!b)
41
+ // LLVM_DEBUG(llvm::dbgs() << "illegal isValidSymbolInt: "
42
+ // << value << " due to " << v << "\n");
43
+ return b;
44
+ }))
45
+ return true ;
46
+ if (auto ifOp = dyn_cast<scf::IfOp>(defOp)) {
47
+ if (isValidSymbolInt (ifOp.getCondition (), recur)) {
48
+ if (llvm::all_of (
49
+ ifOp.thenBlock ()->without_terminator (),
50
+ [&](Operation &o) { return isValidSymbolInt (&o, recur); }) &&
51
+ llvm::all_of (
52
+ ifOp.elseBlock ()->without_terminator (),
53
+ [&](Operation &o) { return isValidSymbolInt (&o, recur); }))
54
+ return true ;
55
+ }
56
+ }
57
+ if (auto ifOp = dyn_cast<AffineIfOp>(defOp)) {
58
+ if (llvm::all_of (ifOp.getOperands (),
59
+ [&](Value o) { return isValidSymbolInt (o, recur); }))
60
+ if (llvm::all_of (
61
+ ifOp.getThenBlock ()->without_terminator (),
62
+ [&](Operation &o) { return isValidSymbolInt (&o, recur); }) &&
63
+ llvm::all_of (
64
+ ifOp.getElseBlock ()->without_terminator (),
65
+ [&](Operation &o) { return isValidSymbolInt (&o, recur); }))
66
+ return true ;
67
+ }
68
+ }
69
+ return false ;
70
+ }
71
+
29
72
// isValidSymbol, even if not index
30
- bool isValidSymbolInt (Value value, bool recur = true ) {
73
+ bool isValidSymbolInt (Value value, bool recur) {
31
74
// Check that the value is a top level value.
32
75
if (isTopLevelValue (value))
33
76
return true ;
34
77
35
78
if (auto *defOp = value.getDefiningOp ()) {
36
- Attribute operandCst;
37
- if (matchPattern (defOp, m_Constant (&operandCst)))
79
+ if (isValidSymbolInt (defOp, recur))
38
80
return true ;
39
-
40
- if (recur) {
41
- if (isa<SelectOp, IndexCastOp, AddIOp, MulIOp, DivSIOp, DivUIOp, RemSIOp,
42
- RemUIOp, SubIOp, CmpIOp>(defOp))
43
- if (llvm::all_of (defOp->getOperands (), [&](Value v) {
44
- bool b = isValidSymbolInt (v, true );
45
- // if (!b)
46
- // LLVM_DEBUG(llvm::dbgs() << "illegal isValidSymbolInt: "
47
- // << value << " due to " << v << "\n");
48
- return b;
49
- }))
50
- return true ;
51
- }
52
81
return isValidSymbol (value, getAffineScope (defOp));
53
82
}
83
+
54
84
return false ;
55
85
}
56
86
@@ -189,7 +219,22 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map,
189
219
return nullptr ;
190
220
}
191
221
Operation *front = nullptr ;
192
- for (auto o : op->getOperands ()) {
222
+ SmallVector<Value> ops;
223
+ std::function<void (Operation *)> getAllOps = [&](Operation *todo) {
224
+ for (auto v : todo->getOperands ()) {
225
+ if (llvm::all_of (op->getRegions (), [&](Region &r) {
226
+ return !r.isAncestor (v.getParentRegion ());
227
+ }))
228
+ ops.push_back (v);
229
+ }
230
+ for (auto &r : todo->getRegions ()) {
231
+ for (auto &b : r.getBlocks ())
232
+ for (auto &o2 : b.without_terminator ())
233
+ getAllOps (&o2);
234
+ }
235
+ };
236
+ getAllOps (op);
237
+ for (auto o : ops) {
193
238
Operation *next;
194
239
if (auto *op = o.getDefiningOp ()) {
195
240
if (Value nv = fix (o, index)) {
@@ -234,9 +279,22 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map,
234
279
for (unsigned i = 0 , e = operands.size (); i < e; ++i) {
235
280
auto t = operands[i];
236
281
auto decast = t;
237
- while (auto idx = decast.getDefiningOp <IndexCastOp>()) {
238
- decast = idx.getIn ();
282
+ while (true ) {
283
+ if (auto idx = decast.getDefiningOp <IndexCastOp>()) {
284
+ decast = idx.getIn ();
285
+ continue ;
286
+ }
287
+ if (auto idx = decast.getDefiningOp <ExtUIOp>()) {
288
+ decast = idx.getIn ();
289
+ continue ;
290
+ }
291
+ if (auto idx = decast.getDefiningOp <ExtSIOp>()) {
292
+ decast = idx.getIn ();
293
+ continue ;
294
+ }
295
+ break ;
239
296
}
297
+
240
298
if (!isValidSymbolInt (t, /* recur*/ false )) {
241
299
t = decast;
242
300
}
@@ -895,6 +953,12 @@ bool isValidIndex(Value val) {
895
953
if (auto cast = val.getDefiningOp <IndexCastOp>())
896
954
return isValidIndex (cast.getOperand ());
897
955
956
+ if (auto cast = val.getDefiningOp <ExtSIOp>())
957
+ return isValidIndex (cast.getOperand ());
958
+
959
+ if (auto cast = val.getDefiningOp <ExtUIOp>())
960
+ return isValidIndex (cast.getOperand ());
961
+
898
962
if (auto bop = val.getDefiningOp <AddIOp>())
899
963
return isValidIndex (bop.getOperand (0 )) && isValidIndex (bop.getOperand (1 ));
900
964
0 commit comments