Skip to content

Commit 4ff4bd0

Browse files
authored
[Warp Specialization] Tweak the scheduling heuristic (#7073)
1 parent 45d9c8b commit 4ff4bd0

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionScheduling.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,17 @@ static std::optional<WarpSchedule> getInitialSchedule(scf::ForOp loop) {
226226
return std::nullopt;
227227

228228
// Propagate defs of exp.
229-
for (auto expOp : loop.getOps<math::Exp2Op>()) {
230-
auto tensorTy = dyn_cast<RankedTensorType>(expOp.getType());
231-
if (tensorTy && tensorTy.getNumElements() > 256) {
232-
schedule.trySchedule(defaultPartition, expOp);
233-
scheduleDependencies(loop, schedule, defaultPartition, expOp);
229+
for (Operation &op : loop.getOps()) {
230+
if (!isa<math::Exp2Op, ElementwiseInlineAsmOp>(op))
231+
continue;
232+
int elementCount = 0;
233+
for (Type type : op.getResultTypes()) {
234+
if (auto tensorTy = dyn_cast<RankedTensorType>(type))
235+
elementCount += tensorTy.getNumElements();
236+
}
237+
if (elementCount > 256) {
238+
schedule.trySchedule(defaultPartition, &op);
239+
scheduleDependencies(loop, schedule, defaultPartition, &op);
234240
}
235241
}
236242

@@ -242,7 +248,8 @@ static std::optional<WarpSchedule> getInitialSchedule(scf::ForOp loop) {
242248
while (userPartitions.size() < mmas.size()) {
243249
userPartitions.push_back(schedule.addPartition(userPartitions.size()));
244250
}
245-
for (auto [mmaOp, userPartition] : llvm::zip(mmas, userPartitions)) {
251+
for (auto [mmaOp, userPartition] :
252+
llvm::reverse(llvm::zip(mmas, userPartitions))) {
246253
scheduleUsers(loop, schedule, userPartition, mmaOp);
247254
}
248255

0 commit comments

Comments
 (0)