Skip to content

Commit dcc60a8

Browse files
masahianmyachev
authored andcommitted
[WS] Fix for aref insert when descriptor_load and local_alloc are in different partitions (#8197)
I was making a wrong assumption that partitions for `descriptor_load` and `local_alloc` are always the same. When they are different, we need to create two arefs, one between TMA and default partitions, another between default and MMA partitions. To do so, we need to treat both ops as a producer operation. [The necessary logic is already there](https://github.com/triton-lang/triton/blob/e28c618576be7b4c78a2083ade058ff20040b266/third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertAref.cpp#L489-L505), so we just need to recognize this case.
1 parent 0fdcc6c commit dcc60a8

File tree

3 files changed

+51
-3
lines changed

3 files changed

+51
-3
lines changed

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/AutomaticWarpSpecialization.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ struct AutomaticWarpSpecialization
3535
void AutomaticWarpSpecialization::runOnOperation() {
3636
OpPassManager pm;
3737
pm.addPass(createTritonGPUPartitionScheduling());
38-
// TODO: re-enable once the regression is fixed.
39-
// pm.addPass(createNVWSInsertAref());
38+
pm.addPass(createNVWSInsertAref());
4039
pm.addPass(createTritonGPULoadMMASpecialization({numStages}));
4140
pm.addPass(createTritonGPURewritePartitionDependencies());
4241
// `int-range-optimizations` and SCCP are good at cleaning up loop arithmetic.

test/NVWS/insert_aref.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
44
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
5+
#blocked2 = #ttg.blocked<{sizePerThread = [128, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
56
#linear = #ttg.linear<{register = [[0, 1], [0, 2], [32, 0], [64, 0], [0, 4]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}>
67
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
78
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>
@@ -166,4 +167,49 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
166167
} {tt.num_stages = 2 : i64, tt.scheduled_max_stage = 1 : i32, tt.warp_specialize, ttg.partition.stages = [0 : i32, 1 : i32, 0 : i32], ttg.warp_specialize.tag = 0 : i32}
167168
tt.return
168169
}
170+
171+
// FUNC-LABEL: @local_alloc_default_partition
172+
// CHECK: @local_alloc_default_partition
173+
tt.func @local_alloc_default_partition(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: !tt.tensordesc<tensor<128x128xf16, #shared>>, %arg4: !tt.tensordesc<tensor<128x128xf16, #shared>>) {
174+
%true = arith.constant true
175+
%c0_i32 = arith.constant 0 : i32
176+
%c1_i32 = arith.constant 1 : i32
177+
%c128_i32 = arith.constant 128 : i32
178+
// CHECK: [[AREF_LHS_TRANS:%.*]] = nvws.aref.create {{.*}} : <[!ttg.memdesc<1x128x128xf16, #shared1, #smem, mutable>]>
179+
// CHECK: [[AREF_RHS:%.*]] = nvws.aref.create {{.*}} : <[!ttg.memdesc<1x128x128xf16, #shared, #smem, mutable>]>
180+
// CHECK: [[AREF_LHS:%.*]] = nvws.aref.create {{.*}} : <[!ttg.memdesc<1x128x128xf16, #shared, #smem, mutable>]>
181+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
182+
%result, %token = ttng.tmem_alloc : () -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token)
183+
%0 = ttng.tmem_store %cst, %result[%token], %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
184+
185+
%1 = scf.for %arg5 = %c0_i32 to %arg0 step %c1_i32 iter_args(%arg6 = %0) -> (!ttg.async.token) : i32 {
186+
%2 = arith.muli %arg5, %c128_i32 {loop.cluster = 1 : i32, loop.stage = 0 : i32} : i32
187+
// CHECK: [[AREF_LHS_PUT_BUF:%.*]], {{.*}} = nvws.aref.put.enter [[AREF_LHS]] {{.*}}ttg.partition = array<i32: 2>}
188+
// CHECK: nvws.descriptor_load {{.*}} 32768 [[AREF_LHS_PUT_BUF]] {{.*}}ttg.partition = array<i32: 2>}
189+
190+
// CHECK: [[AREF_LHS_GET_BUF:%.*]], {{.*}} = nvws.aref.get.enter [[AREF_LHS]] {{.*}}ttg.partition = array<i32: 0>}
191+
// CHECK: [[TMA_RES_REG:%.*]] = ttg.local_load [[AREF_LHS_GET_BUF]] {{.*}}ttg.partition = array<i32: 0>}
192+
193+
// CHECK: [[AREF_LHS_TRANS_PUT_BUF:%.*]], {{.*}} = nvws.aref.put.enter [[AREF_LHS_TRANS]] {{.*}}ttg.partition = array<i32: 0>}
194+
// CHECK: ttg.local_store [[TMA_RES_REG]], [[AREF_LHS_TRANS_PUT_BUF]] {{.*}}ttg.partition = array<i32: 0>}
195+
196+
// CHECK: [[AREF_LHS_TRANS_GET_BUF:%.*]], {{.*}} = nvws.aref.get.enter [[AREF_LHS_TRANS]] {{.*}}ttg.partition = array<i32: 1>}
197+
// CHECK: [[LHS:%.*]] = ttg.memdesc_trans [[AREF_LHS_TRANS_GET_BUF]] {{.*}}ttg.partition = array<i32: 1>}
198+
199+
%3 = tt.descriptor_load %arg3[%arg1, %2] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array<i32: 2>} : !tt.tensordesc<tensor<128x128xf16, #shared>> -> tensor<128x128xf16, #blocked2>
200+
%5 = ttg.local_alloc %3 {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array<i32: 0>} : (tensor<128x128xf16, #blocked2>) -> !ttg.memdesc<128x128xf16, #shared1, #smem>
201+
%lhs_trans = ttg.memdesc_trans %5 {loop.cluster = 0 : i32, loop.stage = 1 : i32, order = array<i32: 1, 0>, ttg.partition = array<i32: 1>} : !ttg.memdesc<128x128xf16, #shared1, #smem> -> !ttg.memdesc<128x128xf16, #shared, #smem>
202+
203+
%4 = tt.descriptor_load %arg4[%arg2, %2] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array<i32: 2>} : !tt.tensordesc<tensor<128x128xf16, #shared>> -> tensor<128x128xf16, #blocked1>
204+
%6 = ttg.local_alloc %4 {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array<i32: 2>} : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #smem>
205+
%7 = ttg.memdesc_trans %6 {loop.cluster = 0 : i32, loop.stage = 1 : i32, order = array<i32: 1, 0>, ttg.partition = array<i32: 1>} : !ttg.memdesc<128x128xf16, #shared, #smem> -> !ttg.memdesc<128x128xf16, #shared1, #smem>
206+
207+
// CHECK: ttng.tc_gen5_mma [[LHS]]
208+
%8 = ttng.tc_gen5_mma %lhs_trans, %7, %result[%arg6], %true, %true {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array<i32: 1>} : !ttg.memdesc<128x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf16, #shared1, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
209+
scf.yield %8 : !ttg.async.token
210+
} {tt.num_stages = 2 : i32, tt.scheduled_max_stage = 1 : i32, tt.warp_specialize, ttg.partition.stages = [0 : i32, 1 : i32, 0 : i32], ttg.warp_specialize.tag = 0 : i32}
211+
%result_0, %token_1 = ttng.tmem_load %result[%1] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
212+
"use"(%result_0) : (tensor<128x128xf32, #blocked>) -> ()
213+
tt.return
214+
}
169215
}

third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertAref.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,10 @@ std::optional<std::pair<AllocOp, LoadOp>> isLoadAndAlloc(Value result) {
6666
auto alloc = result.getDefiningOp<AllocOp>();
6767
if (!alloc)
6868
return std::nullopt;
69-
if (auto load = alloc.getSrc().template getDefiningOp<LoadOp>()) {
69+
if (auto load = alloc.getSrc().template getDefiningOp<LoadOp>();
70+
*getPartitionIds(alloc) == *getPartitionIds(load)) {
71+
// if alloc and load are in different partitions, they are treated as two
72+
// different producer operations.
7073
return std::make_pair(alloc, load);
7174
}
7275
return std::nullopt;

0 commit comments

Comments
 (0)