@@ -226,11 +226,17 @@ static std::optional<WarpSchedule> getInitialSchedule(scf::ForOp loop) {
226
226
return std::nullopt;
227
227
228
228
// Propagate defs of exp.
229
- for (auto expOp : loop.getOps <math::Exp2Op>()) {
230
- auto tensorTy = dyn_cast<RankedTensorType>(expOp.getType ());
231
- if (tensorTy && tensorTy.getNumElements () > 256 ) {
232
- schedule.trySchedule (defaultPartition, expOp);
233
- scheduleDependencies (loop, schedule, defaultPartition, expOp);
229
+ for (Operation &op : loop.getOps ()) {
230
+ if (!isa<math::Exp2Op, ElementwiseInlineAsmOp>(op))
231
+ continue ;
232
+ int elementCount = 0 ;
233
+ for (Type type : op.getResultTypes ()) {
234
+ if (auto tensorTy = dyn_cast<RankedTensorType>(type))
235
+ elementCount += tensorTy.getNumElements ();
236
+ }
237
+ if (elementCount > 256 ) {
238
+ schedule.trySchedule (defaultPartition, &op);
239
+ scheduleDependencies (loop, schedule, defaultPartition, &op);
234
240
}
235
241
}
236
242
@@ -242,7 +248,8 @@ static std::optional<WarpSchedule> getInitialSchedule(scf::ForOp loop) {
242
248
while (userPartitions.size () < mmas.size ()) {
243
249
userPartitions.push_back (schedule.addPartition (userPartitions.size ()));
244
250
}
245
- for (auto [mmaOp, userPartition] : llvm::zip (mmas, userPartitions)) {
251
+ for (auto [mmaOp, userPartition] :
252
+ llvm::reverse (llvm::zip (mmas, userPartitions))) {
246
253
scheduleUsers (loop, schedule, userPartition, mmaOp);
247
254
}
248
255
0 commit comments