Skip to content

Commit 14632a2

Browse files
Merge commit '76f576c8112e8d125c9c66973009c77daa0671fa'
2 parents 8f30fe0 + 76f576c commit 14632a2

File tree

16 files changed

+1167
-662
lines changed

16 files changed

+1167
-662
lines changed

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
9797
mlir::registerTritonAMDGPUCanonicalizePointers();
9898
mlir::registerTritonAMDGPUConvertToBufferOps();
9999
mlir::registerTritonAMDGPUInThreadTranspose();
100+
mlir::registerTritonAMDGPUCoalesceAsyncCopy();
100101
mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints();
101102
mlir::triton::registerTritonAMDGPULowerInstructionSchedHints();
102103

test/Conversion/tritongpu_to_llvm_blackwell.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,44 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
229229

230230
// -----
231231

232+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 8}>
233+
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 8}>
234+
#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
235+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
236+
#tmem_scales = #ttng.tensor_memory_scales_encoding<>
237+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
238+
// CHECK-LABEL: @tc_gen5_mma_block_scale_fp4_a
239+
// CHECK: %[[DESC0:.+]] = llvm.mlir.constant(144769664 : i32) : i32
240+
// CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %{{.+}}, %{{.+}}, %{{.+}}, %[[DESC0]]
241+
// CHECK: %[[DESC1:.+]] = llvm.mlir.constant(681640592 : i32) : i32
242+
// CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %{{.+}}, %{{.+}}, %{{.+}}, %[[DESC1]]
243+
// CHECK: %[[DESC2:.+]] = llvm.mlir.constant(1218511520 : i32) : i32
244+
// CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %{{.+}}, %{{.+}}, %{{.+}}, %[[DESC2]]
245+
// CHECK: %[[DESC3:.+]] = llvm.mlir.constant(1755382448 : i32) : i32
246+
// CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %{{.+}}, %{{.+}}, %{{.+}}, %[[DESC3]]
247+
tt.func @tc_gen5_mma_block_scale_fp4_a(%a: !ttg.memdesc<128x64xi8, #shared1, #ttg.shared_memory>,
248+
%b: !ttg.memdesc<128x128xi8, #shared, #ttg.shared_memory>,
249+
%c: !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>,
250+
%scale_a: !ttg.memdesc<128x2xi8, #tmem_scales, #ttng.tensor_memory>,
251+
%scale_b: !ttg.memdesc<128x2xi8, #tmem_scales, #ttng.tensor_memory>,
252+
%useAcc: i1,
253+
%pred: i1,
254+
%barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>) {
255+
ttng.tc_gen5_mma_scaled %a, %b, %c, %scale_a, %scale_b, %useAcc, %pred lhs = e2m1 rhs = e4m3, %barrier :
256+
(!ttg.memdesc<128x64xi8, #shared1, #ttg.shared_memory>,
257+
!ttg.memdesc<128x128xi8, #shared, #ttg.shared_memory>,
258+
!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>,
259+
!ttg.memdesc<128x2xi8, #tmem_scales, #ttng.tensor_memory>,
260+
!ttg.memdesc<128x2xi8, #tmem_scales, #ttng.tensor_memory>,
261+
i1,
262+
i1,
263+
!ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>) -> ()
264+
tt.return
265+
}
266+
}
267+
268+
// -----
269+
232270
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 16, CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}>
233271
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 16, CTAsPerCGA = [1, 2], CTASplitNum = [1, 2], CTAOrder = [1, 0]}>
234272
#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [2], CTASplitNum = [1], CTAOrder = [0]}>
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
// RUN: triton-opt %s -split-input-file --tritonamdgpu-coalesce-async-copy=arch-generation-name=gfx950 | FileCheck %s
2+
3+
#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
4+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
5+
#smem = #ttg.shared_memory
6+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
7+
// sizePerThread = [1] because we have no information about contiguity of src pointers
8+
// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
9+
tt.func @async_copy_1d(%input: tensor<1024x!tt.ptr<f16>, #blocked>,
10+
%view: !ttg.memdesc<1024xf16, #shared, #smem, mutable>) {
11+
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<1024x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
12+
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<1024x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
13+
%token = ttg.async_copy_global_to_local %input, %view: tensor<1024x!tt.ptr<f16>, #blocked> -> <1024xf16, #shared, #smem, mutable>
14+
tt.return
15+
}
16+
}
17+
18+
// -----
19+
20+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
21+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
22+
#smem = #ttg.shared_memory
23+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
24+
// sizePerThread = [1, 1] because we have no information about contiguity of src pointers
25+
// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
26+
tt.func @async_copy_2d(%input: tensor<64x64x!tt.ptr<f16>, #blocked>,
27+
%view: !ttg.memdesc<64x64xf16, #shared, #smem, mutable>) {
28+
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x64x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
29+
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<64x64x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
30+
%token = ttg.async_copy_global_to_local %input, %view: tensor<64x64x!tt.ptr<f16>, #blocked> -> <64x64xf16, #shared, #smem, mutable>
31+
tt.return
32+
}
33+
}
34+
35+
// -----
36+
37+
#blocked = #ttg.blocked<{sizePerThread = [8, 1, 1], threadsPerWarp = [32, 1, 1], warpsPerCTA = [1,2,2], order = [0,1,2]}>
38+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0,1,2]}>
39+
#smem = #ttg.shared_memory
40+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
41+
// sizePerThread = [1, 1, 1] because we have no information about contiguity of src pointers
42+
// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [0, 1, 2]}>
43+
tt.func @async_copy_3d(%input: tensor<1024x1024x1024x!tt.ptr<f16>, #blocked>,
44+
%view: !ttg.memdesc<1024x1024x1024xf16, #shared, #smem, mutable>) {
45+
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<1024x1024x1024x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
46+
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<1024x1024x1024x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
47+
%token = ttg.async_copy_global_to_local %input, %view: tensor<1024x1024x1024x!tt.ptr<f16>, #blocked> -> <1024x1024x1024xf16, #shared, #smem, mutable>
48+
tt.return
49+
}
50+
}
51+
52+
// -----
53+
54+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
55+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
56+
#smem = #ttg.shared_memory
57+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
58+
// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
59+
tt.func @async_copy_with_mask_and_other(%input: tensor<64x64x!tt.ptr<f16>, #blocked>,
60+
%view: !ttg.memdesc<64x64xf16, #shared, #smem, mutable>,
61+
%mask: tensor<64x64xi1, #blocked>,
62+
%other: tensor<64x64xf16, #blocked>) {
63+
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x64x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
64+
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x64xi1, #[[NEW_BLOCKED]]>
65+
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x64xf16, #[[NEW_BLOCKED]]>
66+
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<64x64x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
67+
%token = ttg.async_copy_global_to_local %input, %view mask %mask other %other: tensor<64x64x!tt.ptr<f16>, #blocked> -> <64x64xf16, #shared, #smem, mutable>
68+
tt.return
69+
}
70+
}
71+
72+
// -----
73+
74+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
75+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
76+
#smem = #ttg.shared_memory
77+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
78+
// Clip to vector size 2 (32bit) because we do not support 64 bit loads to lds
79+
// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
80+
tt.func public @async_copy_vector_size_2(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
81+
%arg1: i32 {tt.divisibility = 16 : i32},
82+
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
83+
// We need the index calculation so AxisAnalysis sees that we can vectorize the load
84+
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
85+
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
86+
%3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked>
87+
%4 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked>
88+
%5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr<f16>, #blocked>, tensor<32x64xi32, #blocked>
89+
90+
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<32x64x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
91+
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<32x64x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
92+
%6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr<f16>, #blocked> -> <32x64xf16, #shared, #smem, mutable>
93+
tt.return
94+
}
95+
}
96+
97+
// -----
98+
99+
#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
100+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
101+
#smem = #ttg.shared_memory
102+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
103+
// Clip to vector size 4 (128bit) which is the largest supported load width
104+
// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
105+
tt.func public @async_copy_vector_size_8(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
106+
%arg1: i32 {tt.divisibility = 16 : i32},
107+
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
108+
// We need the index calculation so AxisAnalysis sees that we can vectorize the load based on the src contiguity
109+
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
110+
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
111+
%3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked>
112+
%4 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked>
113+
%5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr<f16>, #blocked>, tensor<32x64xi32, #blocked>
114+
115+
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<32x64x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
116+
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<32x64x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
117+
%6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr<f16>, #blocked> -> <32x64xf16, #shared, #smem, mutable>
118+
tt.return
119+
}
120+
}
121+
122+
// -----
123+
124+
#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
125+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
126+
#smem = #ttg.shared_memory
127+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
128+
// The order of #blocked and #shared are different so we need to clip to 1 element
129+
// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
130+
tt.func public @async_copy_different_order(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
131+
%arg1: i32 {tt.divisibility = 16 : i32},
132+
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
133+
// We need the index calculation so AxisAnalysis sees that we can vectorize the load based on the src contiguity
134+
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
135+
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
136+
%3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked>
137+
%4 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked>
138+
%5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr<f16>, #blocked>, tensor<32x64xi32, #blocked>
139+
140+
// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<32x64x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
141+
// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<32x64x!tt.ptr<f16>, #[[NEW_BLOCKED]]>
142+
%6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr<f16>, #blocked> -> <32x64xf16, #shared, #smem, mutable>
143+
tt.return
144+
}
145+
}

