Skip to content

Commit adf3999

Browse files
authored
[WS] make aref ops stage/phase optional (triton-lang#8101)
* make aref ops stage/phase optional * update lit tests
1 parent 868d242 commit adf3999

File tree

8 files changed

+142
-114
lines changed

8 files changed

+142
-114
lines changed

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/RewritePartitionDependencies.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,14 @@ namespace {
6464
struct AsyncRef {
6565
auto putView(PartitionBuilder &b, Partition &partition,
6666
StageCluster srcStageCluster) {
67-
auto zero = b.create<arith::ConstantOp>(b.getI32IntegerAttr(0));
6867
auto enterOp = b.createInto<triton::nvws::ArefPutEnterOp>(
69-
partition, srcStageCluster, viewType, tokenType, aref, zero, zero);
68+
partition, srcStageCluster, aref, TypeRange{viewType}, tokenType);
7069
auto token = enterOp.getToken();
7170

7271
auto exitOp = [this, &partition, srcStageCluster,
7372
token](PartitionBuilder &b) {
74-
auto zero = b.create<arith::ConstantOp>(b.getI32IntegerAttr(0));
7573
auto exitOp = b.createInto<triton::nvws::ArefPutExitOp>(
76-
partition, srcStageCluster, aref, token, zero,
74+
partition, srcStageCluster, aref, token,
7775
b.getArrayAttr(SmallVector<Attribute>{triton::nvws::AsyncOpAttr::get(
7876
aref.getContext(), triton::nvws::AsyncOp::NONE)}));
7977
};
@@ -82,16 +80,14 @@ struct AsyncRef {
8280

8381
auto getView(PartitionBuilder &b, Partition &partition,
8482
StageCluster srcStageCluster) {
85-
auto zero = b.create<arith::ConstantOp>(b.getI32IntegerAttr(0));
8683
auto enterOp = b.createInto<triton::nvws::ArefGetEnterOp>(
87-
partition, srcStageCluster, viewType, tokenType, aref, zero, zero);
84+
partition, srcStageCluster, aref, TypeRange{viewType}, tokenType);
8885
auto token = enterOp.getToken();
8986

9087
auto exitOp = [this, &partition, srcStageCluster,
9188
token](PartitionBuilder &b) {
92-
auto zero = b.create<arith::ConstantOp>(b.getI32IntegerAttr(0));
9389
auto exitOp = b.createInto<triton::nvws::ArefGetExitOp>(
94-
partition, srcStageCluster, aref, token, zero,
90+
partition, srcStageCluster, aref, token,
9591
b.getArrayAttr(SmallVector<Attribute>{triton::nvws::AsyncOpAttr::get(
9692
aref.getContext(), triton::nvws::AsyncOp::NONE)}));
9793
};

test/NVWS/insert_aref.mlir

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,10 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
3030
// CHECK-NEXT: [[AREF2:%.*]] = nvws.aref.create [[AREF_BUF2]]
3131
%1 = scf.for %arg5 = %c0_i32 to %arg0 step %c1_i32 iter_args(%arg6 = %0) -> (!ttg.async.token) : i32 {
3232
%2 = arith.muli %arg5, %c64_i32 {loop.cluster = 1 : i32, loop.stage = 0 : i32} : i32
33-
// CHECK: [[C_ZERO1:%.*]] = arith.constant {ttg.partition = 2 : i32} 0
34-
// CHECK: [[PUT_BUF1:%.*]], [[TOKEN1:%.*]] = nvws.aref.put.enter [[AREF1]][[[C_ZERO1]], [[C_ZERO1]]] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = 2 : i32}
33+
// CHECK: [[PUT_BUF1:%.*]], [[TOKEN1:%.*]] = nvws.aref.put.enter [[AREF1]] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = 2 : i32}
3534
// CHECK-NEXT: nvws.descriptor_load {{.*}} 16384 [[PUT_BUF1]]
36-
// CHECK: [[C_ZERO2:%.*]] = arith.constant {ttg.partition = 2 : i32} 0
37-
// CHECK: nvws.aref.put.exit [[AREF1]][[[C_ZERO2]]], [[TOKEN1]] [#nvws.async_op<tma_load>] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = 2 : i32}
38-
// CHECK: [[C_ZERO3:%.*]] = arith.constant {ttg.partition = 1 : i32} 0
39-
// CHECK: [[GET_BUF1:%.*]], [[GET_TOKEN1:%.*]] = nvws.aref.get.enter [[AREF1]][[[C_ZERO3]], [[C_ZERO3]]] {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = 1 : i32}
35+
// CHECK: nvws.aref.put.exit [[AREF1]], [[TOKEN1]] [#nvws.async_op<tma_load>] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = 2 : i32}
36+
// CHECK: [[GET_BUF1:%.*]], [[GET_TOKEN1:%.*]] = nvws.aref.get.enter [[AREF1]] {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = 1 : i32}
4037
%3 = tt.descriptor_load %arg3[%arg1, %2] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = 2 : i32} : !tt.tensordesc<tensor<128x64xf16, #shared>> -> tensor<128x64xf16, #blocked1>
4138
// CHECK: [[PUT_BUF2:%.*]], [[TOKEN2:%.*]] = nvws.aref.put.enter [[AREF2]]
4239
// CHECK-NEXT: nvws.descriptor_load {{.*}} 16384 [[PUT_BUF2]]
@@ -51,9 +48,8 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
5148
%7 = ttg.memdesc_trans %6 {loop.cluster = 0 : i32, loop.stage = 1 : i32, order = array<i32: 1, 0>, ttg.partition = 1 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem> -> !ttg.memdesc<64x128xf16, #shared1, #smem>
5249
// CHECK: ttng.tc_gen5_mma [[GET_BUF1]], [[RHS]], {{.*}}, {{.*}}, {{.*}}
5350
%8 = ttng.tc_gen5_mma %5, %7, %result[%arg6], %true, %true {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = 1 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared1, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
54-
// CHECK: nvws.aref.get.exit [[AREF2]][{{.*}}], [[GET_TOKEN2]]
55-
// CHECK: [[C_ZERO4:%.*]] = arith.constant {ttg.partition = 1 : i32} 0
56-
// CHECK: nvws.aref.get.exit [[AREF1]][[[C_ZERO4]]], [[GET_TOKEN1]] [#nvws.async_op<tc5mma>] {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = 1 : i32}
51+
// CHECK: nvws.aref.get.exit [[AREF2]], [[GET_TOKEN2]]
52+
// CHECK: nvws.aref.get.exit [[AREF1]], [[GET_TOKEN1]] [#nvws.async_op<tc5mma>] {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = 1 : i32}
5753
scf.yield %8 : !ttg.async.token
5854
} {tt.num_stages = 2 : i32, tt.scheduled_max_stage = 1 : i32, tt.warp_specialize, ttg.partition.stages = [0 : i32, 1 : i32, 0 : i32], ttg.warp_specialize.tag = 0 : i32}
5955
%result_0, %token_1 = ttng.tmem_load %result[%1] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>

0 commit comments

Comments
 (0)