Skip to content

Commit 8386213

Browse files
[Backend] Remove convertMMAV3To8BitsDotOperand (#7574)
We remove the custom layout conversion lowering for the MMA v3 8-bit case which used warp shuffles as it's now handled by the general pathway #7558. <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [ ] I have not added any `lit` tests. - [x] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) --------- Co-authored-by: apgoucher <[email protected]>
1 parent f0975f9 commit 8386213

File tree

5 files changed

+124
-127
lines changed

5 files changed

+124
-127
lines changed

include/triton/Analysis/Utility.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,10 +254,6 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy);
254254

255255
bool atomicNeedsSharedMemory(Value result);
256256

257-
// Return true if the src and dst layout match.
258-
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
259-
RankedTensorType dstTy);
260-
261257
// Check if MFMA layout can be converted to the dot operand
262258
// layout using warp shuffle.
263259
bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,

lib/Analysis/Utility.cpp

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -713,24 +713,6 @@ bool supportMMA(Value value, int version) {
713713
(elemTy.isInteger(8) && version >= 2);
714714
}
715715

716-
// For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases.
717-
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
718-
RankedTensorType dstTy) {
719-
auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(srcTy.getEncoding());
720-
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
721-
if (!mmaLayout || !dotOperandLayout) {
722-
return false;
723-
}
724-
int elementTypeSize = srcTy.getElementType().getIntOrFloatBitWidth();
725-
auto parentTy = srcTy.cloneWithEncoding(dotOperandLayout.getParent());
726-
auto ans = mmaLayout.getVersionMajor() == 3 &&
727-
dotOperandLayout.getOpIdx() == 0 &&
728-
mmaLayout.getWarpsPerCTA()[1] == 1 &&
729-
!cvtNeedsSharedMemory(parentTy, srcTy) && elementTypeSize == 8 &&
730-
dotOperandLayout.getKWidth() == 32 / elementTypeSize;
731-
return ans;
732-
}
733-
734716
bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,
735717
RankedTensorType dstTy) {
736718
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding());
@@ -810,7 +792,6 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
810792
// they're fully subsumed by the linear-layout checks.
811793
return !cvtReordersRegisters(srcTy, dstTy) &&
812794
!cvtNeedsWarpShuffle(srcTy, dstTy) &&
813-
!matchMmaV3AndDotOperandLayout(srcTy, dstTy) &&
814795
// to be removed when generalized warp shuffle conversions
815796
// are ready:
816797
!matchMFMAAndDotOperandShuffleCase(srcTy, dstTy);

