@@ -21,42 +21,9 @@ using namespace mlir::triton::gpu;
2121
2222namespace {
2323
24- // Returns (EM, S) selectors to the llvm.amdgcn.perm intrinsic for selecting
25- // resultant bf16/fp16 bytes in the lookup table.
26- std::pair<Value, Value> composePermuteSelectors (Location loc,
27- RewriterBase &rewriter,
28- Value val10, Value val32) {
29- // Each input value packs two mxfp4 values. First extract each mxfp4 value's
30- // EM and S bits. In order to form the selector for llvm.amdgcn.perm
31- // instruction, we need to shuffle them into a 4xu8 manner.
32-
33- // 0xX[S.EE.M] -> 0x0000000[0EEM]
34- Value v0EM = zext (i32_ty, and_ (val10, i8_val (0x07 )));
35- // 0xX[S.EE.M] -> 0x0000000[000S]
36- Value v0S = lshr (zext (i32_ty, and_ (val10, i8_val (0x08 ))), i32_val (3 ));
37- // 0x[S.EE.M]X -> 0x00000[0EEM]00
38- Value v1EM = shl (zext (i32_ty, and_ (val10, i8_val (0x70 ))), i32_val (4 ));
39- // 0x[S.EE.M]X -> 0x00000[000S]00
40- Value v1S = shl (zext (i32_ty, and_ (val10, i8_val (0x80 ))), i32_val (4 - 3 ));
41-
42- // 0xX[S.EE.M] -> 0x000[0EEM]0000
43- Value v2EM = shl (zext (i32_ty, and_ (val32, i8_val (0x07 ))), i32_val (16 ));
44- // 0xX[S.EE.M] -> 0x000[000S]0000
45- Value v2S = shl (zext (i32_ty, and_ (val32, i8_val (0x08 ))), i32_val (16 - 3 ));
46- // 0x[S.EE.M]X -> 0x0[0EEM]000000
47- Value v3EM = shl (zext (i32_ty, and_ (val32, i8_val (0x70 ))), i32_val (20 ));
48- // 0x[S.EE.M]X -> 0x0[000S]000000
49- Value v3S = shl (zext (i32_ty, and_ (val32, i8_val (0x80 ))), i32_val (20 - 3 ));
50-
51- Value selectorEM = or_ (v3EM, or_ (v2EM, or_ (v1EM, v0EM)));
52- Value selectorS = or_ (v3S, or_ (v2S, or_ (v1S, v0S)));
53- return {selectorEM, selectorS};
54- }
55-
56- SmallVector<Value, 2 > upcast4xMxfp4 (RewriterBase &rewriter,
24+ SmallVector<Value, 4 > upcast8xMxfp4 (RewriterBase &rewriter,
5725 UpcastMXFPOp upcastOp, bool tofp16,
58- ArrayRef<Value> inputs) {
59- assert (inputs.size () == 2 );
26+ Value packedVec) {
6027 Location loc = upcastOp.getLoc ();
6128
6229 // MXFP4 has 4 bits, S.EE.M, for Sign, Exponent, and Mantissa respectively.
@@ -68,9 +35,6 @@ SmallVector<Value, 2> upcast4xMxfp4(RewriterBase &rewriter,
6835 // resultant bf16/fp16 bit patterns separately. For the byte containing S, we
6936 // also need to handle the S and E bits separately.
7037
71- auto [selectorEM, selectorS] =
72- composePermuteSelectors (loc, rewriter, inputs[0 ], inputs[1 ]);
73-
7438 // FP4 has 4 bits: S.EE.M. Bf16/fp16 bit patterns for positive values:
7539 //
7640 // FP4 | BF16 | FP16 | Value
@@ -90,55 +54,133 @@ SmallVector<Value, 2> upcast4xMxfp4(RewriterBase &rewriter,
9054 // Encode Byte #1 (EM, non-S part) for BF16/FP16 in a LUT.
9155 Value resB1LutLoNoS = tofp16 ? i32_val (0x3e3c3800 ) : i32_val (0x3f3f3f00 );
9256 Value resB1LutHiNoS = tofp16 ? i32_val (0x46444240 ) : i32_val (0x40404040 );
93- // Encode Byte #1 (S part) for BF16/FP16 in a LUT.
94- Value resB1LutLoS = i32_val (0x8000 );
95- Value resB1LutHiS = i32_val (0 );
9657
9758 Type i32Ty = rewriter.getI32Type ();
9859 auto permU32FnTy = LLVM::LLVMFunctionType::get (i32Ty, {i32Ty, i32Ty, i32Ty});
9960 LLVM::LLVMFuncOp funcOp = appendOrGetExternFuncOp (
10061 rewriter, upcastOp, " llvm.amdgcn.perm" , permU32FnTy);
10162
102- // Select Byte #0 for all 4 mxfp4 values. It's always 0 if upcasting to fp16.
103- Value resB0 = i32_val (0 );
63+ // Start with 8 mxfp4 elements in a single i32 register
64+ // | e7e6 | e5e4 | e3e2 | e1e0 |
65+ Value input = bitcast (packedVec, i32Ty);
66+
67+ // Step 1: extract EM bits for elements 0,2,4,6 and 1,3,5,7 respectively.
68+ // e2m1_6420_idx = | 0[0e6EM] | 0[0e4EM] | 0[0e2EM] | 0[0e0EM] |
69+ Value e2m1_6420_idx = and_ (input, i32_val (0x07070707 ));
70+ // e2m1_7531_idx = | [0e7EM]0 | [0e5EM]0 | [0e3EM]0 | [0e1EM]0 |
71+ Value e2m1_7531_idx = and_ (input, i32_val (0x70707070 ));
72+ // e2m1_7531_idx = | 0[0e7EM] | 0[0e5EM] | 0[0e3EM] | 0[0e1EM] |
73+ e2m1_7531_idx = lshr (e2m1_7531_idx, i32_val (4 ));
74+
75+ // Step 2: extract S bit for elements 0,2,4,6 and 1,3,5,7
76+ // s_6420 = | 0[e6S000] | 0[e4S000] | 0[e2S000] | 0[e0S000] |
77+ Value s_6420 = and_ (input, i32_val (0x08080808 ));
78+ // s_6420 = | [e6S000]0 | [e4S000]0 | [e2S000]0 | [e0S000]0 |
79+ s_6420 = shl (s_6420, i32_val (4 ));
80+ // s_7531 = | [e7S000]0 | [e5S000]0 | [e3S000]0 | [e1S000]0 |
81+ Value s_7531 = and_ (input, i32_val (0x80808080 ));
82+
83+ // Step 3: Upcast elements 0,2,4,6 to 4 16-bit elements
84+ // Select Byte #0. It's always 0 if upcasting to fp16.
85+ // resB0_6420 = | e6B0 | e4B0 | e2B0 | e0B0 |
86+ Value resB0_6420 = i32_val (0 );
87+ if (!tofp16) {
88+ resB0_6420 = LLVM::createLLVMCallOp (rewriter, loc, funcOp,
89+ {resB0LutHi, resB0LutLo, e2m1_6420_idx})
90+ .getResult ();
91+ }
92+ // Select Byte #1
93+ Value resB1NoS_6420 =
94+ LLVM::createLLVMCallOp (rewriter, loc, funcOp,
95+ {resB1LutHiNoS, resB1LutLoNoS, e2m1_6420_idx})
96+ .getResult ();
97+ // resB1_6420 = | e6B1 | e4B1 | e2B1 | e0B1 |
98+ Value resB1_6420 = or_ (resB1NoS_6420, s_6420);
99+ // Construct 16-bit values of e0 and e2
100+ // res_20 = | e2B1 | e2B0 | e0B1 | e0B0 | = | e2_f16 | e0_f16 |
101+ Value res_20 =
102+ LLVM::createLLVMCallOp (rewriter, loc, funcOp,
103+ {resB1_6420, resB0_6420, i32_val (0x05010400 )})
104+ .getResult ();
105+ // Construct 16-bit values of e4 and e6
106+ // res_64 = | e6B1 | e6B0 | e4B1 | e4B0 | = | e6_f16 | e4_f16 |
107+ Value res_64 =
108+ LLVM::createLLVMCallOp (rewriter, loc, funcOp,
109+ {resB1_6420, resB0_6420, i32_val (0x07030602 )})
110+ .getResult ();
111+
112+ // Step 4: Upcast elements 1,3,5,7 to 4 16-bit elements
113+ // This is a copy of step 3 on different group of elements
114+ // Select Byte #0. It's always 0 if upcasting to fp16.
115+ // resB0_7531 = | e7B0 | e5B0 | e3B0 | e1B0 |
116+ Value resB0_7531 = i32_val (0 );
104117 if (!tofp16) {
105- resB0 = LLVM::createLLVMCallOp (rewriter, loc, funcOp,
106- {resB0LutHi, resB0LutLo, selectorEM })
107- .getResult ();
118+ resB0_7531 = LLVM::createLLVMCallOp (rewriter, loc, funcOp,
119+ {resB0LutHi, resB0LutLo, e2m1_7531_idx })
120+ .getResult ();
108121 }
109- // Select Byte #1 for all 4 mxfp4 values.
110- auto resB1NoS = LLVM::createLLVMCallOp (
111- rewriter, loc, funcOp, {resB1LutHiNoS, resB1LutLoNoS, selectorEM});
112- auto resB1S = LLVM::createLLVMCallOp (rewriter, loc, funcOp,
113- {resB1LutHiS, resB1LutLoS, selectorS});
114- Value restB1 = or_ (resB1NoS.getResult (), resB1S.getResult ());
115-
116- // Extract resultant bf16/fp16 values #0 and #1.
117- // #0 would use selector 0x00/0x04 to pick from B0/B1.
118- // #1 would use selector 0x01/0x05 to pick from B0/B1.
119- auto res10 = LLVM::createLLVMCallOp (rewriter, loc, funcOp,
120- {restB1, resB0, i32_val (0x05010400 )});
121- // Extract resultant bf16/fp16 values #2 and #3.
122- // #2 would use selector 0x02/0x06 to pick from B0/B1.
123- // #3 would use selector 0x03/0x07 to pick from B0/B1.
124- auto res32 = LLVM::createLLVMCallOp (rewriter, loc, funcOp,
125- {restB1, resB0, i32_val (0x07030602 )});
126-
127- return {res10.getResult (), res32.getResult ()};
122+ // Select Byte #1
123+ Value resB1NoS_7531 =
124+ LLVM::createLLVMCallOp (rewriter, loc, funcOp,
125+ {resB1LutHiNoS, resB1LutLoNoS, e2m1_7531_idx})
126+ .getResult ();
127+ // resB1_7531 = | e7B1 | e5B1 | e3B1 | e1B1 |
128+ Value resB1_7531 = or_ (resB1NoS_7531, s_7531);
129+ // Construct 16-bit values of e1 and e3
130+ // res_31 = | e3B1 | e3B0 | e1B1 | e1B0 | = | e3_f16 | e1_f16 |
131+ Value res_31 =
132+ LLVM::createLLVMCallOp (rewriter, loc, funcOp,
133+ {resB1_7531, resB0_7531, i32_val (0x05010400 )})
134+ .getResult ();
135+ // Construct 16-bit values of e5 and e7
136+ // res_75 = | e7B1 | e7B0 | e5B1 | e5B0 | = | e7_f16 | e5_f16 |
137+ Value res_75 =
138+ LLVM::createLLVMCallOp (rewriter, loc, funcOp,
139+ {resB1_7531, resB0_7531, i32_val (0x07030602 )})
140+ .getResult ();
141+
142+ // Step 5: Reorder 16-bit elements to be 0,1,2,3,4,5,6,7
143+ // res_10 = | e1_f16 | e0_f16 |
144+ Value res_10 = LLVM::createLLVMCallOp (rewriter, loc, funcOp,
145+ {res_31, res_20, i32_val (0x05040100 )})
146+ .getResult ();
147+ // res_32 = | e3_f16 | e2_f16 |
148+ Value res_32 = LLVM::createLLVMCallOp (rewriter, loc, funcOp,
149+ {res_31, res_20, i32_val (0x07060302 )})
150+ .getResult ();
151+ // res_54 = | e5_f16 | e4_f16 |
152+ Value res_54 = LLVM::createLLVMCallOp (rewriter, loc, funcOp,
153+ {res_75, res_64, i32_val (0x05040100 )})
154+ .getResult ();
155+ // res_76 = | e7_f16 | e6_f16 |
156+ Value res_76 = LLVM::createLLVMCallOp (rewriter, loc, funcOp,
157+ {res_75, res_64, i32_val (0x07060302 )})
158+ .getResult ();
159+
160+ return {res_10, res_32, res_54, res_76};
128161}
129162
130163SmallVector<Value> upcastMxfp4 (RewriterBase &rewriter, UpcastMXFPOp upcastOp,
131164 bool toFp16, ArrayRef<Value> values) {
132- assert (values.size () % 2 == 0 );
165+ assert (values.size () % 4 == 0 );
133166 Location loc = upcastOp.getLoc ();
134167
135168 SmallVector<Value> results;
136169 results.reserve (values.size () * 2 );
137170 Type elemType = toFp16 ? f16_ty : bf16_ty;
138- for (int i = 0 ; i < values.size (); i += 2 ) {
139- SmallVector<Value, 2 > v4i32 =
140- upcast4xMxfp4 (rewriter, upcastOp, toFp16, values.slice (i, 2 ));
141- for (int j = 0 ; j < 2 ; j++) {
171+ for (int i = 0 ; i < values.size (); i += 4 ) {
172+ Value v0 = values[i];
173+ Value v1 = values[i + 1 ];
174+ Value v2 = values[i + 2 ];
175+ Value v3 = values[i + 3 ];
176+ Value packedVec = undef (vec_ty (i8_ty, 4 ));
177+ packedVec = insert_element (packedVec, v0, i32_val (0 ));
178+ packedVec = insert_element (packedVec, v1, i32_val (1 ));
179+ packedVec = insert_element (packedVec, v2, i32_val (2 ));
180+ packedVec = insert_element (packedVec, v3, i32_val (3 ));
181+ SmallVector<Value, 4 > v4i32 =
182+ upcast8xMxfp4 (rewriter, upcastOp, toFp16, packedVec);
183+ for (int j = 0 ; j < 4 ; j++) {
142184 Value elements = bitcast (v4i32[j], vec_ty (elemType, 2 ));
143185 results.push_back (extract_element (elements, i32_val (0 )));
144186 results.push_back (extract_element (elements, i32_val (1 )));
0 commit comments