@@ -127,39 +127,41 @@ struct TritonIntelTensorDescToBlockPointer
127127 }
128128
129129 void propagateToLoops (Operation *op) {
130- if (auto loopOp = dyn_cast<LoopLikeOpInterface>(op)) {
131- bool updated = false ;
132- for (auto [initArg, rgnInitArg, yieldVal, loopRes] :
133- llvm::zip (loopOp.getInits (), loopOp.getRegionIterArgs (),
134- loopOp.getYieldedValues (), loopOp->getResults ())) {
135- Type initArgType = initArg.getType ();
136- Type rgnInitArgType = rgnInitArg.getType ();
137- assert (rgnInitArgType == loopRes.getType () &&
138- rgnInitArgType == yieldVal.getType () && " Type mismatch" );
139- if (rgnInitArgType != initArgType) {
140- rgnInitArg.setType (initArgType);
141- yieldVal.setType (initArgType);
142- loopRes.setType (initArgType);
143- updated = true ;
144- }
145- }
146- if (!updated)
147- return ;
148-
149- // For while loops we also need to update the "after" region arguments.
150- if (auto loopOp = dyn_cast<scf::WhileOp>(op)) {
151- for (auto [initArg, rgnAfterArg] :
152- llvm::zip (loopOp.getInits (), loopOp.getAfterArguments ())) {
153- Type initArgType = initArg.getType ();
154- if (rgnAfterArg.getType () != initArgType)
155- rgnAfterArg.setType (initArgType);
156- }
130+ auto loopOp = dyn_cast<LoopLikeOpInterface>(op);
131+ if (!loopOp)
132+ return ;
133+
134+ bool updated = false ;
135+ for (auto [initArg, rgnInitArg, yieldVal, loopRes] :
136+ llvm::zip (loopOp.getInits (), loopOp.getRegionIterArgs (),
137+ loopOp.getYieldedValues (), loopOp->getResults ())) {
138+ Type initArgType = initArg.getType ();
139+ Type rgnInitArgType = rgnInitArg.getType ();
140+ assert (rgnInitArgType == loopRes.getType () &&
141+ rgnInitArgType == yieldVal.getType () && " Type mismatch" );
142+ if (rgnInitArgType != initArgType) {
143+ rgnInitArg.setType (initArgType);
144+ yieldVal.setType (initArgType);
145+ loopRes.setType (initArgType);
146+ updated = true ;
157147 }
148+ }
149+ if (!updated)
150+ return ;
158151
159- // Propagate the loop results to their users.
160- for (Operation *user : loopOp->getUsers ())
161- propagateToLoops (user);
152+ // For while loops we also need to update the "after" region arguments.
153+ if (auto loopOp = dyn_cast<scf::WhileOp>(op)) {
154+ for (auto [initArg, rgnAfterArg] :
155+ llvm::zip (loopOp.getInits (), loopOp.getAfterArguments ())) {
156+ Type initArgType = initArg.getType ();
157+ if (rgnAfterArg.getType () != initArgType)
158+ rgnAfterArg.setType (initArgType);
159+ }
162160 }
161+
162+ // Propagate the loop results to their users.
163+ for (Operation *user : loopOp->getUsers ())
164+ propagateToLoops (user);
163165 }
164166
165167 LogicalResult rewriteMakeTensorDescriptorOp (tt::MakeTensorDescOp op) {
0 commit comments