test/Conversion/tritongpu_to_llvm_hopper.mlir

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -216,12 +216,9 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
216216
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8}>
217217
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
218218
// CHECK-LABEL: cvt_mma_to_dot_fp8
219-
// CHECK: nvvm.prmt
220-
// CHECK: nvvm.prmt
221-
// CHECK: nvvm.shfl.sync
222-
// CHECK: nvvm.shfl.sync
223-
// CHECK: nvvm.prmt
224-
// CHECK: nvvm.prmt
219+
// CHECK-COUNT-16: llvm.select
220+
// CHECK-COUNT-16: nvvm.shfl.sync
221+
// CHECK-COUNT-16: llvm.select
225222
tt.func @cvt_mma_to_dot_fp8(%a: tensor<128x64xf8E5M2, #mma>) {
226223
%opA = ttg.convert_layout %a : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
227224
tt.return
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm='compute-capability=90 ptx-version=83' --convert-nv-gpu-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 -S | llc -mtriple nvptx64-nvidia-cuda -mcpu=sm_90 -mattr=+ptx83 | FileCheck --dump-input-context=20 %s
2+
3+
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
4+
#dot_op = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth=4}>
5+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
6+
// CHECK-LABEL: cvt_mma_to_dot_fp8
7+
tt.func @cvt_mma_to_dot_fp8(%ptr : !llvm.ptr, %arg0: tensor<128x64xf8E5M2, #mma>) {
8+
9+
// As there are 64 elements per lane, we don't use variables to track them.
10+
11+
// CHECK-COUNT-64: ld.param.b8
12+
13+
// Intra-warp layout conversions can be viewed as a permutation of register
14+
// and lane basis vectors. This can be read off from the linear layouts:
15+
//
16+
// #mma: register: [[0,1], [8,0], [0,8], [0,16], [0,32], [64,0]]
17+
// lane: [[0,2], [0,4], [1,0], [2,0], [4,0]]
18+
// warp: [[16,0], [32,0]]
19+
//
20+
// #dot_op: register: [[0,1], [0,2], [8,0], [0,16], [0,32], [64,0]]
21+
// lane: [[0,4], [0,8], [1,0], [2,0], [4,0]]
22+
// warp: [[16,0], [32,0]]
23+
//
24+
// The layout conversion is described by the permutation (r1 r2 l1 l0),
25+
// which factors as (r1 l1)(l0 l1)(r1 r2).
26+
//
27+
// Register basis vectors correspond to the bits of the indices of the 64
28+
// separate registers which hold the original elements. Since we end up
29+
// packing 4 elements per register, we end up with only 16 registers in
30+
// total before shuffling. The `transferWithinWarp` implementation handles
31+
// register packing by ensuring that elements are packed together only if
32+
// under the layout conversion, they end up in the same destination lane.
33+
// To do this, it rearranges the 64 registers so that it can pack 4
34+
// consecutive elements at a time according to their new register index.
35+
//
36+
// The transposition (r1 l1) above indicates that intially, elements with
37+
// register indices whose r1 bit is on are to be moved to new lanes. We thus
38+
// need to rearrange the registers. The algorithm chooses the next register
39+
// bit > 1 which is not used in a mixed transposition. In this case,
40+
// that bit is r2. Algebrically, this corresponds to conjugating the
41+
// permutation with (r1 r2). This produces (r1 r2)(r2 l1)(l0 l1). The new
42+
// (r1 r2) at the end rearranges elements after unpacking, and only
43+
// (r2 l1)(l0 l1) matters for tracking the movement of the packed registers.
44+
// From the point of view of the packed registers, the symbol `r2` now
45+
// corresponds to the 0th bit of a (packed) register's index.
46+
//
47+
// The transposition (r2 l1) is a bit swap which is implemented in-place as:
48+
// 1. r2 ^= l1
49+
// 2. l1 ^= r2
50+
// 3. r2 ^= l1.
51+
// The algorithm conjugates (l0 l1) through the first two stages to produce:
52+
// 1. r2 ^= l0
53+
// 2a. l0 ^= r2
54+
// 2b. (l0 l1)
55+
// 3. r2 ^= l1.
56+
// The first step is to get the value of l0.
57+
58+
// CHECK: mov.u32 [[TID:%.*]], %tid.x;
59+
// CHECK: and.b32 [[L0_VAL:%.*]], [[TID]], 1;
60+
// CHECK: setp.eq.s32 [[L0_OFF:%.*]], [[L0_VAL]], 0;
61+
62+
// This is used to perform 16 independent selects in stage 1.
63+
64+
// CHECK-COUNT-16: selp.b32 {{.*}}, {{.*}}, [[L0_OFF]];
65+
66+
// Next, we apply (l0 l1) to the lane id to get the base source lane for
67+
// the index shuffles. This is step 2b above, but since we must specify
68+
// the *source* lane for a warp-shuffle, it gets applied first in practice:
69+
//
70+
// dstLane = ((l0 l1) \circ (l0 ^= r2))(srcLane)
71+
// srcLane = ((l0 ^= r2) \circ (l0 l1))(dstLane)
72+
//
73+
// To apply (l0 l1), we use a compile-time mask to collect the fixed bits,
74+
// and then we OR it with the shifted l0 and l1 values.
75+
76+
// CHECK-DAG: and.b32 [[LANEID_FIXED_BITS:%.*]], [[TID]], 28;
77+
// CHECK-DAG: shl.b32 [[L0_TEMP:%.*]], [[L0_VAL]], 1;
78+
// CHECK-DAG: or.b32 [[LANEID_PART_PERM:%.*]], [[L0_TEMP]], [[LANEID_FIXED_BITS]];
79+
// CHECK-DAG: bfe.u32 [[L1_TEMP:%.*]], [[TID]], 1, 1;
80+
// CHECK-DAG: or.b32 [[LANEID_PERM:%.*]], [[LANEID_PART_PERM]], [[L1_TEMP]];
81+
82+
// The index shuffles have source lane dependent on the value of the r2 bit.
83+
// Half of them use `LANEID_PERM` while the other half use `LANEID_PERM`
84+
// with the l0 bit flipped (step 2a).
85+
86+
// CHECK-DAG: xor.b32 [[LANEID_PERM_F:%.*]], [[LANEID_PERM]], 1;
87+
88+
// CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM]], 31, -1;
89+
// CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM]], 31, -1;
90+
// CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM]], 31, -1;
91+
// CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM]], 31, -1;
92+
// CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM]], 31, -1;
93+
// CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM]], 31, -1;
94+
// CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM]], 31, -1;
95+
// CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM]], 31, -1;
96+
// CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM_F]], 31, -1;
97+
// CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM_F]], 31, -1;
98+
// CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM_F]], 31, -1;
99+
// CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM_F]], 31, -1;
100+
// CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM_F]], 31, -1;
101+
// CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM_F]], 31, -1;
102+
// CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM_F]], 31, -1;
103+
// CHECK-DAG: shfl.sync.idx.b32 {{.*}}, [[LANEID_PERM_F]], 31, -1;
104+
105+
// Finally, the last set of selects are performed, using the value of l1 as
106+
// the predicate (step 3).
107+
108+
// CHECK-DAG: and.b32 [[L1_VAL:%.*]], [[TID]], 2;
109+
// CHECK-DAG: setp.eq.s32 [[L1_OFF:%.*]], [[L1_VAL]], 0;
110+
// CHECK-COUNT-16: selp.b32 {{.*}}, {{.*}}, [[L1_OFF]];
111+
112+
// CHECK-COUNT-64: bfe.u32
113+
// CHECK-COUNT-64: st.volatile.global.b8
114+
115+
%0 = ttg.convert_layout %arg0 : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #dot_op>
116+
%1 = builtin.unrealized_conversion_cast %0 : tensor<128x64xf8E5M2, #dot_op> to !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)>
117+
llvm.store volatile %1, %ptr : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)>, !llvm.ptr
118+
119+
tt.return
120+
}
121+
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 0 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,6 @@ struct ConvertLayoutOpConversion
3838
if (shouldUseDistSmem(srcLayout, dstLayout))
3939
return lowerDistToDistWithDistSmem(op, adaptor, rewriter, targetInfo);
4040
}
41-
if (isa<NvidiaMmaEncodingAttr>(srcLayout) &&
42-
isa<DotOperandEncodingAttr>(dstLayout)) {
43-
return lowerMmaToDotOperand(op, adaptor, rewriter);
44-
}
4541

