@@ -24,45 +24,45 @@ module attributes {"ttg.num-warps" = 4 : i32} {
24
24
partition0 num_warps (4 ) {
25
25
scf.for %arg3 = %arg0 to %arg1 step %arg2 : i32 {
26
26
%2 = " op_a" () : () -> tensor <1 xi32 , #blocked >
27
- %3 = nvws.aref.put.enter %1 [%c0_i32 , %c0_i32 ] : <[!ttg.memdesc <2 x1 xi32 , #shared , #smem , mutable >]> -> !ttg.memdesc <1 xi32 , #shared , #smem , mutable , 2 x1 >
27
+ %3 = nvws.aref.put.enter %1 [%c0_i32 , %c0_i32 ] { loop.cluster = 1 : i32 , loop.stage = 3 : i32 } : <[!ttg.memdesc <2 x1 xi32 , #shared , #smem , mutable >]> -> !ttg.memdesc <1 xi32 , #shared , #smem , mutable , 2 x1 >
28
28
ttg.local_store %2 , %3 : tensor <1 xi32 , #blocked > -> !ttg.memdesc <1 xi32 , #shared , #smem , mutable , 2 x1 >
29
29
// CHECK: op_a
30
30
// CHECK-NEXT: [[EMPTYMBAR:%.*]] = ttg.memdesc_index [[EMPTY]]
31
- // CHECK-NEXT: ttng.wait_barrier [[EMPTYMBAR]]
31
+ // CHECK-NEXT: ttng.wait_barrier [[EMPTYMBAR]], {{.*}} {loop.cluster = 1 : i32, loop.stage = 3 : i32}
32
32
// CHECK: local_store
33
33
// CHECK-NEXT: [[FULLMBAR:%.*]] = ttg.memdesc_index [[FULL]]
34
- // CHECK-NEXT: ttng.arrive_barrier [[FULLMBAR]], 1
35
- nvws.aref.put.exit %1 [%c0_i32 ] [#nvws.async_op <none >] : <[!ttg.memdesc <2 x1 xi32 , #shared , #smem , mutable >]>
34
+ // CHECK-NEXT: ttng.arrive_barrier [[FULLMBAR]], 1 {loop.cluster = 1 : i32, loop.stage = 3 : i32}
35
+ nvws.aref.put.exit %1 [%c0_i32 ] [#nvws.async_op <none >] { loop.cluster = 1 : i32 , loop.stage = 3 : i32 } : <[!ttg.memdesc <2 x1 xi32 , #shared , #smem , mutable >]>
36
36
}
37
37
nvws.warp_group.yield
38
38
}
39
39
partition1 num_warps (4 ) {
40
40
scf.for %arg3 = %arg0 to %arg1 step %arg2 : i32 {
41
41
// CHECK: [[FULLMBAR:%.*]] = ttg.memdesc_index [[FULL]]
42
- // CHECK-NEXT: ttng.wait_barrier [[FULLMBAR]]
42
+ // CHECK-NEXT: ttng.wait_barrier [[FULLMBAR]], {{.*}} {loop.cluster = 2 : i32, loop.stage = 3 : i32}
43
43
// CHECK: [[VAL:%.*]] = ttg.local_load
44
44
// CHECK-NEXT: [[EMPTYMBAR:%.*]] = ttg.memdesc_index [[EMPTY]]
45
- // CHECK-NEXT: ttng.arrive_barrier [[EMPTYMBAR]], 1
45
+ // CHECK-NEXT: ttng.arrive_barrier [[EMPTYMBAR]], 1 {loop.cluster = 2 : i32, loop.stage = 3 : i32}
46
46
// CHECK: "op_b"([[VAL]])
47
- %2 = nvws.aref.get.enter %1 [%c0_i32 , %c0_i32 ] : <[!ttg.memdesc <2 x1 xi32 , #shared , #smem , mutable >]> -> !ttg.memdesc <1 xi32 , #shared , #smem , mutable , 2 x1 >
47
+ %2 = nvws.aref.get.enter %1 [%c0_i32 , %c0_i32 ] { loop.cluster = 2 : i32 , loop.stage = 3 : i32 } : <[!ttg.memdesc <2 x1 xi32 , #shared , #smem , mutable >]> -> !ttg.memdesc <1 xi32 , #shared , #smem , mutable , 2 x1 >
48
48
%3 = ttg.local_load %2 : !ttg.memdesc <1 xi32 , #shared , #smem , mutable , 2 x1 > -> tensor <1 xi32 , #blocked >
49
- nvws.aref.get.exit %1 [%c0_i32 ] [#nvws.async_op <none >] : <[!ttg.memdesc <2 x1 xi32 , #shared , #smem , mutable >]>
49
+ nvws.aref.get.exit %1 [%c0_i32 ] [#nvws.async_op <none >] { loop.cluster = 2 : i32 , loop.stage = 3 : i32 } : <[!ttg.memdesc <2 x1 xi32 , #shared , #smem , mutable >]>
50
50
" op_b" (%3 ) : (tensor <1 xi32 , #blocked >) -> ()
51
51
}
52
52
nvws.warp_group.return
53
53
}
54
54
partition2 num_warps (4 ) {
55
55
scf.for %arg3 = %arg0 to %arg1 step %arg2 : i32 {
56
56
// CHECK: [[FULLMBAR:%.*]] = ttg.memdesc_index [[FULL]]
57
- // CHECK-NEXT: ttng.wait_barrier [[FULLMBAR]]
57
+ // CHECK-NEXT: ttng.wait_barrier [[FULLMBAR]], {{.*}} {loop.cluster = 3 : i32, loop.stage = 4 : i32}
58
58
// CHECK: [[VAL:%.*]] = ttg.local_load
59
59
// CHECK-NEXT: [[EMPTYMBAR:%.*]] = ttg.memdesc_index [[EMPTY]]
60
- // CHECK-NEXT: ttng.arrive_barrier [[EMPTYMBAR]], 1
60
+ // CHECK-NEXT: ttng.arrive_barrier [[EMPTYMBAR]], 1 {loop.cluster = 3 : i32, loop.stage = 4 : i32}
61
61
// CHECK: "op_c"([[VAL]])
62
62
// CHECK: "op_d"([[VAL]])
63
- %2 = nvws.aref.get.enter %1 [%c0_i32 , %c0_i32 ] : <[!ttg.memdesc <2 x1 xi32 , #shared , #smem , mutable >]> -> !ttg.memdesc <1 xi32 , #shared , #smem , mutable , 2 x1 >
63
+ %2 = nvws.aref.get.enter %1 [%c0_i32 , %c0_i32 ] { loop.cluster = 3 : i32 , loop.stage = 4 : i32 } : <[!ttg.memdesc <2 x1 xi32 , #shared , #smem , mutable >]> -> !ttg.memdesc <1 xi32 , #shared , #smem , mutable , 2 x1 >
64
64
%3 = ttg.local_load %2 : !ttg.memdesc <1 xi32 , #shared , #smem , mutable , 2 x1 > -> tensor <1 xi32 , #blocked >
65
- nvws.aref.get.exit %1 [%c0_i32 ] [#nvws.async_op <none >] : <[!ttg.memdesc <2 x1 xi32 , #shared , #smem , mutable >]>
65
+ nvws.aref.get.exit %1 [%c0_i32 ] [#nvws.async_op <none >] { loop.cluster = 3 : i32 , loop.stage = 4 : i32 } : <[!ttg.memdesc <2 x1 xi32 , #shared , #smem , mutable >]>
66
66
" op_c" (%3 ) : (tensor <1 xi32 , #blocked >) -> ()
67
67
" op_d" (%3 ) : (tensor <1 xi32 , #blocked >) -> ()
68
68
}
0 commit comments