Skip to content

Commit c00f747

Browse files
authored
Revert "[Warp Specialization] Don't pipeline loops where latency ops are in the same stage" (#6983)
Reverts triton-lang/triton#6969
1 parent 8ed512f commit c00f747

File tree

2 files changed

+0
-31
lines changed

2 files changed

+0
-31
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeliner/ScheduleLoops.cpp

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -167,32 +167,6 @@ CoarseSchedule getInitialSchedule(scf::ForOp forOp,
167167
CoarseSchedule schedule;
168168
if (forOp->hasAttr(kWarpSpecializeAttrName) &&
169169
succeeded(schedule.deSerialize(forOp))) {
170-
// The loop was partitioned from a warp-specialized loop, meaning it can
171-
// have a partial view of the original loop stages. Re-schedule the loop
172-
// root at the stages of the latency ops to prune unnecessary stages.
173-
auto isLatencyOp = [&](Operation &op) {
174-
return opLatency.count(&op) ||
175-
isa<LocalStoreOp, LocalLoadOp, ttng::TMEMLoadOp, ttng::TMEMStoreOp,
176-
AsyncCopyGlobalToLocalOp, ttng::AsyncTMACopyGlobalToLocalOp,
177-
ttng::AsyncTMAGatherOp, ttng::MMAv5OpInterface,
178-
ttng::WaitBarrierOp, ttng::ArriveBarrierOp>(op);
179-
};
180-
181-
// If there are no latency ops or all latency ops are in the same stage, we
182-
// don't need to pipeline the loop. Return a new schedule with everything
183-
// assigned to the same stage.
184-
DenseSet<int> latencyStages;
185-
auto ops = forOp.getBody()->without_terminator();
186-
for (Operation &op : llvm::make_filter_range(ops, isLatencyOp))
187-
latencyStages.insert(schedule[&op].first);
188-
if (latencyStages.size() <= 1) {
189-
CoarseSchedule normalized(/*numStages=*/1);
190-
auto cluster = normalized.clusters.newAtFront();
191-
for (Operation &op : ops)
192-
normalized.insert(&op, 0, cluster);
193-
return normalized;
194-
}
195-
196170
schedule.shrinkToFit();
197171
return schedule;
198172
}

test/TritonGPU/automatic-warp-specialization.mlir

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,11 @@ tt.func @matmul_change_desc_in_prologue(
3131
// CHECK-SAME: num_warps(1)
3232
// BASE-NOT: tt.make_tensor_descriptor
3333
// PIPELINE-NOT: tt.experimental_tensormap_create
34-
// PIPELINE-COUNT-1: tc_gen5_mma
35-
// PIPELINE-NOT: tc_gen5_mma
3634
// CHECK-LABEL: partition1
3735
// CHECK-SAME: num_warps(2)
3836
// BASE-COUNT-2: tt.make_tensor_descriptor
3937
// PIPELINE-COUNT-2: ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 512 : i32}
4038
// PIPELINE-COUNT-2: tt.experimental_tensormap_create
41-
// PIPELINE-NOT: tt.experimental_tensormap_create
42-
// PIPELINE-COUNT-2: async_tma_copy_global_to_local
43-
// PIPELINE-NOT: async_tma_copy_global_to_local
4439
// CHECK-NOT: partition2
4540
scf.for %k = %c0_i32 to %k_tiles step %c1_i32 iter_args(%acc = %zero, %flag = %true, %a_desc = %a_desc_undef, %b_desc = %b_desc_undef) -> (tensor<128x128xf32, #acc_layout>, i1, !tt.tensordesc<tensor<128x64xf16, #shared>>, !tt.tensordesc<tensor<64x128xf16, #shared>>) : i32 {
4641
%do_prologue = "prologue_cond"(%k) : (i32) -> i1

0 commit comments

Comments
 (0)