test/TritonGPU/amd/optimize-lds-usage.mlir

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,19 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32}
4444
}
4545
}
4646

47-
// FIXME: This was broken in https://github.com/triton-lang/triton/pull/5840
48-
// // -----
47+
// -----
4948

49+
// CHECK-DAG: [[$BLOCKED1:#.*]] = #ttg.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}>
50+
// CHECK-DAG: [[$BLOCKED2:#.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 8, 8], warpsPerCTA = [1, 2, 4], order = [1, 2, 0]}>
51+
// CHECK-DAG: [[$MMA:#.*]] = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}>
5052
// Check that optimization works with 3d tensors
5153
// in case of relatively small scratch buffer
52-
// DISABLE-CHECK-LABEL: alloc_convert_3d_load
53-
// DISABLE-CHECK-32KLIMIT-LABEL: alloc_convert_3d_load
54-
// DISABLE-CHECK: %0 = ttg.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared
55-
// DISABLE-CHECK: %1 = ttg.convert_layout %arg1 : {{.*}}#blocked{{.*}}#mma
56-
// DISABLE-CHECK: %2 = ttg.convert_layout %1 : {{.*}}#mma{{.*}}#mma1
57-
// DISABLE-CHECK: %3 = ttg.local_load %0 : {{.*}}#shared{{.*}}#ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>
54+
// CHECK-LABEL: alloc_convert_3d_load
55+
// CHECK-32KLIMIT-LABEL: alloc_convert_3d_load
56+
// CHECK: [[V0:%.*]] = ttg.local_alloc {{.*}}[[$BLOCKED1]]{{.*}}
57+
// CHECK: [[V1:%.*]] = ttg.convert_layout {{.*}}[[$BLOCKED1]]{{.*}}[[$BLOCKED2]]
58+
// CHECK: [[V2:%.*]] = ttg.convert_layout [[V1]] : {{.*}}[[$BLOCKED2]]{{.*}}[[$MMA]]
59+
// CHECK: [[V3:%.*]] = ttg.local_load [[V0]] : {{.*}}#ttg.dot_op<{opIdx = 0, parent = [[$MMA]], kWidth = 4}>>
5860
#blocked = #ttg.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}>
5961
#mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}>
6062
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1, 2]}>
@@ -93,22 +95,21 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32}
9395
}
9496
}
9597

