Skip to content

Commit 70f9020

Browse files
authored
[NVWS] pass on stage/cluster attributes (#7649)
adds stage/cluster attribute to `try_wait/arrive/commit` in `lower-aref` pass. Lit test will be added once PR triton-lang/triton#7648 is merged, because it depends on a lit-test in that PR.
1 parent 0e872d4 commit 70f9020

File tree

6 files changed

+42
-22
lines changed

6 files changed

+42
-22
lines changed

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/LoadMMASpecialization.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#include "PartitionBuilder.h"
21
#include "mlir/Dialect/UB/IR/UBOps.h"
32
#include "mlir/IR/BuiltinOps.h"
43
#include "mlir/IR/Dominance.h"
@@ -9,6 +8,7 @@
98
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
109
#include "triton/Dialect/TritonGPU/Transforms/MMAv5PipelineUtility.h"
1110
#include "triton/Dialect/TritonGPU/Transforms/Partition.h"
11+
#include "triton/Dialect/TritonGPU/Transforms/PartitionBuilder.h"
1212
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
1313
#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h"
1414
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionBuilder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include "PartitionBuilder.h"
1+
#include "triton/Dialect/TritonGPU/Transforms/PartitionBuilder.h"
22
#include "triton/Dialect/TritonGPU/Transforms/Partition.h"
33
#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h"
44

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/RewritePartitionDependencies.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
#include "PartitionBuilder.h"
21
#include "mlir/Dialect/SCF/IR/SCF.h"
32
#include "mlir/IR/BuiltinOps.h"
43
#include "mlir/Pass/Pass.h"
54
#include "nvidia/include/Dialect/NVWS/IR/Dialect.h"
65
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
76
#include "triton/Dialect/TritonGPU/Transforms/Partition.h"
7+
#include "triton/Dialect/TritonGPU/Transforms/PartitionBuilder.h"
88
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
99
#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h"
1010
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"

test/NVWS/lower_aref.mlir

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,45 +24,45 @@ module attributes {"ttg.num-warps" = 4 : i32} {
2424
partition0 num_warps(4) {
2525
scf.for %arg3 = %arg0 to %arg1 step %arg2 : i32 {
2626
%2 = "op_a"() : () -> tensor<1xi32, #blocked>
27-
%3 = nvws.aref.put.enter %1[%c0_i32, %c0_i32] : <[!ttg.memdesc<2x1xi32, #shared, #smem, mutable>]> -> !ttg.memdesc<1xi32, #shared, #smem, mutable, 2x1>
27+
%3 = nvws.aref.put.enter %1[%c0_i32, %c0_i32] {loop.cluster = 1 : i32, loop.stage = 3 : i32}: <[!ttg.memdesc<2x1xi32, #shared, #smem, mutable>]> -> !ttg.memdesc<1xi32, #shared, #smem, mutable, 2x1>
2828
ttg.local_store %2, %3 : tensor<1xi32, #blocked> -> !ttg.memdesc<1xi32, #shared, #smem, mutable, 2x1>
2929
// CHECK: op_a
3030
// CHECK-NEXT: [[EMPTYMBAR:%.*]] = ttg.memdesc_index [[EMPTY]]
31-
// CHECK-NEXT: ttng.wait_barrier [[EMPTYMBAR]]
31+
// CHECK-NEXT: ttng.wait_barrier [[EMPTYMBAR]], {{.*}} {loop.cluster = 1 : i32, loop.stage = 3 : i32}
3232
// CHECK: local_store
3333
// CHECK-NEXT: [[FULLMBAR:%.*]] = ttg.memdesc_index [[FULL]]
34-
// CHECK-NEXT: ttng.arrive_barrier [[FULLMBAR]], 1
35-
nvws.aref.put.exit %1[%c0_i32] [#nvws.async_op<none>] : <[!ttg.memdesc<2x1xi32, #shared, #smem, mutable>]>
34+
// CHECK-NEXT: ttng.arrive_barrier [[FULLMBAR]], 1 {loop.cluster = 1 : i32, loop.stage = 3 : i32}
35+
nvws.aref.put.exit %1[%c0_i32] [#nvws.async_op<none>] {loop.cluster = 1 : i32, loop.stage = 3 : i32} : <[!ttg.memdesc<2x1xi32, #shared, #smem, mutable>]>
3636
}
3737
nvws.warp_group.yield
3838
}
3939
partition1 num_warps(4) {
4040
scf.for %arg3 = %arg0 to %arg1 step %arg2 : i32 {
4141
// CHECK: [[FULLMBAR:%.*]] = ttg.memdesc_index [[FULL]]
42-
// CHECK-NEXT: ttng.wait_barrier [[FULLMBAR]]
42+
// CHECK-NEXT: ttng.wait_barrier [[FULLMBAR]], {{.*}} {loop.cluster = 2 : i32, loop.stage = 3 : i32}
4343
// CHECK: [[VAL:%.*]] = ttg.local_load
4444
// CHECK-NEXT: [[EMPTYMBAR:%.*]] = ttg.memdesc_index [[EMPTY]]
45-
// CHECK-NEXT: ttng.arrive_barrier [[EMPTYMBAR]], 1
45+
// CHECK-NEXT: ttng.arrive_barrier [[EMPTYMBAR]], 1 {loop.cluster = 2 : i32, loop.stage = 3 : i32}
4646
// CHECK: "op_b"([[VAL]])
47-
%2 = nvws.aref.get.enter %1[%c0_i32, %c0_i32] : <[!ttg.memdesc<2x1xi32, #shared, #smem, mutable>]> -> !ttg.memdesc<1xi32, #shared, #smem, mutable, 2x1>
47+
%2 = nvws.aref.get.enter %1[%c0_i32, %c0_i32] {loop.cluster = 2 : i32, loop.stage = 3 : i32}: <[!ttg.memdesc<2x1xi32, #shared, #smem, mutable>]> -> !ttg.memdesc<1xi32, #shared, #smem, mutable, 2x1>
4848
%3 = ttg.local_load %2 : !ttg.memdesc<1xi32, #shared, #smem, mutable, 2x1> -> tensor<1xi32, #blocked>
49-
nvws.aref.get.exit %1[%c0_i32] [#nvws.async_op<none>] : <[!ttg.memdesc<2x1xi32, #shared, #smem, mutable>]>
49+
nvws.aref.get.exit %1[%c0_i32] [#nvws.async_op<none>] {loop.cluster = 2 : i32, loop.stage = 3 : i32} : <[!ttg.memdesc<2x1xi32, #shared, #smem, mutable>]>
5050
"op_b"(%3) : (tensor<1xi32, #blocked>) -> ()
5151
}
5252
nvws.warp_group.return
5353
}
5454
partition2 num_warps(4) {
5555
scf.for %arg3 = %arg0 to %arg1 step %arg2 : i32 {
5656
// CHECK: [[FULLMBAR:%.*]] = ttg.memdesc_index [[FULL]]
57-
// CHECK-NEXT: ttng.wait_barrier [[FULLMBAR]]
57+
// CHECK-NEXT: ttng.wait_barrier [[FULLMBAR]], {{.*}} {loop.cluster = 3 : i32, loop.stage = 4 : i32}
5858
// CHECK: [[VAL:%.*]] = ttg.local_load
5959
// CHECK-NEXT: [[EMPTYMBAR:%.*]] = ttg.memdesc_index [[EMPTY]]
60-
// CHECK-NEXT: ttng.arrive_barrier [[EMPTYMBAR]], 1
60+
// CHECK-NEXT: ttng.arrive_barrier [[EMPTYMBAR]], 1 {loop.cluster = 3 : i32, loop.stage = 4 : i32}
6161
// CHECK: "op_c"([[VAL]])
6262
// CHECK: "op_d"([[VAL]])
63-
%2 = nvws.aref.get.enter %1[%c0_i32, %c0_i32] : <[!ttg.memdesc<2x1xi32, #shared, #smem, mutable>]> -> !ttg.memdesc<1xi32, #shared, #smem, mutable, 2x1>
63+
%2 = nvws.aref.get.enter %1[%c0_i32, %c0_i32] {loop.cluster = 3 : i32, loop.stage = 4 : i32}: <[!ttg.memdesc<2x1xi32, #shared, #smem, mutable>]> -> !ttg.memdesc<1xi32, #shared, #smem, mutable, 2x1>
6464
%3 = ttg.local_load %2 : !ttg.memdesc<1xi32, #shared, #smem, mutable, 2x1> -> tensor<1xi32, #blocked>
65-
nvws.aref.get.exit %1[%c0_i32] [#nvws.async_op<none>] : <[!ttg.memdesc<2x1xi32, #shared, #smem, mutable>]>
65+
nvws.aref.get.exit %1[%c0_i32] [#nvws.async_op<none>] {loop.cluster = 3 : i32, loop.stage = 4 : i32} : <[!ttg.memdesc<2x1xi32, #shared, #smem, mutable>]>
6666
"op_c"(%3) : (tensor<1xi32, #blocked>) -> ()
6767
"op_d"(%3) : (tensor<1xi32, #blocked>) -> ()
6868
}

third_party/nvidia/lib/Dialect/NVWS/Transforms/LowerAref.cpp

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include "nvidia/include/Dialect/NVWS/IR/Dialect.h"
3838
#include "nvidia/include/Dialect/NVWS/Transforms/Passes.h"
3939
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
40+
#include "triton/Dialect/TritonGPU/Transforms/PartitionBuilder.h"
4041
#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h"
4142
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
4243
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
@@ -61,6 +62,16 @@ namespace {
6162

6263
// ----------------------------------------------------------------------------
6364

65+
void assignStageCluster(Operation *op, StageCluster stageCluster,
66+
OpBuilder &builder) {
67+
if (stageCluster) {
68+
op->setAttr(triton::kLoopStageAttrName,
69+
builder.getI32IntegerAttr(stageCluster->first));
70+
op->setAttr(triton::kLoopClusterAttrName,
71+
builder.getI32IntegerAttr(stageCluster->second));
72+
}
73+
}
74+
6475
struct ArefValue {
6576
Value emptyMbars;
6677
Value fullMbars;
@@ -266,7 +277,9 @@ LogicalResult rewritePutEnterOp(ArefCreateOp arefOp, ArefPutEnterOp op,
266277
// get empty barrier at a given stage
267278
Value emptyBarrier = getEmptyBarrier(rewriter, loc, arefVal, op.getStage());
268279

269-
rewriter.create<WaitBarrierOp>(loc, emptyBarrier, op.getPhase());
280+
auto waitOp =
281+
rewriter.create<WaitBarrierOp>(loc, emptyBarrier, op.getPhase());
282+
assignStageCluster(waitOp, getStageCluster(op), rewriter);
270283
auto views = getSubViews(arefVal, op.getStage(), loc, rewriter);
271284
assert(views.size() == op.getResults().size());
272285

@@ -287,7 +300,8 @@ LogicalResult rewriteGetEnterOp(ArefCreateOp arefOp, ArefGetEnterOp op,
287300
rewriter.setInsertionPointAfter(op);
288301

289302
Value fullBarrier = getFullBarrier(rewriter, loc, arefVal, op.getStage());
290-
rewriter.create<WaitBarrierOp>(loc, fullBarrier, op.getPhase());
303+
auto waitOp = rewriter.create<WaitBarrierOp>(loc, fullBarrier, op.getPhase());
304+
assignStageCluster(waitOp, getStageCluster(op), rewriter);
291305
auto views = getSubViews(arefVal, op.getStage(), loc, rewriter);
292306
assert(views.size() == op.getResults().size());
293307

@@ -298,17 +312,19 @@ LogicalResult rewriteGetEnterOp(ArefCreateOp arefOp, ArefGetEnterOp op,
298312
}
299313

300314
LogicalResult insertArriveBarrier(Location loc, ArrayAttr asyncOps,
301-
PatternRewriter &rewriter, Value mbar) {
315+
PatternRewriter &rewriter, Value mbar,
316+
StageCluster stageCluster) {
302317
for (auto asyncOp : asyncOps) {
303318
auto asyncOpEnum = cast<AsyncOpAttr>(asyncOp).getValue();
319+
Operation *arriveOp = {};
304320
switch (asyncOpEnum) {
305321
case AsyncOp::NONE:
306322
case AsyncOp::WGMMA:
307-
rewriter.create<nvidia_gpu::ArriveBarrierOp>(loc, mbar, 1);
323+
arriveOp = rewriter.create<nvidia_gpu::ArriveBarrierOp>(loc, mbar, 1);
308324
break;
309325
case AsyncOp::TC5MMA:
310326
case AsyncOp::TMEMCopy:
311-
rewriter.create<nvidia_gpu::TCGen5CommitOp>(loc, mbar);
327+
arriveOp = rewriter.create<nvidia_gpu::TCGen5CommitOp>(loc, mbar);
312328
break;
313329

314330
case AsyncOp::TMALoad:
@@ -318,6 +334,8 @@ LogicalResult insertArriveBarrier(Location loc, ArrayAttr asyncOps,
318334
default:
319335
llvm_unreachable("unknown async op");
320336
}
337+
if (arriveOp)
338+
assignStageCluster(arriveOp, stageCluster, rewriter);
321339
}
322340

323341
return success();
@@ -328,15 +346,17 @@ LogicalResult rewritePutExitOp(ArefPutExitOp op, PatternRewriter &rewriter,
328346
auto loc = op->getLoc();
329347
rewriter.setInsertionPointAfter(op);
330348
Value fullBarrier = getFullBarrier(rewriter, loc, arefVal, op.getStage());
331-
return insertArriveBarrier(loc, op.getAsyncOps(), rewriter, fullBarrier);
349+
return insertArriveBarrier(loc, op.getAsyncOps(), rewriter, fullBarrier,
350+
getStageCluster(op));
332351
}
333352

334353
LogicalResult rewriteGetExitOp(ArefGetExitOp op, PatternRewriter &rewriter,
335354
ArefValue arefVal) {
336355
auto loc = op->getLoc();
337356
rewriter.setInsertionPointAfter(op);
338357
Value emptyBarrier = getEmptyBarrier(rewriter, loc, arefVal, op.getStage());
339-
return insertArriveBarrier(loc, op.getAsyncOps(), rewriter, emptyBarrier);
358+
return insertArriveBarrier(loc, op.getAsyncOps(), rewriter, emptyBarrier,
359+
getStageCluster(op));
340360
}
341361

342362
LogicalResult rewriteArefDestroyOp(ArefDestroyOp op, PatternRewriter &rewriter,

0 commit comments

Comments
 (0)