Skip to content

Commit af0649d

Browse files
authored
[AMD] Use warp shuffle for fp8 MFMA to dot operand layout conversion (#5139)
Adding a shortcut case for fp8 MFMA to dot operand layout conversion that avoids using shared memory, to speed up FP8 attention kernels.
1 parent 4ae95e7 commit af0649d

File tree

5 files changed

+378
-2
lines changed

5 files changed

+378
-2
lines changed

include/triton/Analysis/Utility.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,11 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
218218
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
219219
RankedTensorType dstTy);
220220

221+
// Check if MFMA layout can be converted to the dot operand
222+
// layout using warp shuffle.
223+
bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,
224+
RankedTensorType dstTy);
225+
221226
// TODO: Move utility functions that belong to ConvertLayoutOp to class
222227
// ConvertLayoutOpHelper in the future
223228
bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout);

lib/Analysis/Utility.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "mlir/IR/Dialect.h"
1111
#include "mlir/IR/Matchers.h"
1212
#include "mlir/Support/LLVM.h"
13+
#include "triton/Conversion/MLIRTypes.h"
1314
#include "triton/Dialect/Triton/IR/Dialect.h"
1415
#include "triton/Dialect/Triton/IR/Utility.h"
1516
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -632,6 +633,25 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
632633
return ans;
633634
}
634635

636+
bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,
637+
RankedTensorType dstTy) {
638+
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding());
639+
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
640+
if (!mfmaLayout || !dotOperandLayout)
641+
return false;
642+
643+
// Currently supporting 32x32 and 16x16 FP8 MFMA -> dot operand case
644+
return dotOperandLayout.getParent() == mfmaLayout &&
645+
dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() &&
646+
dotOperandLayout.getKWidth() == 8 &&
647+
getContigPerThread(mfmaLayout)[1] == 4 &&
648+
((mfmaLayout.getMDim() == 16 && mfmaLayout.getNDim() == 16) ||
649+
(mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32)) &&
650+
triton::type::isFloat8(srcTy.getElementType()) &&
651+
triton::type::isFloat8(dstTy.getElementType()) &&
652+
mfmaLayout.getWarpsPerCTA()[1] == 1;
653+
}
654+
635655
// We get the smallest submap of srcTy^{-1} * dstTy that is not the identity
636656
// under kBlock, kWarp or kLane (in that order). The idea here is that if we
637657
// have a transformation that's the identity on kBlock, we don't need to use
@@ -730,7 +750,10 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
730750
// supported yet in Triton's backend.
731751
return !cvtReordersRegisters(srcTy, dstTy) &&
732752
!isBlockedToDotShortcut(srcTy, dstTy) &&
733-
!matchMmaV3AndDotOperandLayout(srcTy, dstTy);
753+
!matchMmaV3AndDotOperandLayout(srcTy, dstTy) &&
754+
// to be removed when generalized warp shuffle conversions
755+
// are ready:
756+
!matchMFMAAndDotOperandShuffleCase(srcTy, dstTy);
734757
}
735758

736759
bool atomicNeedsSharedMemory(Value value) {

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
409409
return failure();
410410
}
411411

412+
// The following check can be removed when generalized warp shuffle
413+
// conversions are ready:
414+
if (matchMFMAAndDotOperandShuffleCase(srcTy, dstTy)) {
415+
return failure();
416+
}
417+
412418
assert(cvtNeedsSharedMemory(srcTy, dstTy));
413419

414420
SmallVector<Value> inVals =

test/Conversion/amd/mfma-shortcut.mlir

Lines changed: 189 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: triton-opt %s --decompose-unsupported-amd-conversions --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx90a" -split-input-file | FileCheck %s
1+
// RUN: triton-opt %s --decompose-unsupported-amd-conversions --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx942" -split-input-file | FileCheck %s
22

