@@ -42,21 +42,24 @@ class SCFLoop {
4242 int64_t getStep () { return step; }
4343 mlir::Value getLowerBound () { return lowerBound; }
4444 mlir::Value getUpperBound () { return upperBound; }
45+ bool isCanonical () { return canonical; }
4546
46- int64_t findStepAndIV (mlir::Value &addr);
47+ // Returns true if successfully finds both step and induction variable.
48+ mlir::LogicalResult findStepAndIV ();
4749 cir::CmpOp findCmpOp ();
4850 mlir::Value findIVInitValue ();
4951 void analysis ();
5052
51- mlir::Value plusConstant (mlir::Value V , mlir::Location loc, int addend);
53+ mlir::Value plusConstant (mlir::Value v , mlir::Location loc, int addend);
5254 void transferToSCFForOp ();
5355
5456private:
5557 cir::ForOp forOp;
5658 cir::CmpOp cmpOp;
57- mlir::Value IVAddr , lowerBound = nullptr , upperBound = nullptr ;
59+ mlir::Value ivAddr , lowerBound = nullptr , upperBound = nullptr ;
5860 mlir::ConversionPatternRewriter *rewriter;
5961 int64_t step = 0 ;
62+ bool canonical = true ;
6063};
6164
6265class SCFWhileLoop {
@@ -86,47 +89,97 @@ class SCFDoLoop {
8689};
8790
8891static int64_t getConstant (cir::ConstantOp op) {
89- auto attr = op-> getAttrs (). front () .getValue ();
90- const auto IntAttr = mlir::dyn_cast <cir::IntAttr>(attr);
91- return IntAttr .getValue ().getSExtValue ();
92+ auto attr = op.getValue ();
93+ const auto intAttr = mlir::cast <cir::IntAttr>(attr);
94+ return intAttr .getValue ().getSExtValue ();
9295}
9396
94- int64_t SCFLoop::findStepAndIV (mlir::Value &addr ) {
97+ mlir::LogicalResult SCFLoop::findStepAndIV () {
9598 auto *stepBlock =
9699 (forOp.maybeGetStep () ? &forOp.maybeGetStep ()->front () : nullptr );
97100 assert (stepBlock && " Can not find step block" );
98101
99- int64_t step = 0 ;
100- mlir::Value IV = nullptr ;
101- // Try to match "IV load addr; ++IV; store IV, addr" to find step.
102- for (mlir::Operation &op : *stepBlock)
103- if (auto loadOp = dyn_cast<cir::LoadOp>(op)) {
104- addr = loadOp.getAddr ();
105- IV = loadOp.getResult ();
106- } else if (auto cop = dyn_cast<cir::ConstantOp>(op)) {
107- if (step)
108- llvm_unreachable (
109- " Not support multiple constant in step calculation yet" );
110- step = getConstant (cop);
111- } else if (auto bop = dyn_cast<cir::BinOp>(op)) {
112- if (bop.getLhs () != IV)
113- llvm_unreachable (" Find BinOp not operate on IV" );
114- if (bop.getKind () != cir::BinOpKind::Add)
115- llvm_unreachable (
116- " Not support BinOp other than Add in step calculation yet" );
117- } else if (auto uop = dyn_cast<cir::UnaryOp>(op)) {
118- if (uop.getInput () != IV)
119- llvm_unreachable (" Find UnaryOp not operate on IV" );
120- if (uop.getKind () == cir::UnaryOpKind::Inc)
121- step = 1 ;
122- else if (uop.getKind () == cir::UnaryOpKind::Dec)
123- llvm_unreachable (" Not support decrement step yet" );
124- } else if (auto storeOp = dyn_cast<cir::StoreOp>(op)) {
125- assert (storeOp.getAddr () == addr && " Can't find IV when lowering ForOp" );
126- }
127- assert (step && " Can't find step when lowering ForOp" );
102+ // Try to match "iv = load addr; ++iv; store iv, addr; yield" to find step.
103+ // We should match the exact pattern, in case there's something unexpected:
104+ // we must rule out cases like `for (int i = 0; i < n; i++, printf("\n"))`.
105+ auto &oplist = stepBlock->getOperations ();
106+
107+ auto iterator = oplist.begin ();
108+
109+ // We might find constants at beginning. Skip them.
110+ // We could have hoisted them outside the for loop in previous passes, but
111+ // it hasn't been done yet.
112+ while (iterator != oplist.end () && isa<ConstantOp>(*iterator))
113+ ++iterator;
114+
115+ if (iterator == oplist.end ())
116+ return mlir::failure ();
117+
118+ auto load = dyn_cast<LoadOp>(*iterator);
119+ if (!load)
120+ return mlir::failure ();
121+
122+ // We assume this is the address of induction variable (IV). The operations
123+ // that come next will check if that's true.
124+ mlir::Value addr = load.getAddr ();
125+ mlir::Value iv = load.getResult ();
126+
127+ // Then we try to match either "++IV" or "IV += n". Same for reversed loops.
128+ if (++iterator == oplist.end ())
129+ return mlir::failure ();
130+
131+ mlir::Operation &arith = *iterator;
132+
133+ if (auto unary = dyn_cast<UnaryOp>(arith)) {
134+ // Not operating on induction variable. Fail.
135+ if (unary.getInput () != iv)
136+ return mlir::failure ();
137+
138+ if (unary.getKind () == UnaryOpKind::Inc)
139+ step = 1 ;
140+ else if (unary.getKind () == UnaryOpKind::Dec)
141+ step = -1 ;
142+ else
143+ return mlir::failure ();
144+ }
145+
146+ if (auto binary = dyn_cast<BinOp>(arith)) {
147+ if (binary.getLhs () != iv)
148+ return mlir::failure ();
149+
150+ mlir::Value value = binary.getRhs ();
151+ if (auto constValue = dyn_cast<ConstantOp>(value.getDefiningOp ());
152+ isa<IntAttr>(constValue.getValue ()))
153+ step = getConstant (constValue);
154+
155+ if (binary.getKind () == BinOpKind::Add)
156+ ; // Nothing to do. Step has been calculated above.
157+ else if (binary.getKind () == BinOpKind::Sub)
158+ step = -step;
159+ else
160+ return mlir::failure ();
161+ }
162+
163+ // Check whether we immediately store this value into the appropriate place.
164+ if (++iterator == oplist.end ())
165+ return mlir::failure ();
166+
167+ auto store = dyn_cast<StoreOp>(*iterator);
168+ if (!store || store.getAddr () != addr ||
169+ store.getValue () != arith.getResult (0 ))
170+ return mlir::failure ();
171+
172+ if (++iterator == oplist.end ())
173+ return mlir::failure ();
128174
129- return step;
175+ // Finally, this should precede a yield with nothing in between.
176+ bool success = isa<YieldOp>(*iterator);
177+
178+ // Remember to update analysis information.
179+ if (success)
180+ ivAddr = addr;
181+
182+ return success ? mlir::success () : mlir::failure ();
130183}
131184
132185static bool isIVLoad (mlir::Operation *op, mlir::Value IVAddr) {
@@ -143,7 +196,7 @@ static bool isIVLoad(mlir::Operation *op, mlir::Value IVAddr) {
143196
144197cir::CmpOp SCFLoop::findCmpOp () {
145198 cmpOp = nullptr ;
146- for (auto *user : IVAddr .getUsers ()) {
199+ for (auto *user : ivAddr .getUsers ()) {
147200 if (user->getParentRegion () != &forOp.getCond ())
148201 continue ;
149202 if (auto loadOp = dyn_cast<cir::LoadOp>(*user)) {
@@ -162,10 +215,10 @@ cir::CmpOp SCFLoop::findCmpOp() {
162215 if (!mlir::isa<cir::IntType>(type))
163216 llvm_unreachable (" Non-integer type IV is not supported" );
164217
165- auto lhsDefOp = cmpOp.getLhs ().getDefiningOp ();
218+ auto * lhsDefOp = cmpOp.getLhs ().getDefiningOp ();
166219 if (!lhsDefOp)
167220 llvm_unreachable (" Can't find IV load" );
168- if (!isIVLoad (lhsDefOp, IVAddr ))
221+ if (!isIVLoad (lhsDefOp, ivAddr ))
169222 llvm_unreachable (" cmpOp LHS is not IV" );
170223
171224 if (cmpOp.getKind () != cir::CmpOpKind::le &&
@@ -187,7 +240,7 @@ mlir::Value SCFLoop::plusConstant(mlir::Value V, mlir::Location loc,
187240// The operations before the loop have been transferred to MLIR.
188241// So we need to go through getRemappedValue to find the value.
189242mlir::Value SCFLoop::findIVInitValue () {
190- auto remapAddr = rewriter->getRemappedValue (IVAddr );
243+ auto remapAddr = rewriter->getRemappedValue (ivAddr );
191244 if (!remapAddr)
192245 return nullptr ;
193246 if (!remapAddr.hasOneUse ())
@@ -199,7 +252,13 @@ mlir::Value SCFLoop::findIVInitValue() {
199252}
200253
201254void SCFLoop::analysis () {
202- step = findStepAndIV (IVAddr);
255+ canonical = mlir::succeeded (findStepAndIV ());
256+ if (!canonical) {
257+ mlir::emitError (forOp.getLoc (),
258+ " cannot handle non-constant step for induction variable" );
259+ return ;
260+ }
261+
203262 cmpOp = findCmpOp ();
204263 auto IVInit = findIVInitValue ();
205264 // The loop end value should be hoisted out of loop by -cir-mlir-scf-prepare.
@@ -237,7 +296,7 @@ void SCFLoop::transferToSCFForOp() {
237296 llvm_unreachable (
238297 " Not support lowering loop with break, continue or if yet" );
239298 // Replace the IV usage to scf loop induction variable.
240- if (isIVLoad (op, IVAddr )) {
299+ if (isIVLoad (op, ivAddr )) {
241300 // Replace CIR IV load with arith.addi scf.IV, 0.
242301 // The replacement makes the SCF IV can be automatically propogated
243302 // by OpAdaptor for individual IV user lowering.
@@ -293,6 +352,10 @@ class CIRForOpLowering : public mlir::OpConversionPattern<cir::ForOp> {
293352 mlir::ConversionPatternRewriter &rewriter) const override {
294353 SCFLoop loop (op, &rewriter);
295354 loop.analysis ();
355+ if (!loop.isCanonical ())
356+ return mlir::emitError (op.getLoc (),
357+ " cannot handle non-canonicalized loop" );
358+
296359 loop.transferToSCFForOp ();
297360 rewriter.eraseOp (op);
298361 return mlir::success ();
0 commit comments