Skip to content

Commit c79e5d6

Browse files
authored
[Warp Specialization] Don't pipeline loops where latency ops are in the same stage (RELAND) (#6984)
Reland with a fix to make sure the stages of scheduled loads aren't wiped out.
1 parent 2f72d1d commit c79e5d6

File tree

2 files changed

+80
-0
lines changed

2 files changed

+80
-0
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeliner/ScheduleLoops.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,33 @@ CoarseSchedule getInitialSchedule(scf::ForOp forOp,
167167
CoarseSchedule schedule;
168168
if (forOp->hasAttr(kWarpSpecializeAttrName) &&
169169
succeeded(schedule.deSerialize(forOp))) {
170+
// The loop was partitioned from a warp-specialized loop, meaning it can
171+
// have a partial view of the original loop stages. Re-schedule the loop
172+
// root at the stages of the latency ops to prune unnecessary stages.
173+
auto isLatencyOp = [&](Operation &op) {
174+
return opLatency.count(&op) ||
175+
isa<LoadOp, DescriptorLoadOp, DescriptorGatherOp, LocalStoreOp,
176+
LocalLoadOp, ttng::TMEMLoadOp, ttng::TMEMStoreOp,
177+
AsyncCopyGlobalToLocalOp, ttng::AsyncTMACopyGlobalToLocalOp,
178+
ttng::AsyncTMAGatherOp, ttng::MMAv5OpInterface,
179+
ttng::WaitBarrierOp, ttng::ArriveBarrierOp>(op);
180+
};
181+
182+
// If there are no latency ops or all latency ops are in the same stage, we
183+
// don't need to pipeline the loop. Return a new schedule with everything
184+
// assigned to the same stage.
185+
DenseSet<int> latencyStages;
186+
auto ops = forOp.getBody()->without_terminator();
187+
for (Operation &op : llvm::make_filter_range(ops, isLatencyOp))
188+
latencyStages.insert(schedule[&op].first);
189+
if (latencyStages.size() <= 1) {
190+
CoarseSchedule normalized(/*numStages=*/1);
191+
auto cluster = normalized.clusters.newAtFront();
192+
for (Operation &op : ops)
193+
normalized.insert(&op, 0, cluster);
194+
return normalized;
195+
}
196+
170197
schedule.shrinkToFit();
171198
return schedule;
172199
}

test/TritonGPU/automatic-warp-specialization.mlir

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#indices_layout = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
55
#acc_layout = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
66
#oper_layout = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
7+
#b_layout = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
78
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
89
#smem = #ttg.shared_memory
910
#acc_tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
@@ -31,11 +32,16 @@ tt.func @matmul_change_desc_in_prologue(
3132
// CHECK-SAME: num_warps(1)
3233
// BASE-NOT: tt.make_tensor_descriptor
3334
// PIPELINE-NOT: tt.experimental_tensormap_create
35+
// PIPELINE-COUNT-1: tc_gen5_mma
36+
// PIPELINE-NOT: tc_gen5_mma
3437
// CHECK-LABEL: partition1
3538
// CHECK-SAME: num_warps(2)
3639
// BASE-COUNT-2: tt.make_tensor_descriptor
3740
// PIPELINE-COUNT-2: ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 512 : i32}
3841
// PIPELINE-COUNT-2: tt.experimental_tensormap_create
42+
// PIPELINE-NOT: tt.experimental_tensormap_create
43+
// PIPELINE-COUNT-2: async_tma_copy_global_to_local
44+
// PIPELINE-NOT: async_tma_copy_global_to_local
3945
// CHECK-NOT: partition2
4046
scf.for %k = %c0_i32 to %k_tiles step %c1_i32 iter_args(%acc = %zero, %flag = %true, %a_desc = %a_desc_undef, %b_desc = %b_desc_undef) -> (tensor<128x128xf32, #acc_layout>, i1, !tt.tensordesc<tensor<128x64xf16, #shared>>, !tt.tensordesc<tensor<64x128xf16, #shared>>) : i32 {
4147
%do_prologue = "prologue_cond"(%k) : (i32) -> i1
@@ -108,6 +114,53 @@ tt.func @matmul_tma_acc_with_conditional_def_and_use(
108114
tt.return
109115
}
110116

117+
// CHECK-LABEL: @matmul_tma_and_regular_load
118+
tt.func @matmul_tma_and_regular_load(
119+
%a_desc: !tt.tensordesc<tensor<1x64xf16, #shared>>,
120+
%b_ptr_init: tensor<64x128x!tt.ptr<f16>, #b_layout> {tt.divisibility = 16 : i32, tt.contiguity = 64 : i32}
121+
) {
122+
%c0_i32 = arith.constant 0 : i32
123+
%c1_i32 = arith.constant 1 : i32
124+
%true = arith.constant true
125+
%false = arith.constant false
126+
%zero = arith.constant dense<0.0> : tensor<128x128xf32, #acc_layout>
127+
%k_tiles = arith.constant 32 : i32
128+
// CHECK-LABEL: ttg.warp_specialize
129+
// CHECK-LABEL: default
130+
// CHECK-LABEL: partition0
131+
// CHECK-SAME: num_warps(4)
132+
// PIPELINE-COUNT-3: async_copy_global_to_local
133+
// PIPELINE-NOT: async_copy_global_to_local
134+
// CHECK-LABEL: partition1
135+
// CHECK-SAME: num_warps(4)
136+
// CHECK: [[INDICES:%.*]] = tt.splat %{{.*}} : i32 -> tensor<128xi32,
137+
// CHECK: ttng.async_tma_gather %{{.*}}[[[INDICES]],
138+
// CHECK-NOT: partition2
139+
scf.for %k = %c0_i32 to %k_tiles step %c1_i32 iter_args(%acc = %zero, %flag = %true, %b_ptr = %b_ptr_init) -> (tensor<128x128xf32, #acc_layout>, i1, tensor<64x128x!tt.ptr<f16>, #b_layout>) : i32 {
140+
%off_m, %offs_n, %off_k = "get_offsets"(%k) : (i32) -> (i32, tensor<64x128xi32, #b_layout>, i32)
141+
%indices = tt.splat %off_m : i32 -> tensor<128xi32, #indices_layout>
142+
143+
%a = tt.descriptor_gather %a_desc[%indices, %off_k] : (!tt.tensordesc<tensor<1x64xf16, #shared>>, tensor<128xi32, #indices_layout>, i32) -> tensor<128x64xf16, #oper_layout>
144+
145+
%b_ptrs = tt.addptr %b_ptr, %offs_n {tt.divisibility = dense<16> : tensor<64x128xi32>, tt.contiguity = dense<64> : tensor<64x128xi32>, tt.constancy = dense<1> : tensor<64x128xi32>} : tensor<64x128x!tt.ptr<f16>, #b_layout>, tensor<64x128xi32, #b_layout>
146+
%b = tt.load %b_ptrs : tensor<64x128x!tt.ptr<f16>, #b_layout>
147+
148+
%a_shared = ttg.local_alloc %a : (tensor<128x64xf16, #oper_layout>) -> !ttg.memdesc<128x64xf16, #shared, #smem>
149+
%b_shared = ttg.local_alloc %b : (tensor<64x128xf16, #b_layout>) -> !ttg.memdesc<64x128xf16, #shared, #smem>
150+
%c_tmem, %c_tok = ttng.tmem_alloc %acc : (tensor<128x128xf32, #acc_layout>) -> (!ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable>, !ttg.async.token)
151+
%mma_tok = ttng.tc_gen5_mma %a_shared, %b_shared, %c_tmem[%c_tok], %flag, %true : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable>
152+
%c, %load_tok = ttng.tmem_load %c_tmem[%mma_tok] : !ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #acc_layout>
153+
154+
%do_epilogue = arith.cmpi eq, %k, %c0_i32 : i32
155+
%use_acc = arith.select %do_epilogue, %false, %true : i1
156+
scf.if %do_epilogue {
157+
"acc_user"(%c) : (tensor<128x128xf32, #acc_layout>) -> ()
158+
}
159+
scf.yield %c, %use_acc, %b_ptrs : tensor<128x128xf32, #acc_layout>, i1, tensor<64x128x!tt.ptr<f16>, #b_layout>
160+
} {tt.warp_specialize, tt.disallow_acc_multi_buffer, tt.num_stages = 2 : i32}
161+
tt.return
162+
}
163+
111164
}
112165

113166
// -----

0 commit comments

Comments
 (0)