Skip to content

Commit d90b234

Browse files
authored
[Warp Specialization] Fix accidentally assigning ops to default partition (#6777)
Ops that have no backward dependency on any other partition are expected to be left in the root partition because it's easier to let later passes deal with rematerialization than for higher-level passes to duplicate and assign explicit stages. Accidentally assigning partitions to these ops will cause later passes to fail. This also changes those passes to a hard fail because the code generated by WS can be invalid if it isn't split. (The code expects to be run concurrently).
1 parent 44ecbec commit d90b234

File tree

6 files changed

+198
-131
lines changed

6 files changed

+198
-131
lines changed

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/LoadMMASpecialization.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,8 @@ static void scheduleDependencies(scf::ForOp loop, WarpSchedule &schedule,
229229

230230
Operation *defOp =
231231
loop.getBody()->findAncestorOpInBlock(*dep.getDefiningOp());
232-
if (!defOp || !schedule.trySchedule(partition, defOp))
232+
if (!defOp || !hasDefPartition(loop, defOp, schedule) ||
233+
!schedule.trySchedule(partition, defOp))
233234
continue;
234235
llvm::append_range(deps, getNestedOperands(defOp));
235236
}
@@ -392,7 +393,7 @@ void propagatePartitions(scf::ForOp loop, WarpSchedule &schedule) {
392393
// For each partition, place users of its outputs in a cluster if it is not
393394
// already assigned to a partition.
394395
auto useCallback = [&](OpResult result, OpOperand &use, unsigned distance) {
395-
Operation *user = use.getOwner();
396+
Operation *user = loop.getBody()->findAncestorOpInBlock(*use.getOwner());
396397
if (!schedule.isScheduled(user)) {
397398
// Add the current partition as a def to the cluster.
398399
opClusters.getOrCreate(user)->defPartitions.insert(&partition);

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionLoops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,6 @@ void PartitionLoops::runOnOperation() {
269269

270270
for (scf::ForOp loop : loops) {
271271
if (failed(partitionLoop(loop)))
272-
continue;
272+
return signalPassFailure();
273273
}
274274
}

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/RewritePartitionDependencies.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,6 @@ void RewritePartitionDependencies::runOnOperation() {
568568

569569
for (scf::ForOp loop : loops) {
570570
if (failed(rewritePartitionDependencies(loop)))
571-
continue;
571+
return signalPassFailure();
572572
}
573573
}

test/TritonGPU/load-mma-specialization.mlir

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -172,54 +172,27 @@ tt.func @unsupported_load() {
172172
// CHECK-NEXT: [[DONE_MBAR0:%.*]] = ttg.memdesc_subview [[DONE_MBAR]][%c0_i32]
173173
// CHECK-NEXT: ttng.init_barrier [[DONE_MBAR0]], 1
174174

175-
// CHECK-NEXT: [[A_SHARED:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<128x64xf16,
176-
// CHECK-NEXT: [[B_SHARED:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<64x128xf16,
177-
178-
// CHECK-NEXT: [[OPER_EMPTY_MBAR:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<1xi64
179-
// CHECK-NEXT: [[OPER_EMPTY_MBAR0:%.*]] = ttg.memdesc_subview [[OPER_EMPTY_MBAR]][%c0_i32]
180-
// CHECK-NEXT: init_barrier [[OPER_EMPTY_MBAR0]], 1
181-
182-
// CHECK-NEXT: [[OPER_READY_MBAR:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<1xi64
183-
// CHECK-NEXT: [[OPER_READY_MBAR0:%.*]] = ttg.memdesc_subview [[OPER_READY_MBAR]][%c0_i32]
184-
// CHECK-NEXT: init_barrier [[OPER_READY_MBAR0]], 1
185-
186-
// CHECK-NEXT: arrive_barrier [[OPER_EMPTY_MBAR]], 1
187-
188175
// CHECK-NEXT: scf.for
189176
scf.for %k = %c0_i32 to %k_tiles step %c1_i32 iter_args(%acc = %zero) -> tensor<128x128xf32, #acc_layout> : i32 {
190177
// CHECK-NEXT: get_ptrs
191178
%a_ptrs, %b_ptrs = "get_ptrs"(%k) : (i32) -> (tensor<128x64x!tt.ptr<f16>, #oper_layout>, tensor<64x128x!tt.ptr<f16>, #oper_layout>)
192-
// CHECK-NEXT: [[A:%.*]] = tt.load
193179
%a = tt.load %a_ptrs : tensor<128x64x!tt.ptr<f16>, #oper_layout>
194-
// CHECK-NEXT: [[B:%.*]] = tt.load
195180
%b = tt.load %b_ptrs : tensor<64x128x!tt.ptr<f16>, #oper_layout>
196181

197-
// CHECK-NEXT: wait_barrier [[OPER_EMPTY_MBAR]]
198-
// CHECK-NEXT: local_store [[A]], [[A_SHARED]]
199182
%a_shared = ttg.local_alloc %a : (tensor<128x64xf16, #oper_layout>) -> !ttg.memdesc<128x64xf16, #shared, #smem>
200-
// CHECK-NEXT: local_store [[B]], [[B_SHARED]]
201183
%b_shared = ttg.local_alloc %b : (tensor<64x128xf16, #oper_layout>) -> !ttg.memdesc<64x128xf16, #shared, #smem>
202-
// CHECK-NEXT: arrive_barrier [[OPER_READY_MBAR]], 1
203184

204185
%c_tmem, %c_tok = ttng.tmem_alloc %acc : (tensor<128x128xf32, #acc_layout>) -> (!ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable>, !ttg.async.token)
205-
// CHECK-NEXT: [[IS_LAST:%.*]] = arith.cmpi eq, %{{.*}}, %c31_i32
206-
// CHECK-NEXT: wait_barrier [[OPER_READY_MBAR]]
207-
// CHECK-NEXT: ttng.tc_gen5_mma %{{.*}}, [[ACC]][], %true, %true, [[DONE_MBAR0]][[[IS_LAST]]], [[OPER_EMPTY_MBAR]][%true] {ttg.partition = 1 : i32}
186+
// CHECK: [[IS_LAST:%.*]] = arith.cmpi eq, %{{.*}}, %c31_i32
187+
// CHECK-NEXT: ttng.tc_gen5_mma %{{.*}}, [[ACC]][], %true, %true, [[DONE_MBAR0]][[[IS_LAST]]] {ttg.partition = 1 : i32}
208188
%mma_tok = ttng.tc_gen5_mma %a_shared, %b_shared, %c_tmem[%c_tok], %true, %true : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable>
209189
%c, %load_tok = ttng.tmem_load %c_tmem[%mma_tok] : !ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #acc_layout>
210190

211-
// CHECK-NEXT: [[NEXT_PHASE:%.*]] = arith.xori
212-
// CHECK-NEXT: yield [[NEXT_PHASE]]
213-
214191
scf.yield %c : tensor<128x128xf32, #acc_layout>
215-
// CHECK-NEXT: ttg.partition.stages = [0 : i32, 1 : i32, 0 : i32]
192+
// CHECK: ttg.partition.stages = [0 : i32, 1 : i32, 0 : i32]
216193
} {tt.warp_specialize}
217194

218195
// CHECK-NEXT: ttng.wait_barrier [[DONE_MBAR0]], %c0_i32
219-
// CHECK-NEXT: ttng.inval_barrier [[OPER_READY_MBAR0]]
220-
// CHECK-NEXT: ttg.local_dealloc [[OPER_READY_MBAR]]
221-
// CHECK-NEXT: ttng.inval_barrier [[OPER_EMPTY_MBAR0]]
222-
// CHECK-NEXT: ttg.local_dealloc [[OPER_EMPTY_MBAR]]
223196
// CHECK-NEXT: ttng.inval_barrier [[DONE_MBAR0]]
224197
// CHECK-NEXT: ttg.local_dealloc [[DONE_MBAR]]
225198

@@ -749,7 +722,7 @@ tt.func @matmul_tma_acc_with_conditional_def_and_use_no_multibuf_flag(
749722
%b_shared = ttg.local_alloc %b : (tensor<64x128xf16, #oper_layout>) -> !ttg.memdesc<64x128xf16, #shared, #smem>
750723
%c_tmem, %c_tok = ttng.tmem_alloc %acc : (tensor<128x128xf32, #acc_layout>) -> (!ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable>, !ttg.async.token)
751724

752-
// CHECK-NEXT: [[DO_EPILOGUE:%.*]] = arith.cmpi eq, [[K:%.*]], %c0_i32
725+
// CHECK-NEXT: [[DO_EPILOGUE:%.*]] = arith.cmpi eq, [[K:%.*]], %c0_i32 : i32
753726
// CHECK-NEXT: [[MMA_TOK:%.*]] = ttng.tc_gen5_mma %{{[0-9]+}}, %{{[0-9]+}}, [[ACC_BUF]][], [[FLAG]], %true, {{.*}}, [[ACC_READY_BUF0]][[[DO_EPILOGUE]]] {ttg.partition = 1 : i32}
754727
%mma_tok = ttng.tc_gen5_mma %a_shared, %b_shared, %c_tmem[%c_tok], %flag, %true : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable>
755728
%c, %load_tok = ttng.tmem_load %c_tmem[%mma_tok] : !ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #acc_layout>

test/TritonGPU/partition-loops.mlir

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,10 @@
1-
// RUN: triton-opt %s -allow-unregistered-dialect -tritongpu-partition-loops -verify-diagnostics -canonicalize | FileCheck %s
1+
// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -tritongpu-partition-loops -verify-diagnostics -canonicalize | FileCheck %s
22

33
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
44
!ty = tensor<1xi32, #blocked>
55

66
module attributes {"ttg.num-warps" = 4 : i32} {
77

8-
tt.func @still_has_ssa_deps(%lb: i32, %ub: i32, %step: i32) {
9-
scf.for %i = %lb to %ub step %step : i32 {
10-
// expected-warning @below {{non-root partition #0 has direct SSA consumer}}
11-
%0 = "op_a"() {ttg.partition = 0} : () -> !ty
12-
// expected-note @below {{use at distance 0 in partition #1 here}}
13-
"op_b"(%0) {ttg.partition = 1} : (!ty) -> ()
14-
} {ttg.partition.stages = [0, 1]}
15-
tt.return
16-
}
17-
188
// CHECK-LABEL: @no_partitions
199
tt.func @no_partitions(%lb: i32, %ub: i32, %step: i32) {
2010
// CHECK-NEXT: scf.for
@@ -259,3 +249,22 @@ tt.func public @capture_order(%arg0: i32) {
259249
}
260250

261251
}
252+
253+
// -----
254+
255+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
256+
!ty = tensor<1xi32, #blocked>
257+
258+
module attributes {"ttg.num-warps" = 4 : i32} {
259+
260+
tt.func @still_has_ssa_deps(%lb: i32, %ub: i32, %step: i32) {
261+
scf.for %i = %lb to %ub step %step : i32 {
262+
// expected-warning @below {{non-root partition #0 has direct SSA consumer}}
263+
%0 = "op_a"() {ttg.partition = 0} : () -> !ty
264+
// expected-note @below {{use at distance 0 in partition #1 here}}
265+
"op_b"(%0) {ttg.partition = 1} : (!ty) -> ()
266+
} {ttg.partition.stages = [0, 1]}
267+
tt.return
268+
}
269+
270+
}

0 commit comments

Comments
 (0)