Skip to content

Commit 1cb0d99

Browse files
authored
[AMD] Use Linear Layout convertions for AMDWmma (#5255)
Enable LL conwertions for WMMA as well as for MFMA layouts. See also: triton-lang/triton#5210 Signed-off-by: Ilya Veselov <[email protected]>
1 parent 55b741d commit 1cb0d99

File tree

2 files changed

+67
-2
lines changed

2 files changed

+67
-2
lines changed

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,9 +374,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
374374
// TODO (Keren): Currently, we handle general mma/blocked/slice/dot(ampere)
375375
// -> mma/blocked/slice/dot(ampere) conversions. The following tasks must be
376376
// completed before we can remove the layoutIsOK check:
377-
// 1. Support for AMD's WMMA
377+
// 1. Support for AMD's WMMA dot operand
378378
std::function<bool(Attribute)> layoutIsOK = [&](Attribute layout) {
379-
if (isa<NvidiaMmaEncodingAttr, AMDMfmaEncodingAttr>(layout)) {
379+
if (isa<MmaEncodingTrait>(layout)) {
380380
return !useLegacyMMAConversion;
381381
}
382382
if (auto dotOperand = dyn_cast<DotOperandEncodingAttr>(layout)) {

test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx1100 | FileCheck %s
22

3+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
34
#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
45
#mma1 = #ttg.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}>
56
#mma2 = #ttg.amd_wmma<{version = 2, warpsPerCTA = [2, 2]}>
@@ -97,6 +98,70 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
9798
// CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)>
9899
tt.return
99100
}
101+
102+
// CHECK-LABEL: blocked_to_wmma1
103+
tt.func @blocked_to_wmma1(%arg0: tensor<128x16xi32, #blocked>) {
104+
// CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
105+
// CHECK-COUNT-32: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
106+
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #blocked> -> tensor<128x16xi32, #mma1>
107+
tt.return
108+
}
109+
110+
// CHECK-LABEL: slice_blocked_to_wmma1
111+
tt.func @slice_blocked_to_wmma1(%arg0: tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) {
112+
// CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
113+
// CHECK-COUNT-1: llvm.insertvalue {{.*}} : !llvm.struct<(i32)>
114+
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma1}>>
115+
tt.return
116+
}
117+
118+
// CHECK-LABEL: wmma1_to_blocked
119+
tt.func @wmma1_to_blocked(%arg0: tensor<128x16xi32, #mma1>) {
120+
// CHECK-COUNT-32: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
121+
// CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
122+
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #mma1> -> tensor<128x16xi32, #blocked>
123+
tt.return
124+
}
125+
126+
// CHECK-LABEL: slice_wmma1_to_blocked
127+
tt.func @slice_wmma1_to_blocked(%arg0: tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma1}>>) {
128+
// CHECK-COUNT-1: llvm.extractvalue {{.*}} : !llvm.struct<(i32)>
129+
// CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
130+
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma1}>> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
131+
tt.return
132+
}
133+
134+
// CHECK-LABEL: blocked_to_wmma2
135+
tt.func @blocked_to_wmma2(%arg0: tensor<128x16xi32, #blocked>) {
136+
// CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
137+
// CHECK-COUNT-32: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
138+
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #blocked> -> tensor<128x16xi32, #mma2>
139+
tt.return
140+
}
141+
142+
// CHECK-LABEL: slice_blocked_to_wmma2
143+
tt.func @slice_blocked_to_wmma2(%arg0: tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) {
144+
// CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
145+
// CHECK-COUNT-1: llvm.insertvalue {{.*}} : !llvm.struct<(i32)>
146+
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma2}>>
147+
tt.return
148+
}
149+
150+
// CHECK-LABEL: wmma2_to_blocked
151+
tt.func @wmma2_to_blocked(%arg0: tensor<128x16xi32, #mma2>) {
152+
// CHECK-COUNT-32: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
153+
// CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
154+
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #mma2> -> tensor<128x16xi32, #blocked>
155+
tt.return
156+
}
157+
158+
// CHECK-LABEL: slice_wmma2_to_blocked
159+
tt.func @slice_wmma2_to_blocked(%arg0: tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma2}>>) {
160+
// CHECK-COUNT-1: llvm.extractvalue {{.*}} : !llvm.struct<(i32)>
161+
// CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
162+
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma2}>> -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
163+
tt.return
164+
}
100165
}
101166

102167
// -----

0 commit comments

Comments
 (0)