4642
return failure();
4743
}
@@ -136,100 +132,6 @@ struct ConvertLayoutOpConversion
136132
return success();
137133
}
138134

139-
// Convert from accumulator MMA layout to 8bit dot operand layout.
140-
// The conversion logic is taken from:
141-
// https://github.com/ColfaxResearch/cutlass-kernels/blob/a9de6446c1c0415c926025cea284210c799b11f8/src/fmha-pipeline/reg2reg.h#L45
142-
void
143-
convertMMAV3To8BitsDotOperand(triton::gpu::ConvertLayoutOp op,
144-
OpAdaptor adaptor,
145-
ConversionPatternRewriter &rewriter) const {
146-
auto loc = op.getLoc();
147-
auto b = TritonLLVMOpBuilder(loc, rewriter);
148-
auto dstTy = op.getType();
149-
auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
150-
SmallVector<Value> retVals;
151-
for (int i = 0; i < vals.size(); i += 8) {
152-
Value upper = b.undef(vec_ty(i8_ty, 4));
153-
for (int j = 0; j < 4; j++) {
154-
upper = b.insert_element(vec_ty(i8_ty, 4), upper, vals[i + j],
155-
b.i32_val(j));
156-
}
157-
upper = b.bitcast(upper, i32_ty);
158-
Value lower = b.undef(vec_ty(i8_ty, 4));
159-
for (int j = 0; j < 4; j++) {
160-
lower = b.insert_element(vec_ty(i8_ty, 4), lower, vals[i + 4 + j],
161-
b.i32_val(j));
162-
}
163-
lower = b.bitcast(lower, i32_ty);
164-
165-
Value threadIdMod4 = b.urem(getThreadId(rewriter, loc), b.i32_val(4));
166-
Value cnd = b.or_(b.icmp_eq(threadIdMod4, b.i32_val(0)),
167-
b.icmp_eq(threadIdMod4, b.i32_val(3)));
168-
Value selectorEx0 = b.select(cnd, b.i32_val(0x3210), b.i32_val(0x7654));
169-
Value selectorEx1 = b.select(cnd, b.i32_val(0x7654), b.i32_val(0x3210));
170-
Value selectorEx4 = b.select(cnd, b.i32_val(0x5410), b.i32_val(0x1054));
171-
Value selectorEx5 = b.select(cnd, b.i32_val(0x7632), b.i32_val(0x3276));
172-
173-
Value isOne = b.icmp_eq(threadIdMod4, b.i32_val(1));
174-
Value isTwo = b.icmp_eq(threadIdMod4, b.i32_val(2));
175-
Value isThree = b.icmp_eq(threadIdMod4, b.i32_val(3));
176-
Value upperIdx = b.i32_val(0);
177-
upperIdx = b.select(isOne, b.i32_val(3), upperIdx);
178-
upperIdx = b.select(isTwo, b.i32_val(1), upperIdx);
179-
upperIdx = b.select(isThree, b.i32_val(2), upperIdx);
180-
181-
Value lowerIdx = b.i32_val(1);
182-
lowerIdx = b.select(isOne, b.i32_val(2), lowerIdx);
183-
lowerIdx = b.select(isTwo, b.i32_val(0), lowerIdx);
184-
lowerIdx = b.select(isThree, b.i32_val(3), lowerIdx);
185-
186-
Value upper0 =
187-
LLVM::NVIDIA::permute(loc, rewriter, upper, lower, selectorEx0);
188-
Value lower0 =
189-
LLVM::NVIDIA::permute(loc, rewriter, upper, lower, selectorEx1);
190-
Value mask = b.i32_val(0xFFFFFFFF);
191-
// Set clamp tp shuffle only within 4 lanes.
192-
Value clamp = b.i32_val(0x1C1F);
193-
upper0 =
194-
rewriter.create<NVVM::ShflOp>(loc, i32_ty, mask, upper0, upperIdx,
195-
clamp, NVVM::ShflKind::idx, UnitAttr());
196-
lower0 =
197-
rewriter.create<NVVM::ShflOp>(loc, i32_ty, mask, lower0, lowerIdx,
198-
clamp, NVVM::ShflKind::idx, UnitAttr());
199-
Value upper1 =
200-
LLVM::NVIDIA::permute(loc, rewriter, upper0, lower0, selectorEx4);
201-
Value vecVal = b.bitcast(upper1, vec_ty(i8_ty, 4));
202-
for (int i = 0; i < 4; i++) {
203-
retVals.push_back(b.extract_element(i8_ty, vecVal, b.i32_val(i)));
204-
}
205-
Value lower1 =
206-
LLVM::NVIDIA::permute(loc, rewriter, upper0, lower0, selectorEx5);
207-
vecVal = b.bitcast(lower1, vec_ty(i8_ty, 4));
208-
for (int i = 0; i < 4; i++) {
209-
retVals.push_back(b.extract_element(i8_ty, vecVal, b.i32_val(i)));
210-
}
211-
}
212-
Value result =
213-
packLLElements(loc, getTypeConverter(), retVals, rewriter, dstTy);
214-
rewriter.replaceOp(op, result);
215-
}
216-
217-
// mma -> dot_operand
218-
LogicalResult
219-
lowerMmaToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
220-
ConversionPatternRewriter &rewriter) const {
221-
auto loc = op.getLoc();
222-
auto srcTy = op.getSrc().getType();
223-
auto dstTy = op.getType();
224-
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) {
225-
assert(srcTy.getElementType().getIntOrFloatBitWidth() == 8 &&
226-
"Unsupported type size.");
227-
convertMMAV3To8BitsDotOperand(op, adaptor, rewriter);
228-
return success();
229-
}
230-
return failure();
231-
}
232-
233135
private:
234136
const NVIDIA::TargetInfo &targetInfo;
235137
};

0 commit comments

Comments
 (0)