Skip to content

Commit 896cbdb

Browse files
3gxMogball
andauthored
[WS] Update aref ops and lower_aref pass (#7479)
* Update arefs ops to align with [`nvws`](https://github.com/triton-lang/triton/tree/aref_auto_ws) branch. * Update lower_aref pass to:     -  threads aref-index through control flow ops (aref-index is used to compute both stage and mbarrier phase)     -  lowers arefs to mbarriers     -  [TODO]: add support for `tma_load` aref lowering, and lowering `nvws::arrive_barrier` will be added in subsequent PRs enabling e2e functionality * Updated lit-tests cc: @masahi, @htyu, @manman-ren, @jeffniu-openai <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [X] I am not making a trivial change, such as fixing a typo in a comment. - [X] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [X] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [X] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [ ] I have not added any `lit` tests. - [X] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) --------- Co-authored-by: Jeff Niu <[email protected]>
1 parent 7f341eb commit 896cbdb

File tree

13 files changed

+768
-305
lines changed

13 files changed

+768
-305
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_
2+
#define TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_
3+
4+
#include "triton/Dialect/Triton/IR/Dialect.h"
5+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
6+
7+
namespace mlir::triton::nvidia_gpu {
8+
9+
LogicalResult verifyBarrierType(Operation *op,
10+
mlir::triton::gpu::MemDescType barrierType);
11+
12+
}
13+
14+
#endif // TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
#include "mlir/IR/Builders.h"
2525
#include "mlir/Support/LLVM.h"
2626
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
27-
2827
#include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.cpp.inc"
28+
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
2929

3030
using namespace mlir::triton::gpu;
3131

@@ -131,15 +131,6 @@ LogicalResult WarpGroupDotWaitOp::inferReturnTypes(
131131
return mlir::success();
132132
}
133133

134-
static LogicalResult
135-
verifyBarrierType(Operation *op, mlir::triton::gpu::MemDescType barrierType) {
136-
if (!barrierType.getElementType().isInteger(64) ||
137-
barrierType.getShape() != ArrayRef<int64_t>({1}))
138-
return op->emitOpError(
139-
"barrier allocation must be a descriptor of 1xi64 type");
140-
return success();
141-
}
142-
143134
// -- InitBarrierOp --
144135
LogicalResult InitBarrierOp::verify() {
145136
if (failed(verifyBarrierType(*this, getAlloc().getType())))

lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_triton_library(TritonNvidiaGPUTransforms
1111
TensorMemoryAllocation.cpp
1212
TMALowering.cpp
1313
TMAUtilities.cpp
14+
Utility.cpp
1415

1516
DEPENDS
1617
TritonNvidiaGPUTransformsIncGen
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
2+
3+
#define DEBUG_TYPE "ttng-utility"
4+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
5+
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
6+
namespace mlir::triton::nvidia_gpu {
7+
8+
using namespace triton;
9+
10+
LogicalResult verifyBarrierType(Operation *op,
11+
mlir::triton::gpu::MemDescType barrierType) {
12+
if (!barrierType.getElementType().isInteger(64) ||
13+
barrierType.getShape() != ArrayRef<int64_t>({1}))
14+
return op->emitOpError(
15+
"barrier allocation must be a descriptor of 1xi64 type");
16+
return success();
17+
}
18+
19+
} // namespace mlir::triton::nvidia_gpu

test/NVWS/invalid.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-w
2020
tt.func @aref_get_single(%d : !ttg.memdesc<1x64x16xf16, #shared0, #smem>, %e : !ttg.memdesc<2x16x32xf16, #shared0, #smem>) {
2121
%c0_i32 = arith.constant 0 : i32
2222
// expected-error @below {{Aref buffer is used elsewhere, Aref cannot guarantee async safety}}
23-
%0 = nvws.aref.create %d, %e : !nvws.aref<[!ttg.memdesc<1x64x16xf16, #shared0, #smem>, !ttg.memdesc<2x16x32xf16, #shared0, #smem>], 1>
23+
%0 = nvws.aref.create %d, %e : !nvws.aref<[!ttg.memdesc<1x64x16xf16, #shared0, #smem>, !ttg.memdesc<2x16x32xf16, #shared0, #smem>]>
2424
%1 = ttng.tmem_alloc %d : (!ttg.memdesc<1x64x16xf16, #shared0, #smem>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
2525
tt.return
2626
}
@@ -35,7 +35,7 @@ module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-w
3535
%0 = nvws.aref.create %d, %e : !nvws.aref<[!ttg.memdesc<1x64x16xf16, #shared0, #smem>, !ttg.memdesc<1x16x32xf16, #shared0, #smem>]>
3636
%c0_i32 = arith.constant 0 : i32
3737
// expected-error @below {{Aref has different number of arguments than enter}}
38-
%1 = nvws.aref.put.enter %0[%c0_i32, %c0_i32] :
38+
%1 = nvws.aref.put.enter %0[%c0_i32] :
3939
!nvws.aref<[!ttg.memdesc<1x64x16xf16, #shared0, #smem>, !ttg.memdesc<1x16x32xf16, #shared0, #smem>]>
4040
-> !ttg.memdesc<64x16xf16, #shared0, #smem>
4141
tt.return
@@ -51,7 +51,7 @@ module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-w
5151
%c0_i32 = arith.constant 0 : i32
5252
%0 = nvws.aref.create %d, %e : !nvws.aref<[!ttg.memdesc<1x64x16xf16, #shared0, #smem>, !ttg.memdesc<1x16x32xf16, #shared0, #smem>]>
5353
// expected-error @below {{Dimensions don't match}}
54-
%1:2 = nvws.aref.put.enter %0[%c0_i32, %c0_i32] :
54+
%1:2 = nvws.aref.put.enter %0[%c0_i32] :
5555
!nvws.aref<[!ttg.memdesc<1x64x16xf16, #shared0, #smem>, !ttg.memdesc<1x16x32xf16, #shared0, #smem>]>
5656
-> !ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<32x32xf16, #shared0, #smem>
5757
tt.return
@@ -67,7 +67,7 @@ module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-w
6767
%c0_i32 = arith.constant 0 : i32
6868
%0 = nvws.aref.create %d, %e : !nvws.aref<[!ttg.memdesc<1x64x16xf16, #shared0, #smem>, !ttg.memdesc<1x16x32xf16, #shared0, #smem>]>
6969
// expected-error @below {{MLIR Types don't match}}
70-
nvws.aref.get.enter %0[%c0_i32, %c0_i32] :
70+
nvws.aref.get.enter %0[%c0_i32] :
7171
!nvws.aref<[!ttg.memdesc<1x64x16xf16, #shared0, #smem>, !ttg.memdesc<1x16x32xf16, #shared0, #smem>]>
7272
-> !ttg.memdesc<64x16xf16, #shared0, #smem>, tensor<16x32xf16>
7373
tt.return

test/NVWS/lower_aref.mlir

Lines changed: 174 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,43 +4,182 @@
44
#smem = #ttg.shared_memory
55
#tmem = #ttng.tensor_memory
66
module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} {
7-
//CHECK: tt.func @aref_get_put
8-
// CHECK-NEXT: [[ZERO:%.*]] = arith.constant 0 : i32
9-
// CHECK-NEXT: [[ONE:%.*]] = arith.constant 1 : i32
10-
// CHECK-NEXT: [[EMPTY:%.*]] = ttg.local_alloc {aref_empty_mbarriers}
11-
// CHECK-NEXT: [[FULL:%.*]] = ttg.local_alloc {aref_full_mbarriers}
12-
// CHECK-NEXT: scf.for
13-
// CHECK-NEXT: [[EMPTYSLICE:%.*]] = ttg.memdesc_subview [[EMPTY]]
14-
// CHECK-NEXT: ttng.init_barrier [[EMPTYSLICE]], 0
15-
// CHECK-NEXT: [[FULLSLICE:%.*]] = ttg.memdesc_subview [[FULL]]
16-
// CHECK-NEXT: ttng.init_barrier [[FULLSLICE]], 1
17-
// CHECK-NEXT: }
18-
// CHECK-NEXT: [[EMPTYSLICE2:%.*]] = ttg.memdesc_subview [[EMPTY]]
19-
// CHECK-NEXT: ttng.wait_barrier [[EMPTYSLICE2]], [[ONE]]
20-
// CHECK-NEXT: [[A:%.*]] = ttg.memdesc_subview %arg0
21-
// CHECK-NEXT: [[B:%.*]] = ttg.memdesc_subview %arg1
22-
// CHECK-NEXT: "foo"([[A]], [[B]])
23-
// CHECK-NEXT: [[FULLSLICE2:%.*]] = ttg.memdesc_subview [[FULL]]
24-
// CHECK-NEXT: ttng.arrive_barrier [[FULLSLICE2]], 1
25-
// CHECK-NEXT: [[FULLSLICE3:%.*]] = ttg.memdesc_subview [[FULL]]
26-
// CHECK-NEXT: ttng.wait_barrier [[FULLSLICE3]], [[ZERO]]
27-
// CHECK-NEXT: [[AA:%.*]] = ttg.memdesc_subview %arg0
28-
// CHECK-NEXT: [[BB:%.*]] = ttg.memdesc_subview %arg1
29-
// CHECK-NEXT: "bar"([[AA]], [[BB]])
30-
// CHECK-NEXT: [[EMPTYSLICE3:%.*]] = ttg.memdesc_subview [[EMPTY]]
31-
// CHECK-NEXT: ttng.arrive_barrier [[EMPTYSLICE3]],
32-
// CHECK-NEXT: tt.return
33-
// CHECK-NEXT: }
34-
tt.func @aref_get_put(%d : !ttg.memdesc<1x64x16xf16, #shared0, #tmem>, %e : !ttg.memdesc<1x16x32xf16, #shared0, #smem>) {
7+
//CHECK-LABEL: @aref_lowering
8+
tt.func @aref_lowering(%d : !ttg.memdesc<3x64x16xf16, #shared0, #tmem>,
9+
%e : !ttg.memdesc<3x16x32xf16, #shared0, #smem>,
10+
%cond : i1) {
3511
%c0_i32 = arith.constant 0 : i32
3612
%c1_i32 = arith.constant 1 : i32
37-
%0 = nvws.aref.create %d, %e : !nvws.aref<[!ttg.memdesc<1x64x16xf16, #shared0, #tmem>, !ttg.memdesc<1x16x32xf16, #shared0, #smem>]>
38-
%1:2 = nvws.aref.put.enter %0[%c0_i32, %c1_i32] : !nvws.aref<[!ttg.memdesc<1x64x16xf16, #shared0, #tmem>, !ttg.memdesc<1x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>
39-
"foo"(%1#0, %1#1) : (!ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>) -> ()
40-
nvws.aref.put.exit %0[%c0_i32] : !nvws.aref<[!ttg.memdesc<1x64x16xf16, #shared0, #tmem>, !ttg.memdesc<1x16x32xf16, #shared0, #smem>]>
41-
%2:2 = nvws.aref.get.enter %0[%c0_i32, %c0_i32] : !nvws.aref<[!ttg.memdesc<1x64x16xf16, #shared0, #tmem>, !ttg.memdesc<1x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>
42-
"bar"(%2#0, %2#1) : (!ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>) -> ()
43-
nvws.aref.get.exit %0[%c0_i32] : !nvws.aref<[!ttg.memdesc<1x64x16xf16, #shared0, #tmem>, !ttg.memdesc<1x16x32xf16, #shared0, #smem>]>
13+
%lb = arith.constant 0 : i32
14+
// CHECK: [[C3:%.*]] = arith.constant 3 : i32
15+
// CHECK: [[C0:%.*]] = arith.constant 0 : i32
16+
// CHECK: [[C1:%.*]] = arith.constant 1 : i32
17+
%ub = arith.constant 4 : i32
18+
19+
// CHECK: [[EMPTY0:%.*]] = ttg.local_alloc
20+
// CHECK-NEXT: [[FULL0:%.*]] = ttg.local_alloc
21+
// CHECK-NEXT: scf.for
22+
// CHECK-NEXT: [[EMPTYSLICE:%.*]] = ttg.memdesc_subview [[EMPTY0]]
23+
// CHECK-NEXT: ttng.init_barrier [[EMPTYSLICE]], 1
24+
// CHECK-NEXT: [[FULLSLICE:%.*]] = ttg.memdesc_subview [[FULL0]]
25+
// CHECK-NEXT: ttng.init_barrier [[FULLSLICE]], 129
26+
// CHECK-NEXT: }
27+
%aref0 = nvws.aref.create %d, %e : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>
28+
29+
// CHECK: [[EMPTY1:%.*]] = ttg.local_alloc
30+
// CHECK-NEXT: [[FULL1:%.*]] = ttg.local_alloc
31+
// CHECK-NEXT: scf.for
32+
// CHECK-NEXT: [[EMPTYSLICE:%.*]] = ttg.memdesc_subview [[EMPTY1]]
33+
// CHECK-NEXT: ttng.init_barrier [[EMPTYSLICE]], 256
34+
// CHECK-NEXT: [[FULLSLICE:%.*]] = ttg.memdesc_subview [[FULL1]]
35+
// CHECK-NEXT: ttng.init_barrier [[FULLSLICE]], 128
36+
// CHECK-NEXT: }
37+
%aref1 = nvws.aref.create %d, %e : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>
38+
39+
nvws.warp_group
40+
partition0 num_warps(4) {
41+
// CHECK: [[IDX:%.*]]:4 = scf.for [[I:%.*]] = [[LB:%.*]] to [[UB:%.*]] step [[C1:%.*]] iter_args([[IDX0:%.*]] = [[C0]], [[IDX1:%.*]] = [[C0]], [[IDX2:%.*]] = [[C0]], [[IDX3:%.*]] = [[C0]])
42+
scf.for %i = %lb to %ub step %c1_i32 : i32{
43+
44+
// CHECK-NEXT: [[EMPTYIDX:%.*]] = arith.remsi [[IDX0]], [[C3]]
45+
// CHECK-NEXT: [[EMPTYMBAR:%.*]] = ttg.memdesc_subview [[EMPTY0]][[[EMPTYIDX]]]
46+
// CHECK-NEXT: [[PHASE_DIV:%.*]] = arith.divsi [[IDX0]], [[C3]]
47+
// CHECK-NEXT: [[PHASE_AND:%.*]] = arith.andi [[PHASE_DIV]], [[C1]]
48+
// CHECK-NEXT: [[PHASE_XOR:%.*]] = arith.xori [[PHASE_AND]], [[C1]]
49+
// CHECK-NEXT: ttng.wait_barrier [[EMPTYMBAR]], [[PHASE_XOR]]
50+
%1:2 = nvws.aref.put.enter %aref0[%c0_i32] {aref_tag = "put0"} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>
51+
52+
// CHECK-NEXT: [[STAGE:%.*]] = arith.remsi [[IDX0]], [[C3]]
53+
// CHECK-NEXT: [[BUFA:%.*]] = ttg.memdesc_subview %arg0[[[STAGE]],{{.*}},{{.*}}]
54+
// CHECK-NEXT: [[BUFB:%.*]] = ttg.memdesc_subview %arg1[[[STAGE]],{{.*}},{{.*}}]
55+
// CHECK-NEXT: [[FULLIDX:%.*]] = arith.remsi [[IDX2]], [[C3]]
56+
// CHECK-NEXT: [[FULLMBAR:%.*]] = ttg.memdesc_subview [[FULL0]][[[FULLIDX]]]
57+
// CHECK-NEXT: ttng.barrier_expect [[FULLMBAR]], 0
58+
// CHECK-NEXT: [[IDX0a:%.*]] = arith.addi [[IDX0]], [[C1]]
59+
// CHECK-NEXT: "tma_load"([[BUFA]])
60+
// CHECK-NEXT: "cp_async"([[BUFB]])
61+
"tma_load"(%1#0) : (!ttg.memdesc<64x16xf16, #shared0, #tmem>) -> ()
62+
"cp_async"(%1#1) : (!ttg.memdesc<16x32xf16, #shared0, #smem>) -> ()
63+
64+
// CHECK-NEXT: [[FULLIDX:%.*]] = arith.remsi [[IDX2]], [[C3]]
65+
// CHECK-NEXT: [[FULLMBAR:%.*]] = ttg.memdesc_subview [[FULL0]][[[FULLIDX]]]
66+
// CHECK-NEXT: nvws.async_complete [[FULLMBAR]], async_op = <tma_load>
67+
// CHECK-NEXT: nvws.async_complete [[FULLMBAR]], async_op = <cp_async>
68+
// CHECK-NEXT: [[IDX2a:%.*]] = arith.addi [[IDX2]], [[C1]]
69+
nvws.aref.put.exit %aref0[%c0_i32] [#nvws.async_op<tma_load>, #nvws.async_op<cp_async>] {aref_tag = "put0"} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>
70+
71+
// CHECK-NEXT: [[IDX13:%.*]]:2 = scf.if
72+
scf.if %cond {
73+
74+
// CHECK: arith.remsi [[IDX1]], [[C3]]
75+
// CHECK: arith.divsi [[IDX1]], [[C3]]
76+
// CHECK-NEXT: arith.andi {{.*}}, [[C1]]
77+
// CHECK-NEXT: arith.xori
78+
// CHECK-NEXT: ttng.wait_barrier
79+
// CHECK: [[IDX1a:%.*]] = arith.addi [[IDX1]], [[C1]]
80+
%2:2 = nvws.aref.put.enter %aref1[%c0_i32] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>
81+
"tmem_store"(%2#0, %2#1) : (!ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>) -> ()
82+
83+
// CHECK: arith.remsi [[IDX3]], [[C3]]
84+
// CHECK: [[IDX3a:%.*]] = arith.addi [[IDX3]], [[C1]]
85+
nvws.aref.put.exit %aref1[%c0_i32] [#nvws.async_op<none>] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>
86+
87+
// CHECK: scf.yield [[IDX1a]], [[IDX3a]]
88+
}
89+
// CHECK-NEXT: } else {
90+
// CHECK-NEXT: scf.yield [[IDX1]], [[IDX3]]
91+
// CHECK-NEXT: }
92+
93+
// CHECK: scf.yield [[IDX0a]], [[IDX13]]#0, [[IDX2a]], [[IDX13]]#1
94+
}
95+
96+
// CHECK: [[IDX1:%.*]]:2 = scf.if
97+
scf.if %cond {
98+
99+
// CHECK: arith.remsi [[IDX]]#0, [[C3]]
100+
// CHECK: arith.divsi [[IDX]]#0, [[C3]]
101+
// CHECK: [[IDX0a:%.*]] = arith.addi [[IDX]]#0, [[C1]]
102+
%1:2 = nvws.aref.put.enter %aref0[%c0_i32] {aref_tag = "put1"} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>
103+
"tma_load"(%1#0) : (!ttg.memdesc<64x16xf16, #shared0, #tmem>) -> ()
104+
"cp_async"(%1#1) : (!ttg.memdesc<16x32xf16, #shared0, #smem>) -> ()
105+
106+
// CHECK: arith.remsi [[IDX]]#2, [[C3]]
107+
// CHECK: [[IDX2a:%.*]] = arith.addi [[IDX]]#2, [[C1]]
108+
nvws.aref.put.exit %aref0[%c0_i32] [#nvws.async_op<tma_load>, #nvws.async_op<cp_async>] {aref_tag = "put1"} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>
109+
}
110+
111+
// CHECK: arith.remsi [[IDX]]#1, [[C3]]
112+
// CHECK: arith.divsi [[IDX]]#1, [[C3]]
113+
%1:2 = nvws.aref.put.enter %aref1[%c0_i32] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>
114+
"tmem_store"(%1#0, %1#1) : (!ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>) -> ()
115+
// CHECK: arith.remsi [[IDX]]#3, [[C3]]
116+
nvws.aref.put.exit %aref1[%c0_i32] [#nvws.async_op<none>] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>
117+
nvws.warp_group.return
118+
}
119+
partition1 num_warps(8) {
120+
// CHECK: [[IDX:%.*]]:4 = scf.for [[I:%.*]] = [[LB:%.*]] to [[UB:%.*]] step [[C1:%.*]] iter_args([[IDX0:%.*]] = [[C0]], [[IDX1:%.*]] = [[C0]], [[IDX2:%.*]] = [[C0]], [[IDX3:%.*]] = [[C0]])
121+
scf.for %i = %lb to %ub step %c1_i32 : i32{
122+
123+
// CHECK-NEXT: [[FULLIDX:%.*]] = arith.remsi [[IDX0]], [[C3]]
124+
// CHECK-NEXT: [[FULLMBAR:%.*]] = ttg.memdesc_subview [[FULL0]][[[FULLIDX]]]
125+
// CHECK-NEXT: [[PHASE_DIV:%.*]] = arith.divsi [[IDX0]], [[C3]]
126+
// CHECK-NEXT: [[PHASE_AND:%.*]] = arith.andi [[PHASE_DIV]], [[C1]]
127+
// CHECK-NEXT: ttng.wait_barrier [[FULLMBAR]], [[PHASE_AND]]
128+
%2:2 = nvws.aref.get.enter %aref0[%c0_i32] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>
129+
130+
// CHECK-NEXT: [[STAGE:%.*]] = arith.remsi [[IDX0]], [[C3]]
131+
// CHECK-NEXT: [[BUFA:%.*]] = ttg.memdesc_subview %arg0[[[STAGE]],{{.*}},{{.*}}]
132+
// CHECK-NEXT: [[BUFB:%.*]] = ttg.memdesc_subview %arg1[[[STAGE]],{{.*}},{{.*}}]
133+
// CHECK-NEXT: arith.addi
134+
// CHECK-NEXT: "tc5mma"([[BUFA]], [[BUFB]])
135+
"tc5mma"(%2#0, %2#1) : (!ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>) -> ()
136+
137+
// CHECK-NEXT: [[EMPTYIDX:%.*]] = arith.remsi [[IDX2]], [[C3]]
138+
// CHECK-NEXT: [[EMPTYMBAR:%.*]] = ttg.memdesc_subview [[EMPTY0]][[[EMPTYIDX]]]
139+
// CHECK-NEXT: nvws.async_complete [[EMPTYMBAR]], async_op = <tc5mma>
140+
// CHECK-NEXT: arith.addi
141+
nvws.aref.get.exit %aref0[%c0_i32] [#nvws.async_op<tc5mma>] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>
142+
143+
// CHECK: [[IDX13:%.*]]:2 = scf.if
144+
scf.if %cond {
145+
// CHECK: arith.remsi [[IDX1]], [[C3]]
146+
// CHECK: arith.divsi [[IDX1]], [[C3]]
147+
// CHECK-NEXT: arith.andi {{.*}}, [[C1]]
148+
// CHECK-NEXT: ttng.wait_barrier
149+
%3:2 = nvws.aref.get.enter %aref1[%c0_i32] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>
150+
"tmem_load"(%3#0, %3#1) : (!ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>) -> ()
151+
152+
// CHECK: arith.remsi [[IDX3]], [[C3]]
153+
// CHECK-NEXT: ttg.memdesc_subview
154+
// CHECK-NEXT: nvws.async_complete {{.*}}, async_op = <none>
155+
nvws.aref.get.exit %aref1[%c0_i32] [#nvws.async_op<none>] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>
156+
}
157+
// CHECK: } else {
158+
// CHECK-NEXT: scf.yield [[IDX1]], [[IDX3]]
159+
// CHECK-NEXT: }
160+
161+
// CHECK: scf.yield {{.*}}, [[IDX13]]#0, {{.*}}, [[IDX13]]#1
162+
}
163+
scf.if %cond {
164+
// CHECK: arith.remsi [[IDX]]#0, [[C3]]
165+
// CHECK: arith.divsi [[IDX]]#0, [[C3]]
166+
// CHECK-NEXT: arith.andi {{.*}}, [[C1]]
167+
// CHECK-NEXT: ttng.wait_barrier
168+
%2:2 = nvws.aref.get.enter %aref0[%c0_i32] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>
169+
"tc5mma"(%2#0, %2#1) : (!ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>) -> ()
170+
171+
// CHECK: arith.remsi [[IDX]]#2, [[C3]]
172+
nvws.aref.get.exit %aref0[%c0_i32] [#nvws.async_op<tc5mma>] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>
173+
}
174+
// CHECK: } else {
175+
// CHECK-NEXT: scf.yield [[IDX]]#0, [[IDX]]#2
176+
// CHECK-NEXT: }
177+
178+
%2:2 = nvws.aref.get.enter %aref1[%c0_i32] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>
179+
"tmem_load"(%2#0, %2#1) : (!ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>) -> ()
180+
nvws.aref.get.exit %aref1[%c0_i32] [#nvws.async_op<none>] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>
181+
nvws.warp_group.return
182+
}
44183
tt.return
45184
}
46185
}

0 commit comments

Comments
 (0)