33
#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
44
#dotop = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}>
@@ -27,3 +27,191 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
2727
tt.return
2828
}
2929
}
30+
31+
// -----
32+
33+
#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
34+
#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>
35+
36+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
37+
// CHECK-LABEL: mfma_dot_cvt_f8_mfma32
38+
tt.func public @mfma_dot_cvt_f8_mfma32(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) {
39+
// CHECK-NOT: store
40+
// CHECK-NOT: load
41+
42+
// CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3]
43+
// CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7]
44+
45+
// CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32)
46+
// CHECK-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32)
47+
48+
// CHECK: [[threadId:%.*]] = rocdl.workitem.id.x
49+
// CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]]
50+
// CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]]
51+
52+
// CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]]
53+
// CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]]
54+
55+
// CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8>
56+
// CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8>
57+
58+
// CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]]
59+
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
60+
// CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]]
61+
// CHECK: [[bShflVec0:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]]
62+
// CHECK: [[shflVec0:%.*]] = llvm.bitcast [[bShflVec0]]
63+
64+
// CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]]
65+
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
66+
// CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]]
67+
// CHECK: [[bShflVec1:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]]
68+
// CHECK: [[shflVec1:%.*]] = llvm.bitcast [[bShflVec1]]
69+
70+
// Input (8 values): (vec0, vec1)
71+
// Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64):
72+
// resVec0 resVec1
73+
// lanes 0-31: (vec0 , vec0 >> 32) (mask0=1)
74+
// lanes 32-63: (vec1 >> 32, vec1 ) (mask0=0)
75+
76+
// CHECK: [[resVec0:%.*]] = llvm.select [[mask0]], [[vec0]], [[shflVec1]]
77+
// CHECK: [[resVec1:%.*]] = llvm.select [[mask0]], [[shflVec0]], [[vec1]]
78+
79+
// CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32)
80+
// CHECK: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8>
81+
// CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32
82+
// CHECK: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8>
83+
84+
// CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3]
85+
// CHECK: llvm.insertvalue [[resVal7]], {{.*}}[7]
86+
87+
// CHECK: llvm.return
88+
%0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0>
89+
tt.return
90+
}
91+
}
92+
93+
// -----
94+
95+
#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
96+
#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>
97+
98+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
99+
// CHECK-LABEL: mfma_dot_cvt_bf8_mfma32
100+
tt.func public @mfma_dot_cvt_bf8_mfma32(%arg0: tensor<128x32xf8E5M2, #mfma>) {
101+
// CHECK-NOT: store
102+
// CHECK-NOT: load
103+
// CHECK: rocdl.ds_bpermute
104+
// CHECK: llvm.return
105+
%0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0>
106+
tt.return
107+
}
108+
}
109+
110+
// -----
111+
112+
#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
113+
#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>
114+
115+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
116+
// CHECK-LABEL: mfma_dot_cvt_f8_mfma16
117+
tt.func public @mfma_dot_cvt_f8_mfma16(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) {
118+
// CHECK-NOT: store
119+
// CHECK-NOT: load
120+
121+
// CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3]
122+
// CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7]
123+
124+
// CHECK-DAG: [[c16:%.*]] = llvm.mlir.constant(16 : i32)
125+
// CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32)
126+
// CHECK-DAG: [[c48:%.*]] = llvm.mlir.constant(48 : i32)
127+
// CHECK-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32)
128+
129+
// CHECK: [[threadId:%.*]] = rocdl.workitem.id.x
130+
// CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]]
131+
// CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]]
132+
133+
// CHECK: [[laneIdRem:%.*]] = llvm.urem [[laneId]], [[c32]]
134+
// CHECK: [[mask1:%.*]] = llvm.icmp "slt" [[laneIdRem]], [[c16]]
135+
136+
// CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c16]]
137+
// CHECK: [[addr16:%.*]] = llvm.urem [[shflLaneId]], [[c64]]
138+
139+
// CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]]
140+
// CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]]
141+
142+
// CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c48]]
143+
// CHECK: [[addr48:%.*]] = llvm.urem [[shflLaneId]], [[c64]]
144+
145+
// CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8>
146+
// CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8>
147+
148+
// CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]]
149+
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
150+
// CHECK: [[addr:%.*]] = llvm.shl [[addr16]], [[c2]]
151+
// CHECK: [[bShflVec0_16:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]]
152+
// CHECK: [[shflVec0_16:%.*]] = llvm.bitcast [[bShflVec0_16]]
153+
154+
// CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]]
155+
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
156+
// CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]]
157+
// CHECK: [[bShflVec0_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]]
158+
// CHECK: [[shflVec0_32:%.*]] = llvm.bitcast [[bShflVec0_32]]
159+
160+
// CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]]
161+
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
162+
// CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]]
163+
// CHECK: [[bShflVec1_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]]
164+
// CHECK: [[shflVec1_32:%.*]] = llvm.bitcast [[bShflVec1_32]]
165+
166+
// CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]]
167+
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
168+
// CHECK: [[addr:%.*]] = llvm.shl [[addr48]], [[c2]]
169+
// CHECK: [[bShflVec1_48:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]]
170+
// CHECK: [[shflVec1_48:%.*]] = llvm.bitcast [[bShflVec1_48]]
171+
172+
// Input (8 values): (vec0, vec1)
173+
// Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64):
174+
// resVec0 resVec1
175+
// lanes 0-15: (vec0 , vec0 >> 16) (mask0=1, mask1=1)
176+
// lanes 16-31: (vec0 >> 16, vec0 >> 32) (mask0=1, mask1=0)
177+
// lanes 32-47: (vec1 >> 32, vec1 >> 48) (mask0=0, mask1=1)
178+
// lanes 48-63: (vec1 >> 48, vec1 ) (mask0=0, mask1=0)
179+
180+
// CHECK-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[vec0]], [[shflVec0_16]] : i1, vector<4xi8>
181+
// CHECK-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_32]], [[shflVec1_48]] : i1, vector<4xi8>
182+
// CHECK: [[resVec0:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8>
183+
184+
// CHECK-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[shflVec0_16]], [[shflVec0_32]] : i1, vector<4xi8>
185+
// CHECK-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_48]], [[vec1]] : i1, vector<4xi8>
186+
// CHECK: [[resVec1:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8>
187+
188+
// CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32)
189+
// CHECK: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8>
190+
// CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32
191+
// CHECK: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8>
192+
193+
// CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3]
194+
// CHECK: llvm.insertvalue [[resVal7]], {{.*}}[7]
195+
196+
// CHECK: llvm.return
197+
%0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0>
198+
tt.return
199+
}
200+
}
201+
202+
// -----
203+
204+
#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
205+
#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>
206+
207+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
208+
// CHECK-LABEL: mfma_dot_cvt_bf8_mfma16
209+
tt.func public @mfma_dot_cvt_bf8_mfma16(%arg0: tensor<128x32xf8E5M2, #mfma>) {
210+
// CHECK-NOT: store
211+
// CHECK-NOT: load
212+
// CHECK: rocdl.ds_bpermute
213+
// CHECK: llvm.return
214+
%0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0>
215+
tt.return
216+
}
217+
}

0 commit comments

Comments
 (0)