Skip to content

Commit c6da81a

Browse files
authored
Revert "[AMD] Use warp shuffle for MFMA to Dot operand layout conversion (FP8)" (#5240)
It is causing performance regression, revert until it can be investigated Reverts triton-lang/triton#5139
1 parent 84ced0e commit c6da81a

File tree

5 files changed

+2
-378
lines changed

5 files changed

+2
-378
lines changed

include/triton/Analysis/Utility.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -218,11 +218,6 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
218218
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
219219
RankedTensorType dstTy);
220220

221-
// Check if MFMA layout can be converted to the dot operand
222-
// layout using warp shuffle.
223-
bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,
224-
RankedTensorType dstTy);
225-
226221
// TODO: Move utility functions that belong to ConvertLayoutOp to class
227222
// ConvertLayoutOpHelper in the future
228223
bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout);

lib/Analysis/Utility.cpp

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include "mlir/IR/Dialect.h"
1111
#include "mlir/IR/Matchers.h"
1212
#include "mlir/Support/LLVM.h"
13-
#include "triton/Conversion/MLIRTypes.h"
1413
#include "triton/Dialect/Triton/IR/Dialect.h"
1514
#include "triton/Dialect/Triton/IR/Utility.h"
1615
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -633,25 +632,6 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
633632
return ans;
634633
}
635634

636-
bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,
637-
RankedTensorType dstTy) {
638-
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding());
639-
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
640-
if (!mfmaLayout || !dotOperandLayout)
641-
return false;
642-
643-
// Currently supporting 32x32 and 16x16 FP8 MFMA -> dot operand case
644-
return dotOperandLayout.getParent() == mfmaLayout &&
645-
dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() &&
646-
dotOperandLayout.getKWidth() == 8 &&
647-
getContigPerThread(mfmaLayout)[1] == 4 &&
648-
((mfmaLayout.getMDim() == 16 && mfmaLayout.getNDim() == 16) ||
649-
(mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32)) &&
650-
triton::type::isFloat8(srcTy.getElementType()) &&
651-
triton::type::isFloat8(dstTy.getElementType()) &&
652-
mfmaLayout.getWarpsPerCTA()[1] == 1;
653-
}
654-
655635
// We get the smallest submap of srcTy^{-1} * dstTy that is not the identity
656636
// under kBlock, kWarp or kLane (in that order). The idea here is that if we
657637
// have a transformation that's the identity on kBlock, we don't need to use
@@ -750,10 +730,7 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
750730
// supported yet in Triton's backend.
751731
return !cvtReordersRegisters(srcTy, dstTy) &&
752732
!isBlockedToDotShortcut(srcTy, dstTy) &&
753-
!matchMmaV3AndDotOperandLayout(srcTy, dstTy) &&
754-
// to be removed when generalized warp shuffle conversions
755-
// are ready:
756-
!matchMFMAAndDotOperandShuffleCase(srcTy, dstTy);
733+
!matchMmaV3AndDotOperandLayout(srcTy, dstTy);
757734
}
758735

759736
bool atomicNeedsSharedMemory(Value value) {

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -409,12 +409,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
409409
return failure();
410410
}
411411

412-
// The following check can be removed when generalized warp shuffle
413-
// conversions are ready:
414-
if (matchMFMAAndDotOperandShuffleCase(srcTy, dstTy)) {
415-
return failure();
416-
}
417-
418412
assert(cvtNeedsSharedMemory(srcTy, dstTy));
419413

420414
SmallVector<Value> inVals =
Lines changed: 1 addition & 189 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: triton-opt %s --decompose-unsupported-amd-conversions --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx942" -split-input-file | FileCheck %s
1+
// RUN: triton-opt %s --decompose-unsupported-amd-conversions --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx90a" -split-input-file | FileCheck %s
22

