Skip to content

Commit 41ecd1c

Browse files
authored
[AMD] Reduce instruction count to upcast mxfp4 (triton-lang#5651)
This PR improves the logic to upcast mxfp4 to bf/fp16 values. Previously, we process 4 mxfp4 values at a time with about 23 instructions. Therefore, it takes **46** instructions to upcast 8 mxfp4 elements. This PR processes 8 mxfp4 values at a time with **20** instructions.
1 parent 04de400 commit 41ecd1c

File tree

1 file changed

+112
-70
lines changed

1 file changed

+112
-70
lines changed

third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp

Lines changed: 112 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -21,42 +21,9 @@ using namespace mlir::triton::gpu;
2121

2222
namespace {
2323

24-
// Returns (EM, S) selectors to the llvm.amdgcn.perm intrinsic for selecting
25-
// resultant bf16/fp16 bytes in the lookup table.
26-
std::pair<Value, Value> composePermuteSelectors(Location loc,
27-
RewriterBase &rewriter,
28-
Value val10, Value val32) {
29-
// Each input value packs two mxfp4 values. First extract each mxfp4 value's
30-
// EM and S bits. In order to form the selector for llvm.amdgcn.perm
31-
// instruction, we need to shuffle them into a 4xu8 manner.
32-
33-
// 0xX[S.EE.M] -> 0x0000000[0EEM]
34-
Value v0EM = zext(i32_ty, and_(val10, i8_val(0x07)));
35-
// 0xX[S.EE.M] -> 0x0000000[000S]
36-
Value v0S = lshr(zext(i32_ty, and_(val10, i8_val(0x08))), i32_val(3));
37-
// 0x[S.EE.M]X -> 0x00000[0EEM]00
38-
Value v1EM = shl(zext(i32_ty, and_(val10, i8_val(0x70))), i32_val(4));
39-
// 0x[S.EE.M]X -> 0x00000[000S]00
40-
Value v1S = shl(zext(i32_ty, and_(val10, i8_val(0x80))), i32_val(4 - 3));
41-
42-
// 0xX[S.EE.M] -> 0x000[0EEM]0000
43-
Value v2EM = shl(zext(i32_ty, and_(val32, i8_val(0x07))), i32_val(16));
44-
// 0xX[S.EE.M] -> 0x000[000S]0000
45-
Value v2S = shl(zext(i32_ty, and_(val32, i8_val(0x08))), i32_val(16 - 3));
46-
// 0x[S.EE.M]X -> 0x0[0EEM]000000
47-
Value v3EM = shl(zext(i32_ty, and_(val32, i8_val(0x70))), i32_val(20));
48-
// 0x[S.EE.M]X -> 0x0[000S]000000
49-
Value v3S = shl(zext(i32_ty, and_(val32, i8_val(0x80))), i32_val(20 - 3));
50-
51-
Value selectorEM = or_(v3EM, or_(v2EM, or_(v1EM, v0EM)));
52-
Value selectorS = or_(v3S, or_(v2S, or_(v1S, v0S)));
53-
return {selectorEM, selectorS};
54-
}
55-
56-
SmallVector<Value, 2> upcast4xMxfp4(RewriterBase &rewriter,
24+
SmallVector<Value, 4> upcast8xMxfp4(RewriterBase &rewriter,
5725
UpcastMXFPOp upcastOp, bool tofp16,
58-
ArrayRef<Value> inputs) {
59-
assert(inputs.size() == 2);
26+
Value packedVec) {
6027
Location loc = upcastOp.getLoc();
6128

6229
// MXFP4 has 4 bits, S.EE.M, for Sign, Exponent, and Mantissa respectively.
@@ -68,9 +35,6 @@ SmallVector<Value, 2> upcast4xMxfp4(RewriterBase &rewriter,
6835
// resultant bf16/fp16 bit patterns separately. For the byte containing S, we
6936
// also need to handle the S and E bits separately.
7037

71-
auto [selectorEM, selectorS] =
72-
composePermuteSelectors(loc, rewriter, inputs[0], inputs[1]);
73-
7438
// FP4 has 4 bits: S.EE.M. Bf16/fp16 bit patterns for positive values:
7539
//
7640
// FP4 | BF16 | FP16 | Value
@@ -90,55 +54,133 @@ SmallVector<Value, 2> upcast4xMxfp4(RewriterBase &rewriter,
9054
// Encode Byte #1 (EM, non-S part) for BF16/FP16 in a LUT.
9155
Value resB1LutLoNoS = tofp16 ? i32_val(0x3e3c3800) : i32_val(0x3f3f3f00);
9256
Value resB1LutHiNoS = tofp16 ? i32_val(0x46444240) : i32_val(0x40404040);
93-
// Encode Byte #1 (S part) for BF16/FP16 in a LUT.
94-
Value resB1LutLoS = i32_val(0x8000);
95-
Value resB1LutHiS = i32_val(0);
9657

9758
Type i32Ty = rewriter.getI32Type();
9859
auto permU32FnTy = LLVM::LLVMFunctionType::get(i32Ty, {i32Ty, i32Ty, i32Ty});
9960
LLVM::LLVMFuncOp funcOp = appendOrGetExternFuncOp(
10061
rewriter, upcastOp, "llvm.amdgcn.perm", permU32FnTy);
10162

102-
// Select Byte #0 for all 4 mxfp4 values. It's always 0 if upcasting to fp16.
103-
Value resB0 = i32_val(0);
63+
// Start with 8 mxfp4 elements in a single i32 register
64+
// | e7e6 | e5e4 | e3e2 | e1e0 |
65+
Value input = bitcast(packedVec, i32Ty);
66+
67+
// Step 1: extract EM bits for elements 0,2,4,6 and 1,3,5,7 respectively.
68+
// e2m1_6420_idx = | 0[0e6EM] | 0[0e4EM] | 0[0e2EM] | 0[0e0EM] |
69+
Value e2m1_6420_idx = and_(input, i32_val(0x07070707));
70+
// e2m1_7531_idx = | [0e7EM]0 | [0e5EM]0 | [0e3EM]0 | [0e1EM]0 |
71+
Value e2m1_7531_idx = and_(input, i32_val(0x70707070));
72+
// e2m1_7531_idx = | 0[0e7EM] | 0[0e5EM] | 0[0e3EM] | 0[0e1EM] |
73+
e2m1_7531_idx = lshr(e2m1_7531_idx, i32_val(4));
74+
75+
// Step 2: extract S bit for elements 0,2,4,6 and 1,3,5,7
76+
// s_6420 = | 0[e6S000] | 0[e4S000] | 0[e2S000] | 0[e0S000] |
77+
Value s_6420 = and_(input, i32_val(0x08080808));
78+
// s_6420 = | [e6S000]0 | [e4S000]0 | [e2S000]0 | [e0S000]0 |
79+
s_6420 = shl(s_6420, i32_val(4));
80+
// s_7531 = | [e7S000]0 | [e5S000]0 | [e3S000]0 | [e1S000]0 |
81+
Value s_7531 = and_(input, i32_val(0x80808080));
82+
83+
// Step 3: Upcast elements 0,2,4,6 to 4 16-bit elements
84+
// Select Byte #0. It's always 0 if upcasting to fp16.
85+
// resB0_6420 = | e6B0 | e4B0 | e2B0 | e0B0 |
86+
Value resB0_6420 = i32_val(0);
87+
if (!tofp16) {
88+
resB0_6420 = LLVM::createLLVMCallOp(rewriter, loc, funcOp,
89+
{resB0LutHi, resB0LutLo, e2m1_6420_idx})
90+
.getResult();
91+
}
92+
// Select Byte #1
93+
Value resB1NoS_6420 =
94+
LLVM::createLLVMCallOp(rewriter, loc, funcOp,
95+
{resB1LutHiNoS, resB1LutLoNoS, e2m1_6420_idx})
96+
.getResult();
97+
// resB1_6420 = | e6B1 | e4B1 | e2B1 | e0B1 |
98+
Value resB1_6420 = or_(resB1NoS_6420, s_6420);
99+
// Construct 16-bit values of e0 and e2
100+
// res_20 = | e2B1 | e2B0 | e0B1 | e0B0 | = | e2_f16 | e0_f16 |
101+
Value res_20 =
102+
LLVM::createLLVMCallOp(rewriter, loc, funcOp,
103+
{resB1_6420, resB0_6420, i32_val(0x05010400)})
104+
.getResult();
105+
// Construct 16-bit values of e4 and e6
106+
// res_64 = | e6B1 | e6B0 | e4B1 | e4B0 | = | e6_f16 | e4_f16 |
107+
Value res_64 =
108+
LLVM::createLLVMCallOp(rewriter, loc, funcOp,
109+
{resB1_6420, resB0_6420, i32_val(0x07030602)})
110+
.getResult();
111+
112+
// Step 4: Upcast elements 1,3,5,7 to 4 16-bit elements
113+
// This is a copy of step 3 on different group of elements
114+
// Select Byte #0. It's always 0 if upcasting to fp16.
115+
// resB0_7531 = | e7B0 | e5B0 | e3B0 | e1B0 |
116+
Value resB0_7531 = i32_val(0);
104117
if (!tofp16) {
105-
resB0 = LLVM::createLLVMCallOp(rewriter, loc, funcOp,
106-
{resB0LutHi, resB0LutLo, selectorEM})
107-
.getResult();
118+
resB0_7531 = LLVM::createLLVMCallOp(rewriter, loc, funcOp,
119+
{resB0LutHi, resB0LutLo, e2m1_7531_idx})
120+
.getResult();
108121
}
109-
// Select Byte #1 for all 4 mxfp4 values.
110-
auto resB1NoS = LLVM::createLLVMCallOp(
111-
rewriter, loc, funcOp, {resB1LutHiNoS, resB1LutLoNoS, selectorEM});
112-
auto resB1S = LLVM::createLLVMCallOp(rewriter, loc, funcOp,
113-
{resB1LutHiS, resB1LutLoS, selectorS});
114-
Value restB1 = or_(resB1NoS.getResult(), resB1S.getResult());
115-
116-
// Extract resultant bf16/fp16 values #0 and #1.
117-
// #0 would use selector 0x00/0x04 to pick from B0/B1.
118-
// #1 would use selector 0x01/0x05 to pick from B0/B1.
119-
auto res10 = LLVM::createLLVMCallOp(rewriter, loc, funcOp,
120-
{restB1, resB0, i32_val(0x05010400)});
121-
// Extract resultant bf16/fp16 values #2 and #3.
122-
// #2 would use selector 0x02/0x06 to pick from B0/B1.
123-
// #3 would use selector 0x03/0x07 to pick from B0/B1.
124-
auto res32 = LLVM::createLLVMCallOp(rewriter, loc, funcOp,
125-
{restB1, resB0, i32_val(0x07030602)});
126-
127-
return {res10.getResult(), res32.getResult()};
122+
// Select Byte #1
123+
Value resB1NoS_7531 =
124+
LLVM::createLLVMCallOp(rewriter, loc, funcOp,
125+
{resB1LutHiNoS, resB1LutLoNoS, e2m1_7531_idx})
126+
.getResult();
127+
// resB1_7531 = | e7B1 | e5B1 | e3B1 | e1B1 |
128+
Value resB1_7531 = or_(resB1NoS_7531, s_7531);
129+
// Construct 16-bit values of e1 and e3
130+
// res_31 = | e3B1 | e3B0 | e1B1 | e1B0 | = | e3_f16 | e1_f16 |
131+
Value res_31 =
132+
LLVM::createLLVMCallOp(rewriter, loc, funcOp,
133+
{resB1_7531, resB0_7531, i32_val(0x05010400)})
134+
.getResult();
135+
// Construct 16-bit values of e5 and e7
136+
// res_75 = | e7B1 | e7B0 | e5B1 | e5B0 | = | e7_f16 | e5_f16 |
137+
Value res_75 =
138+
LLVM::createLLVMCallOp(rewriter, loc, funcOp,
139+
{resB1_7531, resB0_7531, i32_val(0x07030602)})
140+
.getResult();
141+
142+
// Step 5: Reorder 16-bit elements to be 0,1,2,3,4,5,6,7
143+
// res_10 = | e1_f16 | e0_f16 |
144+
Value res_10 = LLVM::createLLVMCallOp(rewriter, loc, funcOp,
145+
{res_31, res_20, i32_val(0x05040100)})
146+
.getResult();
147+
// res_32 = | e3_f16 | e2_f16 |
148+
Value res_32 = LLVM::createLLVMCallOp(rewriter, loc, funcOp,
149+
{res_31, res_20, i32_val(0x07060302)})
150+
.getResult();
151+
// res_54 = | e5_f16 | e4_f16 |
152+
Value res_54 = LLVM::createLLVMCallOp(rewriter, loc, funcOp,
153+
{res_75, res_64, i32_val(0x05040100)})
154+
.getResult();
155+
// res_76 = | e7_f16 | e6_f16 |
156+
Value res_76 = LLVM::createLLVMCallOp(rewriter, loc, funcOp,
157+
{res_75, res_64, i32_val(0x07060302)})
158+
.getResult();
159+
160+
return {res_10, res_32, res_54, res_76};
128161
}
129162

130163
SmallVector<Value> upcastMxfp4(RewriterBase &rewriter, UpcastMXFPOp upcastOp,
131164
bool toFp16, ArrayRef<Value> values) {
132-
assert(values.size() % 2 == 0);
165+
assert(values.size() % 4 == 0);
133166
Location loc = upcastOp.getLoc();
134167

135168
SmallVector<Value> results;
136169
results.reserve(values.size() * 2);
137170
Type elemType = toFp16 ? f16_ty : bf16_ty;
138-
for (int i = 0; i < values.size(); i += 2) {
139-
SmallVector<Value, 2> v4i32 =
140-
upcast4xMxfp4(rewriter, upcastOp, toFp16, values.slice(i, 2));
141-
for (int j = 0; j < 2; j++) {
171+
for (int i = 0; i < values.size(); i += 4) {
172+
Value v0 = values[i];
173+
Value v1 = values[i + 1];
174+
Value v2 = values[i + 2];
175+
Value v3 = values[i + 3];
176+
Value packedVec = undef(vec_ty(i8_ty, 4));
177+
packedVec = insert_element(packedVec, v0, i32_val(0));
178+
packedVec = insert_element(packedVec, v1, i32_val(1));
179+
packedVec = insert_element(packedVec, v2, i32_val(2));
180+
packedVec = insert_element(packedVec, v3, i32_val(3));
181+
SmallVector<Value, 4> v4i32 =
182+
upcast8xMxfp4(rewriter, upcastOp, toFp16, packedVec);
183+
for (int j = 0; j < 4; j++) {
142184
Value elements = bitcast(v4i32[j], vec_ty(elemType, 2));
143185
results.push_back(extract_element(elements, i32_val(0)));
144186
results.push_back(extract_element(elements, i32_val(1)));

0 commit comments

Comments
 (0)