Skip to content

Commit 5a87bde

Browse files
authored
[AMD] Remove manual transposed MFMA to DotOp layout conversion (#7625)
Generalized warp-shuffle conversions were introduced in #7558. We remove the two manual implementations for the transposed mfma to dot op layout conversions and their corresponding LIT tests. For the 32x32 instruction, in the LIT test `mfma_dot_cvt_f8_mfma32`: - The manual implementation uses 4 `ds_bpermute`s and 4 `select`s - The general pathway uses 2 `ds_bpermute`s and 6 `select`s For the 16x16 instruction, in the LIT test `mfma_dot_cvt_f8_mfma16`: - The manual implementation uses 8 `ds_bpermute`s and 12 `select`s - The general pathway uses 4 `ds_bpermute`s and 10 `select`s While I have not run tests to verify that this would not result in a performance regression, it seems likely there would be noticeable improvements. Still, maintainers should check (cf. #7574). There is room for further optimization in the 32x32 case for CDNA4 supported hardware using the `V_PERMLANE32_SWAP_B32` instruction, as in `ConvertLayoutOpMFMAToLinearConversion`, where two such instructions can replace the 2 `ds_bpermute`s and 6 `select`s above. This PR does not implement this as I believe it would make more sense to implement this generally rather than by hand for each case. This can be done using the `getWarpLayoutConvertDecomposition` utility function in a later PR.
1 parent 2fad92c commit 5a87bde

File tree

4 files changed

+1
-342
lines changed

4 files changed

+1
-342
lines changed

include/triton/Analysis/Utility.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -252,11 +252,6 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy);
252252
// warps, and possibly blocks.
253253
bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy);
254254

255-
// Check if MFMA layout can be converted to the dot operand
256-
// layout using warp shuffle.
257-
bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,
258-
RankedTensorType dstTy);
259-
260255
// TODO: Move utility functions that belong to ConvertLayoutOp to class
261256
// ConvertLayoutOpHelper in the future
262257
bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout);

lib/Analysis/Utility.cpp

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -713,24 +713,6 @@ bool supportMMA(Value value, int version) {
713713
(elemTy.isInteger(8) && version >= 2);
714714
}
715715

716-
bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,
717-
RankedTensorType dstTy) {
718-
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding());
719-
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
720-
if (!mfmaLayout || !dotOperandLayout)
721-
return false;
722-
723-
// Currently supporting 32x32 and 16x16 FP8 MFMA -> dot operand case
724-
return dotOperandLayout.getParent() == mfmaLayout &&
725-
dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() &&
726-
dotOperandLayout.getKWidth() == 8 &&
727-
((mfmaLayout.getMDim() == 16 && mfmaLayout.getNDim() == 16) ||
728-
(mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32)) &&
729-
triton::type::isFloat8(srcTy.getElementType()) &&
730-
triton::type::isFloat8(dstTy.getElementType()) &&
731-
mfmaLayout.getWarpsPerCTA()[1] == 1;
732-
}
733-
734716
// We get the smallest submap of srcTy^{-1} * dstTy that is not the identity
735717
// under the common dimensions. The idea here is that if we have a
736718
// transformation that's the identity on kBlock, we don't need to use
@@ -788,13 +770,8 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
788770
}
789771

790772
bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
791-
// TODO(jlebar): Remove these special cases `isMfmaToDotShortcut` once
792-
// they're fully subsumed by the linear-layout checks.
793773
return !cvtReordersRegisters(srcTy, dstTy) &&
794-
!cvtNeedsWarpShuffle(srcTy, dstTy) &&
795-
// to be removed when generalized warp shuffle conversions
796-
// are ready:
797-
!matchMFMAAndDotOperandShuffleCase(srcTy, dstTy);
774+
!cvtNeedsWarpShuffle(srcTy, dstTy);
798775
}
799776

