Skip to content

Commit 98147ef

Browse files
authored
[AMD][BACKEND] Support padded shared layouts with direct-to-lds loads (triton-lang#8085)
Adds support for padded shared layouts to `AMDCoalesceAsyncCopy`, `UpdateAsyncWaitCount` and the lowering of `ttg.async_copy_global_to_local`/`amdgpu.buffer_load_to_local`.
1 parent a6d11f7 commit 98147ef

File tree

7 files changed

+163
-53
lines changed

7 files changed

+163
-53
lines changed

test/Conversion/amd/async_ops_to_llvm.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,26 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
2121

2222
// -----
2323

24+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
25+
#shared = #ttg.padded_shared<[4:+4] {order = [1, 0], shape = [32, 64]}>
26+
#smem = #ttg.shared_memory
27+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
28+
// CHECK-LABEL: async_copy_padded
29+
tt.func public @async_copy_padded(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
30+
%arg1: i32 {tt.divisibility = 16 : i32},
31+
%arg2: !ttg.memdesc<32x64xf32, #shared, #smem, mutable>) {
32+
// We need the splat to allow the AxisAnalysis to work during lowering
33+
%1 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<32x64x!tt.ptr<f32>, #blocked>
34+
// Each thread needs to load 8 elements and we load 1 () per global.load.lds
35+
// CHECK-COUNT-8: rocdl.global.load.lds
36+
// CHECK-NOT: rocdl.global.load.lds
37+
%2 = ttg.async_copy_global_to_local %1, %arg2 : tensor<32x64x!tt.ptr<f32>, #blocked> -> <32x64xf32, #shared, #smem, mutable>
38+
tt.return
39+
}
40+
}
41+
42+
// -----
43+
2444
#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
2545
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
2646
#smem = #ttg.shared_memory

test/Conversion/amd/async_ops_to_llvm_invalid.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,26 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
3232
tt.return
3333
}
3434
}
35+
36+
// -----
37+
38+
#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
39+
// Padding interval of 1 forces vec==1 which we cannot lower because it's less than 32bits per lane
40+
#shared = #ttg.padded_shared<[1:+2] {order = [1, 0], shape = [32, 64]}>
41+
#smem = #ttg.shared_memory
42+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
43+
tt.func public @async_copy_padded_invalid_vec(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
44+
%arg1: i32 {tt.divisibility = 16 : i32},
45+
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
46+
// We need the index calculation so AxisAnalysis sees that we can vectorize the load
47+
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
48+
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
49+
%3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked>
50+
%4 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked>
51+
%5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr<f16>, #blocked>, tensor<32x64xi32, #blocked>
52+
53+
// expected-error@+1 {{failed to legalize operation 'ttg.async_copy_global_to_local' that was explicitly marked illegal}}
54+
%6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr<f16>, #blocked> -> <32x64xf16, #shared, #smem, mutable>
55+
tt.return
56+
}
57+
}

test/TritonGPU/amd/amd-coalesce-async-copy.mlir