96-
// FIXME: This was broken in https://github.com/triton-lang/triton/pull/5840
9798
// -----
9899

99100
// Check that optimization correctly handles LDS shortcut (see #mma2 -> #dotop2 conversion)
100-
// DISABLE-CHECK-DAG: [[BLOCKED_1:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
101-
// DISABLE-CHECK-DAG: [[BLOCKED_2:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [0, 1]}>
102-
// DISABLE-CHECK-DAG: [[MMA_1:#[a-z0-9]*]] = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
103-
// DISABLE-CHECK-DAG: [[MMA_2:#[a-z0-9]*]] = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}>
104-
// DISABLE-CHECK-DAG: [[SHARED:#[a-z0-9]*]] = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>
101+
// CHECK-DAG: [[BLOCKED_1:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
102+
// CHECK-DAG: [[BLOCKED_2:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [0, 1]}>
103+
// CHECK-DAG: [[MMA_1:#[a-z0-9]*]] = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
104+
// CHECK-DAG: [[MMA_2:#[a-z0-9]*]] = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}>
105+
// CHECK-DAG: [[SHARED:#[a-z0-9]*]] = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>
105106

106-
// DISABLE-CHECK: tt.func public @mfma_dot_shortcut([[ARG_0:%[a-z0-9]*]]: {{.*}}, [[ARG_1:%[a-z0-9]*]]: {{.*}}, [[ARG_2:%[a-z0-9]*]]: {{.*}})
107-
// DISABLE-CHECK: [[ALLOC:%[0-9]+]] = ttg.local_alloc [[ARG_0]] : (tensor<128x128xf16, [[BLOCKED_1]]>) -> !ttg.memdesc<128x128xf16, [[SHARED]], #smem>
108-
// DISABLE-CHECK: [[INTERMEDIATE_CONV:%[0-9]+]] = ttg.convert_layout [[ARG_1]] : tensor<128x128xf32, [[BLOCKED_1]]> -> tensor<128x128xf32, [[BLOCKED_2]]>
109-
// DISABLE-CHECK: [[CONVERT_1:%[0-9]+]] = ttg.convert_layout [[INTERMEDIATE_CONV]] : tensor<128x128xf32, [[BLOCKED_2]]> -> tensor<128x128xf32, [[MMA_2]]>
110-
// DISABLE-CHECK: [[CONVERT_2:%[0-9]+]] = ttg.convert_layout [[ARG_2]] : tensor<256x128xf16, [[MMA_1]]> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_1]], kWidth = 4}>>
111-
// DISABLE-CHECK: [[LOAD:%[0-9]+]] = ttg.local_load [[ALLOC]] : !ttg.memdesc<128x128xf16, [[SHARED]], #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_2]], kWidth = 4}>>
107+
// CHECK: tt.func public @mfma_dot_shortcut([[ARG_0:%[a-z0-9]*]]: {{.*}}, [[ARG_1:%[a-z0-9]*]]: {{.*}}, [[ARG_2:%[a-z0-9]*]]: {{.*}})
108+
// CHECK: [[ALLOC:%[0-9]+]] = ttg.local_alloc [[ARG_0]] : (tensor<128x128xf16, [[BLOCKED_1]]>) -> !ttg.memdesc<128x128xf16, [[SHARED]], #smem>
109+
// CHECK: [[INTERMEDIATE_CONV:%[0-9]+]] = ttg.convert_layout [[ARG_1]] : tensor<128x128xf32, [[BLOCKED_1]]> -> tensor<128x128xf32, [[BLOCKED_2]]>
110+
// CHECK: [[CONVERT_1:%[0-9]+]] = ttg.convert_layout [[INTERMEDIATE_CONV]] : tensor<128x128xf32, [[BLOCKED_2]]> -> tensor<128x128xf32, [[MMA_2]]>
111+
// CHECK: [[CONVERT_2:%[0-9]+]] = ttg.convert_layout [[ARG_2]] : tensor<256x128xf16, [[MMA_1]]> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_1]], kWidth = 4}>>
112+
// CHECK: [[LOAD:%[0-9]+]] = ttg.local_load [[ALLOC]] : !ttg.memdesc<128x128xf16, [[SHARED]], #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_2]], kWidth = 4}>>
112113
#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
113114
#mma1 = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}>
114115
#mma2 = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
@@ -125,6 +126,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32}
125126
tt.return
126127
}
127128
}
129+
128130
// -----
129131

130132
// Checks that optimization do not crash on 1d tensor

0 commit comments

Comments
 (0)