33
#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
44
#dotop = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}>
@@ -27,191 +27,3 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
2727
tt.return
2828
}
2929
}
30-
31-
// -----
32-
33-
#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
34-
#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>
35-
36-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
37-
// CHECK-LABEL: mfma_dot_cvt_f8_mfma32
38-
tt.func public @mfma_dot_cvt_f8_mfma32(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) {
39-
// CHECK-NOT: store
40-
// CHECK-NOT: load
41-
42-
// CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3]
43-
// CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7]
44-
45-
// CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32)
46-
// CHECK-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32)
47-
48-
// CHECK: [[threadId:%.*]] = rocdl.workitem.id.x
49-
// CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]]
50-
// CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]]
51-
52-
// CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]]
53-
// CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]]
54-
55-
// CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8>
56-
// CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8>
57-
58-
// CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]]
59-
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
60-
// CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]]
61-
// CHECK: [[bShflVec0:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]]
62-
// CHECK: [[shflVec0:%.*]] = llvm.bitcast [[bShflVec0]]
63-
64-
// CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]]
65-
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
66-
// CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]]
67-
// CHECK: [[bShflVec1:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]]
68-
// CHECK: [[shflVec1:%.*]] = llvm.bitcast [[bShflVec1]]
69-
70-
// Input (8 values): (vec0, vec1)
71-
// Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64):
72-
// resVec0 resVec1
73-
// lanes 0-31: (vec0 , vec0 >> 32) (mask0=1)
74-
// lanes 32-63: (vec1 >> 32, vec1 ) (mask0=0)
75-
76-
// CHECK: [[resVec0:%.*]] = llvm.select [[mask0]], [[vec0]], [[shflVec1]]
77-
// CHECK: [[resVec1:%.*]] = llvm.select [[mask0]], [[shflVec0]], [[vec1]]
78-
79-
// CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32)
80-
// CHECK: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8>
81-
// CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32
82-
// CHECK: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8>
83-
84-
// CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3]
85-
// CHECK: llvm.insertvalue [[resVal7]], {{.*}}[7]
86-
87-
// CHECK: llvm.return
88-
%0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0>
89-
tt.return
90-
}
91-
}
92-
93-
// -----
94-
95-
#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
96-
#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>
97-
98-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
99-
// CHECK-LABEL: mfma_dot_cvt_bf8_mfma32
100-
tt.func public @mfma_dot_cvt_bf8_mfma32(%arg0: tensor<128x32xf8E5M2, #mfma>) {
101-
// CHECK-NOT: store
102-
// CHECK-NOT: load
103-
// CHECK: rocdl.ds_bpermute
104-
// CHECK: llvm.return
105-
%0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0>
106-
tt.return
107-
}
108-
}
109-
110-
// -----
111-
112-
#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
113-
#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>
114-
115-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
116-
// CHECK-LABEL: mfma_dot_cvt_f8_mfma16
117-
tt.func public @mfma_dot_cvt_f8_mfma16(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) {
118-
// CHECK-NOT: store
119-
// CHECK-NOT: load
120-
121-
// CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3]
122-
// CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7]
123-
124-
// CHECK-DAG: [[c16:%.*]] = llvm.mlir.constant(16 : i32)
125-
// CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32)
126-
// CHECK-DAG: [[c48:%.*]] = llvm.mlir.constant(48 : i32)
127-
// CHECK-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32)
128-
129-
// CHECK: [[threadId:%.*]] = rocdl.workitem.id.x
130-
// CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]]
131-
// CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]]
132-
133-
// CHECK: [[laneIdRem:%.*]] = llvm.urem [[laneId]], [[c32]]
134-
// CHECK: [[mask1:%.*]] = llvm.icmp "slt" [[laneIdRem]], [[c16]]
135-
136-
// CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c16]]
137-
// CHECK: [[addr16:%.*]] = llvm.urem [[shflLaneId]], [[c64]]
138-
139-
// CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]]
140-
// CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]]
141-
142-
// CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c48]]
143-
// CHECK: [[addr48:%.*]] = llvm.urem [[shflLaneId]], [[c64]]
144-
145-
// CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8>
146-
// CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8>
147-
148-
// CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]]
149-
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
150-
// CHECK: [[addr:%.*]] = llvm.shl [[addr16]], [[c2]]
151-
// CHECK: [[bShflVec0_16:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]]
152-
// CHECK: [[shflVec0_16:%.*]] = llvm.bitcast [[bShflVec0_16]]
153-
154-
// CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]]
155-
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
156-
// CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]]
157-
// CHECK: [[bShflVec0_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]]
158-
// CHECK: [[shflVec0_32:%.*]] = llvm.bitcast [[bShflVec0_32]]
159-
160-
// CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]]
161-
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
162-
// CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]]
163-
// CHECK: [[bShflVec1_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]]
164-
// CHECK: [[shflVec1_32:%.*]] = llvm.bitcast [[bShflVec1_32]]
165-
166-
// CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]]
167-
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
168-
// CHECK: [[addr:%.*]] = llvm.shl [[addr48]], [[c2]]
169-
// CHECK: [[bShflVec1_48:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]]
170-
// CHECK: [[shflVec1_48:%.*]] = llvm.bitcast [[bShflVec1_48]]
171-
172-
// Input (8 values): (vec0, vec1)
173-
// Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64):
174-
// resVec0 resVec1
175-
// lanes 0-15: (vec0 , vec0 >> 16) (mask0=1, mask1=1)
176-
// lanes 16-31: (vec0 >> 16, vec0 >> 32) (mask0=1, mask1=0)
177-
// lanes 32-47: (vec1 >> 32, vec1 >> 48) (mask0=0, mask1=1)
178-
// lanes 48-63: (vec1 >> 48, vec1 ) (mask0=0, mask1=0)
179-
180-
// CHECK-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[vec0]], [[shflVec0_16]] : i1, vector<4xi8>
181-
// CHECK-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_32]], [[shflVec1_48]] : i1, vector<4xi8>
182-
// CHECK: [[resVec0:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8>
183-
184-
// CHECK-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[shflVec0_16]], [[shflVec0_32]] : i1, vector<4xi8>
185-
// CHECK-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_48]], [[vec1]] : i1, vector<4xi8>
186-
// CHECK: [[resVec1:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8>
187-
188-
// CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32)
189-
// CHECK: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8>
190-
// CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32
191-
// CHECK: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8>
192-
193-
// CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3]
194-
// CHECK: llvm.insertvalue [[resVal7]], {{.*}}[7]
195-
196-
// CHECK: llvm.return
197-
%0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0>
198-
tt.return
199-
}
200-
}
201-
202-
// -----
203-
204-
#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
205-
#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>
206-
207-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
208-
// CHECK-LABEL: mfma_dot_cvt_bf8_mfma16
209-
tt.func public @mfma_dot_cvt_bf8_mfma16(%arg0: tensor<128x32xf8E5M2, #mfma>) {
210-
// CHECK-NOT: store
211-
// CHECK-NOT: load
212-
// CHECK: rocdl.ds_bpermute
213-
// CHECK: llvm.return
214-
%0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0>
215-
tt.return
216-
}
217-
}

0 commit comments

Comments
 (0)