Lines changed: 49 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,30 @@
55
#smem = #ttg.shared_memory
66
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
77
// sizePerThread = [1] because we have no information about contiguity of src pointers
8-
// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
8+
// CHECK: #[[$NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
9+
// CHECK-LABEL: async_copy_1d
910
tt.func @async_copy_1d(%input: tensor<1024x!tt.ptr<f32>, #blocked>,
1011
%view: !ttg.memdesc<1024xf32, #shared, #smem, mutable>) {
11-
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<1024x!tt.ptr<f32>, #[[NEW_BLOCKED]]>
12-
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<1024x!tt.ptr<f32>, #[[NEW_BLOCKED]]>
12+
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<1024x!tt.ptr<f32>, #[[$NEW_BLOCKED]]>
13+
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<1024x!tt.ptr<f32>, #[[$NEW_BLOCKED]]>
14+
%token = ttg.async_copy_global_to_local %input, %view: tensor<1024x!tt.ptr<f32>, #blocked> -> <1024xf32, #shared, #smem, mutable>
15+
tt.return
16+
}
17+
}
18+
19+
// -----
20+
21+
#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
22+
#shared = #ttg.padded_shared<[4:+4] {order = [0], shape = [1024]}>
23+
#smem = #ttg.shared_memory
24+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
25+
// sizePerThread = [1] because we have no information about contiguity of src pointers
26+
// CHECK: #[[$NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
27+
// CHECK-LABEL: async_copy_with_padding
28+
tt.func @async_copy_with_padding(%input: tensor<1024x!tt.ptr<f32>, #blocked>,
29+
%view: !ttg.memdesc<1024xf32, #shared, #smem, mutable>) {
30+
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<1024x!tt.ptr<f32>, #[[$NEW_BLOCKED]]>
31+
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<1024x!tt.ptr<f32>, #[[$NEW_BLOCKED]]>
1332
%token = ttg.async_copy_global_to_local %input, %view: tensor<1024x!tt.ptr<f32>, #blocked> -> <1024xf32, #shared, #smem, mutable>
1433
tt.return
1534
}
@@ -22,11 +41,12 @@ tt.func @async_copy_1d(%input: tensor<1024x!tt.ptr<f32>, #blocked>,
2241
#smem = #ttg.shared_memory
2342
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
2443
// sizePerThread = [1, 1] because we have no information about contiguity of src pointers
25-
// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
44+
// CHECK: #[[$NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
45+
// CHECK-LABEL: async_copy_2d
2646
tt.func @async_copy_2d(%input: tensor<64x64x!tt.ptr<f32>, #blocked>,
2747
%view: !ttg.memdesc<64x64xf32, #shared, #smem, mutable>) {
28-
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x64x!tt.ptr<f32>, #[[NEW_BLOCKED]]>
29-
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<64x64x!tt.ptr<f32>, #[[NEW_BLOCKED]]>
48+
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x64x!tt.ptr<f32>, #[[$NEW_BLOCKED]]>
49+
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<64x64x!tt.ptr<f32>, #[[$NEW_BLOCKED]]>
3050
%token = ttg.async_copy_global_to_local %input, %view: tensor<64x64x!tt.ptr<f32>, #blocked> -> <64x64xf32, #shared, #smem, mutable>
3151
tt.return
3252
}
@@ -39,11 +59,12 @@ tt.func @async_copy_2d(%input: tensor<64x64x!tt.ptr<f32>, #blocked>,
3959
#smem = #ttg.shared_memory
4060
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
4161
// sizePerThread = [1, 1, 1] because we have no information about contiguity of src pointers
42-
// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [0, 1, 2]}>
62+
// CHECK: #[[$NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [0, 1, 2]}>
63+
// CHECK-LABEL: async_copy_3d
4364
tt.func @async_copy_3d(%input: tensor<1024x1024x1024x!tt.ptr<f32>, #blocked>,
4465
%view: !ttg.memdesc<1024x1024x1024xf32, #shared, #smem, mutable>) {
45-
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<1024x1024x1024x!tt.ptr<f32>, #[[NEW_BLOCKED]]>
46-
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<1024x1024x1024x!tt.ptr<f32>, #[[NEW_BLOCKED]]>
66+
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<1024x1024x1024x!tt.ptr<f32>, #[[$NEW_BLOCKED]]>
67+
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<1024x1024x1024x!tt.ptr<f32>, #[[$NEW_BLOCKED]]>
4768
%token = ttg.async_copy_global_to_local %input, %view: tensor<1024x1024x1024x!tt.ptr<f32>, #blocked> -> <1024x1024x1024xf32, #shared, #smem, mutable>
4869
tt.return
4970
}
@@ -55,15 +76,16 @@ tt.func @async_copy_3d(%input: tensor<1024x1024x1024x!tt.ptr<f32>, #blocked>,
5576
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
5677
#smem = #ttg.shared_memory
5778
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
58-
// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
79+
// CHECK: #[[$NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
80+
// CHECK-LABEL: async_copy_with_mask_and_other
5981
tt.func @async_copy_with_mask_and_other(%input: tensor<64x64x!tt.ptr<f32>, #blocked>,
6082
%view: !ttg.memdesc<64x64xf32, #shared, #smem, mutable>,
6183
%mask: tensor<64x64xi1, #blocked>,
6284
%other: tensor<64x64xf32, #blocked>) {
63-
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x64x!tt.ptr<f32>, #[[NEW_BLOCKED]]>
64-
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x64xi1, #[[NEW_BLOCKED]]>
65-
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x64xf32, #[[NEW_BLOCKED]]>
66-
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<64x64x!tt.ptr<f32>, #[[NEW_BLOCKED]]>
85+
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x64x!tt.ptr<f32>, #[[$NEW_BLOCKED]]>
86+
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x64xi1, #[[$NEW_BLOCKED]]>
87+
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x64xf32, #[[$NEW_BLOCKED]]>
88+
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<64x64x!tt.ptr<f32>, #[[$NEW_BLOCKED]]>
6789
%token = ttg.async_copy_global_to_local %input, %view mask %mask other %other: tensor<64x64x!tt.ptr<f32>, #blocked> -> <64x64xf32, #shared, #smem, mutable>
6890
tt.return
6991
}
@@ -76,7 +98,8 @@ tt.func @async_copy_with_mask_and_other(%input: tensor<64x64x!tt.ptr<f32>, #bloc
7698
#smem = #ttg.shared_memory
7799
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
78100
// Clip to vector size 2 (32bit) because we do not support 64 bit loads to lds
79-
// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
101+
// CHECK: #[[$NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
102+
// CHECK-LABEL: async_copy_vector_size_2
80103
tt.func public @async_copy_vector_size_2(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
81104
%arg1: i32 {tt.divisibility = 16 : i32},
82105
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
@@ -87,8 +110,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
87110
%4 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked>
88111
%5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr<f16>, #blocked>, tensor<32x64xi32, #blocked>
89112

90-
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<32x64x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
91-
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<32x64x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
113+
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<32x64x!tt.ptr<f16>, #[[$NEW_BLOCKED]]>
114+
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<32x64x!tt.ptr<f16>, #[[$NEW_BLOCKED]]>
92115
%6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr<f16>, #blocked> -> <32x64xf16, #shared, #smem, mutable>
93116
tt.return
94117
}
@@ -101,7 +124,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
101124
#smem = #ttg.shared_memory
102125
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
103126
// Clip to vector size 4 (128bit) which is the largest supported load width
104-
// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
127+
// CHECK: #[[$NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
128+
// CHECK-LABEL: async_copy_vector_size_8
105129
tt.func public @async_copy_vector_size_8(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
106130
%arg1: i32 {tt.divisibility = 16 : i32},
107131
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
@@ -112,8 +136,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
112136
%4 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked>
113137
%5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr<f16>, #blocked>, tensor<32x64xi32, #blocked>
114138

115-
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<32x64x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
116-
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<32x64x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
139+
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<32x64x!tt.ptr<f16>, #[[$NEW_BLOCKED]]>
140+
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<32x64x!tt.ptr<f16>, #[[$NEW_BLOCKED]]>
117141
%6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr<f16>, #blocked> -> <32x64xf16, #shared, #smem, mutable>
118142
tt.return
119143
}
@@ -126,7 +150,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
126150
#smem = #ttg.shared_memory
127151
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
128152
// The order of #blocked and #shared are different so we need to clip to 1 element
129-
// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
153+
// CHECK: #[[$NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
154+
// CHECK-LABEL: async_copy_different_order
130155
tt.func public @async_copy_different_order(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
131156
%arg1: i32 {tt.divisibility = 16 : i32},
132157
%arg2: !ttg.memdesc<32x64xf32, #shared, #smem, mutable>) {
@@ -137,8 +162,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
137162
%4 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<32x64x!tt.ptr<f32>, #blocked>
138163
%5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr<f32>, #blocked>, tensor<32x64xi32, #blocked>
139164

140-
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<32x64x!tt.ptr<f32>, #[[NEW_BLOCKED]]>
141-
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<32x64x!tt.ptr<f32>, #[[NEW_BLOCKED]]>
165+
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<32x64x!tt.ptr<f32>, #[[$NEW_BLOCKED]]>
166+
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<32x64x!tt.ptr<f32>, #[[$NEW_BLOCKED]]>
142167
%6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr<f32>, #blocked> -> <32x64xf32, #shared, #smem, mutable>
143168
tt.return
144169
}
@@ -153,7 +178,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
153178
// The shared layout should not be changed
154179
// CHECK: #shared = #ttg.swizzled_shared<{vec = 1, perPhase = 2, maxPhase = 4, order = [1, 0]}>
155180
// CHECK-NOT: #shared1
156-
181+
// CHECK-LABEL: async_copy_2d_swizzled
157182
tt.func @async_copy_2d_swizzled(%input: tensor<64x64x!tt.ptr<f16>, #blocked>,
158183
%view: !ttg.memdesc<64x64xf16, #shared, #smem, mutable>) {
159184
// CHECK: %{{.*}} = ttg.async_copy_global_to_local {{.*}} -> <64x64xf16, #shared, #smem, mutable>

test/TritonGPU/amd/amd-update-async-wait-count.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
66
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
77
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [1, 0]}>
8-
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
8+
#shared1 = #ttg.padded_shared<[4:+4] {order = [1, 0], shape=[16, 256]}>
99
#smem = #ttg.shared_memory
1010
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
1111
// CHECK-LABEL: simple_waitcnt

0 commit comments

Comments
 (0)