800777
namespace {

test/Conversion/amd/mfma-shortcut.mlir

Lines changed: 0 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -19,70 +19,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
1919
#mfma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
2020
#dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>
2121

22-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
23-
// GFX942-LABEL: mfma_dot_cvt_f8_mfma32
24-
tt.func public @mfma_dot_cvt_f8_mfma32(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) {
25-
// GFX942-NOT: store
26-
// GFX942-NOT: load
27-
28-
// GFX942: [[val3:%.*]] = llvm.extractvalue %arg0[3]
29-
// GFX942: [[val7:%.*]] = llvm.extractvalue %arg0[7]
30-
31-
// GFX942-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32)
32-
// GFX942-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32)
33-
34-
// GFX942: [[threadId:%.*]] = rocdl.workitem.id.x
35-
// GFX942: [[c255:%.*]] = llvm.mlir.constant(255 : i32)
36-
// GFX942: [[RTID:%.*]] = llvm.and [[threadId]], [[c255]]
37-
// GFX942: [[laneId:%.*]] = llvm.urem [[RTID]], [[c64]]
38-
// GFX942: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]]
39-
40-
// GFX942: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]]
41-
// GFX942: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]]
42-
43-
// GFX942: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8>
44-
// GFX942: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8>
45-
46-
// GFX942: [[bvec0:%.*]] = llvm.bitcast [[vec0]]
47-
// GFX942: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
48-
// GFX942: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]]
49-
// GFX942: [[bShflVec0:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]]
50-
// GFX942: [[shflVec0:%.*]] = llvm.bitcast [[bShflVec0]]
51-
52-
// GFX942: [[bvec1:%.*]] = llvm.bitcast [[vec1]]
53-
// GFX942: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
54-
// GFX942: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]]
55-
// GFX942: [[bShflVec1:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]]
56-
// GFX942: [[shflVec1:%.*]] = llvm.bitcast [[bShflVec1]]
57-
58-
// Input (8 values): (vec0, vec1)
59-
// Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64):
60-
// resVec0 resVec1
61-
// lanes 0-31: (vec0 , vec0 >> 32) (mask0=1)
62-
// lanes 32-63: (vec1 >> 32, vec1 ) (mask0=0)
63-
64-
// GFX942: [[resVec0:%.*]] = llvm.select [[mask0]], [[vec0]], [[shflVec1]]
65-
// GFX942: [[resVec1:%.*]] = llvm.select [[mask0]], [[shflVec0]], [[vec1]]
66-
67-
// GFX942: [[c3:%.*]] = llvm.mlir.constant(3 : i32)
68-
// GFX942: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8>
69-
// GFX942: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32
70-
// GFX942: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8>
71-
72-
// GFX942: llvm.insertvalue [[resVal3]], {{.*}}[3]
73-
// GFX942: llvm.insertvalue [[resVal7]], {{.*}}[7]
74-
75-
// GFX942: llvm.return
76-
%0 = ttg.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0>
77-
tt.return
78-
}
79-
}
80-
81-
// -----
82-
83-
#mfma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
84-
#dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>
85-
8622
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
8723
// GFX942-LABEL: mfma_dot_cvt_bf8_mfma32
8824
tt.func public @mfma_dot_cvt_bf8_mfma32(%arg0: tensor<128x32xf8E5M2, #mfma>) {
@@ -100,100 +36,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
10036
#mfma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
10137
#dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>
10238

103-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
104-
// GFX942-LABEL: mfma_dot_cvt_f8_mfma16
105-
tt.func public @mfma_dot_cvt_f8_mfma16(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) {
106-
// GFX942-NOT: store
107-
// GFX942-NOT: load
108-
109-
// GFX942: [[val3:%.*]] = llvm.extractvalue %arg0[3]
110-
// GFX942: [[val7:%.*]] = llvm.extractvalue %arg0[7]
111-
112-
// GFX942-DAG: [[c16:%.*]] = llvm.mlir.constant(16 : i32)
113-
// GFX942-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32)
114-
// GFX942-DAG: [[c48:%.*]] = llvm.mlir.constant(48 : i32)
115-
// GFX942-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32)
116-
117-
// GFX942: [[threadId:%.*]] = rocdl.workitem.id.x
118-
// GFX942: [[c255:%.*]] = llvm.mlir.constant(255 : i32)
119-
// GFX942: [[RTID:%.*]] = llvm.and [[threadId]], [[c255]]
120-
// GFX942: [[laneId:%.*]] = llvm.urem [[RTID]], [[c64]]
121-
// GFX942: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]]
122-
123-
// GFX942: [[laneIdRem:%.*]] = llvm.urem [[laneId]], [[c32]]
124-
// GFX942: [[mask1:%.*]] = llvm.icmp "slt" [[laneIdRem]], [[c16]]
125-
126-
// GFX942: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c16]]
127-
// GFX942: [[addr16:%.*]] = llvm.urem [[shflLaneId]], [[c64]]
128-
129-
// GFX942: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]]
130-
// GFX942: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]]
131-
132-
// GFX942: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c48]]
133-
// GFX942: [[addr48:%.*]] = llvm.urem [[shflLaneId]], [[c64]]
134-
135-
// GFX942: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8>
136-
// GFX942: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8>
137-
138-
// GFX942: [[bvec0:%.*]] = llvm.bitcast [[vec0]]
139-
// GFX942: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
140-
// GFX942: [[addr:%.*]] = llvm.shl [[addr16]], [[c2]]
141-
// GFX942: [[bShflVec0_16:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]]
142-
// GFX942: [[shflVec0_16:%.*]] = llvm.bitcast [[bShflVec0_16]]
143-
144-
// GFX942: [[bvec0:%.*]] = llvm.bitcast [[vec0]]
145-
// GFX942: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
146-
// GFX942: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]]
147-
// GFX942: [[bShflVec0_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]]
148-
// GFX942: [[shflVec0_32:%.*]] = llvm.bitcast [[bShflVec0_32]]
149-
150-
// GFX942: [[bvec1:%.*]] = llvm.bitcast [[vec1]]
151-
// GFX942: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
152-
// GFX942: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]]
153-
// GFX942: [[bShflVec1_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]]
154-
// GFX942: [[shflVec1_32:%.*]] = llvm.bitcast [[bShflVec1_32]]
155-
156-
// GFX942: [[bvec1:%.*]] = llvm.bitcast [[vec1]]
157-
// GFX942: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
158-
// GFX942: [[addr:%.*]] = llvm.shl [[addr48]], [[c2]]
159-
// GFX942: [[bShflVec1_48:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]]
160-
// GFX942: [[shflVec1_48:%.*]] = llvm.bitcast [[bShflVec1_48]]
161-
162-
// Input (8 values): (vec0, vec1)
163-
// Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64):
164-
// resVec0 resVec1
165-
// lanes 0-15: (vec0 , vec0 >> 16) (mask0=1, mask1=1)
166-
// lanes 16-31: (vec0 >> 16, vec0 >> 32) (mask0=1, mask1=0)
167-
// lanes 32-47: (vec1 >> 32, vec1 >> 48) (mask0=0, mask1=1)
168-
// lanes 48-63: (vec1 >> 48, vec1 ) (mask0=0, mask1=0)
169-
170-
// GFX942-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[vec0]], [[shflVec0_16]] : i1, vector<4xi8>
171-
// GFX942-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_32]], [[shflVec1_48]] : i1, vector<4xi8>
172-
// GFX942: [[resVec0:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8>
173-
174-
// GFX942-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[shflVec0_16]], [[shflVec0_32]] : i1, vector<4xi8>
175-
// GFX942-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_48]], [[vec1]] : i1, vector<4xi8>
176-
// GFX942: [[resVec1:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8>
177-
178-
// GFX942: [[c3:%.*]] = llvm.mlir.constant(3 : i32)
179-
// GFX942: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8>
180-
// GFX942: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32
181-
// GFX942: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8>
182-
183-
// GFX942: llvm.insertvalue [[resVal3]], {{.*}}[3]
184-
// GFX942: llvm.insertvalue [[resVal7]], {{.*}}[7]
185-
186-
// GFX942: llvm.return
187-
%0 = ttg.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0>
188-
tt.return
189-
}
190-
}
191-
192-
// -----
193-
194-
#mfma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
195-
#dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>
196-
19739
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
19840
// GFX942-LABEL: mfma_dot_cvt_bf8_mfma16
19941
tt.func public @mfma_dot_cvt_bf8_mfma16(%arg0: tensor<128x32xf8E5M2, #mfma>) {

third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 0 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -13,159 +13,6 @@ using ::triton::gpu::LinearEncodingAttr;
1313

1414
namespace {
1515

16-
struct ConvertLayoutOpMFMAToDotOpConversion
17-
: public ConvertOpToLLVMPattern<triton::gpu::ConvertLayoutOp> {
18-
public:
19-
explicit ConvertLayoutOpMFMAToDotOpConversion(
20-
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
21-
PatternBenefit benefit)
22-
: ConvertOpToLLVMPattern<triton::gpu::ConvertLayoutOp>(typeConverter,
23-
benefit),
24-
targetInfo(targetInfo) {}
25-
26-
LogicalResult
27-
matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
28-
ConversionPatternRewriter &rewriter) const override {
29-
auto srcType = cast<RankedTensorType>(op.getSrc().getType());
30-
auto dstType = cast<RankedTensorType>(op.getType());
31-
32-
if (!matchMFMAAndDotOperandShuffleCase(srcType, dstType))
33-
return failure();
34-
35-
auto loc = op.getLoc();
36-
auto b = TritonLLVMOpBuilder(loc, rewriter);
37-
38-
SmallVector<Value> inVals =
39-
unpackLLElements(loc, adaptor.getSrc(), rewriter);
40-
if (inVals.empty() || inVals.size() % 8 != 0)
41-
return failure();
42-
43-
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcType.getEncoding());
44-
assert((mfmaLayout.getMDim() == 16 || mfmaLayout.getMDim() == 32) &&
45-
"Expected MFMA size 16 or 32");
46-
assert(triton::gpu::lookupThreadsPerWarp(rewriter) == 64 &&
47-
"Expected warp size 64 for MFMA");
48-
49-
auto elemTy = int_ty(8);
50-
auto vecTy = vec_ty(elemTy, 4);
51-
52-
Value c16 = b.i32_val(16);
53-
Value c32 = b.i32_val(32);
54-
Value c48 = b.i32_val(48);
55-
Value c64 = b.i32_val(64);
56-
57-
Value threadId = getThreadId(rewriter, loc);
58-
Value laneId = b.urem(threadId, c64);
59-
60-
Value mask0 = b.icmp_slt(laneId, c32);
61-
Value mask1 = b.icmp_slt(b.urem(laneId, c32), c16);
62-
63-
Value addrShift16 = b.urem(b.add(laneId, c16), c64);
64-
Value addrShift32 = b.urem(b.add(laneId, c32), c64);
65-
Value addrShift48 = b.urem(b.add(laneId, c48), c64);
66-
67-
SmallVector<Value> outVals;
68-
for (size_t startIdx = 0; startIdx < inVals.size(); startIdx += 8) {
69-
Value vec0 = b.undef(vecTy);
70-
for (size_t vIdx = 0; vIdx < 4; ++vIdx) {
71-
vec0 = b.insert_element(vecTy, vec0, inVals[startIdx + vIdx],
72-
b.i32_val(vIdx));
73-
}
74-
Value vec1 = b.undef(vecTy);
75-
for (size_t vIdx = 0; vIdx < 4; ++vIdx) {
76-
vec1 = b.insert_element(vecTy, vec1, inVals[startIdx + vIdx + 4],
77-
b.i32_val(vIdx));
78-
}
79-
80-
Value resVec0, resVec1;
81-
if (mfmaLayout.getMDim() == 32) {
82-
/*
83-
Using wave shuffle to convert layouts (32x32x16 case):
84-
1) Input MMA layout (32x32, fp8, 16 values):
85-
_____________________________________________________________
86-
|(t0 v0 v1 v2 v3) (t32 v0 v1 v2 v3) ... (t32 v12 v13 v14 v15)|
87-
| ... ... |
88-
|(t31 v0 v1 v2 v3) (t63 v0 v1 v2 v3) ... (t63 v12 v13 v14 v15)|
89-
|_____________________________________________________________|
90-
91-
2) Output Dot operand layout (two 32x16 tiles, fp8, 8 values each):
92-
____________________________________________________________ ___
93-
|(t0 v0 v1 v2 v3 v4 v5 v6 v7) (t32 v0 v1 v2 v3 v4 v5 v6 v7) ||
94-
| ... ... ||...
95-
|(t31 v0 v1 v2 v3 v4 v5 v6 v7) (t63 v0 v1 v2 v3 v4 v5 v6 v7) ||
96-
|____________________________________________________________||___
97-
*/
98-
99-
Value shflVec0 = b.bitcast(
100-
targetInfo.shuffleIdx(rewriter, loc, b.bitcast(vec0, int_ty(32)),
101-
addrShift32),
102-
vecTy);
103-
Value shflVec1 = b.bitcast(
104-
targetInfo.shuffleIdx(rewriter, loc, b.bitcast(vec1, int_ty(32)),
105-
addrShift32),
106-
vecTy);
107-
108-
resVec0 = b.select(mask0, vec0, shflVec1);
109-
resVec1 = b.select(mask0, shflVec0, vec1);
110-
} else if (mfmaLayout.getMDim() == 16) {
111-
/*
112-
16x16x32 case:
113-
1) Input MMA layout (two 16x16, fp8, 4 values each):
114-
_________________________________________________________ ___________
115-
|(t0 v0 v1 v2 v3) (t16 v0 v1 v2 v3) ... (t48 v0 v1 v2 v3)||(t0 v4 ...
116-
| ... ... || ...
117-
|(t15 v0 v1 v2 v3) (t31 v0 v1 v2 v3) ... (t63 v0 v1 v2 v3)||(t15 v4 ...
118-
|_________________________________________________________||___________
119-
120-
2) Output Dot operand layout (16x32 tile, fp8, 8 values):
121-
________________________________________________________________
122-
|(t0 v0 v1 v2 v3 v4 v5 v6 v7) ... (t48 v0 v1 v2 v3 v4 v5 v6 v7) |
123-
| ... ... |
124-
|(t15 v0 v1 v2 v3 v4 v5 v6 v7) ... (t63 v0 v1 v2 v3 v4 v5 v6 v7) |
125-
|________________________________________________________________|
126-
*/
127-
128-
Value shflVec0_16 = b.bitcast(
129-
targetInfo.shuffleIdx(rewriter, loc, b.bitcast(vec0, int_ty(32)),
130-
addrShift16),
131-
vecTy);
132-
Value shflVec0_32 = b.bitcast(
133-
targetInfo.shuffleIdx(rewriter, loc, b.bitcast(vec0, int_ty(32)),
134-
addrShift32),
135-
vecTy);
136-
Value shflVec1_32 = b.bitcast(
137-
targetInfo.shuffleIdx(rewriter, loc, b.bitcast(vec1, int_ty(32)),
138-
addrShift32),
139-
vecTy);
140-
Value shflVec1_48 = b.bitcast(
141-
targetInfo.shuffleIdx(rewriter, loc, b.bitcast(vec1, int_ty(32)),
142-
addrShift48),
143-
vecTy);
144-
145-
resVec0 = b.select(mask0, b.select(mask1, vec0, shflVec0_16),
146-
b.select(mask1, shflVec1_32, shflVec1_48));
147-
resVec1 = b.select(mask0, b.select(mask1, shflVec0_16, shflVec0_32),
148-
b.select(mask1, shflVec1_48, vec1));
149-
}
150-
151-
for (size_t vIdx = 0; vIdx < 4; ++vIdx) {
152-
outVals.push_back(b.extract_element(elemTy, resVec0, b.i32_val(vIdx)));
153-
}
154-
for (size_t vIdx = 0; vIdx < 4; ++vIdx) {
155-
outVals.push_back(b.extract_element(elemTy, resVec1, b.i32_val(vIdx)));
156-
}
157-
}
158-
159-
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
160-
op.getType());
161-
rewriter.replaceOp(op, result);
162-
return success();
163-
}
164-
165-
protected:
166-
const TargetInfoBase &targetInfo;
167-
};
168-
16916
// Match MFMA->Linear Layout conversion
17017
static bool matchMFMAAndLinearLayoutCase(RankedTensorType srcTy,
17118
RankedTensorType dstTy) {
@@ -338,8 +185,6 @@ struct ConvertLayoutForcedSwizzling
338185
void mlir::triton::AMD::populateConvertLayoutOpToLLVMPatterns(
339186
LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo,
340187
RewritePatternSet &patterns, PatternBenefit benefit) {
341-
patterns.add<ConvertLayoutOpMFMAToDotOpConversion>(typeConverter, targetInfo,
342-
benefit);
343188
patterns.add<ConvertLayoutOpMFMAToLinearConversion>(typeConverter, targetInfo,
344189
benefit);
345190
patterns.add<ConvertLayoutForcedPadding>(typeConverter, targetInfo, benefit);

0 commit comments

Comments
 (0)