|
19 | 19 | using namespace mlir;
|
20 | 20 | using namespace mlir::triton;
|
21 | 21 | using namespace mlir::triton::gpu;
|
| 22 | +using ::mlir::LLVM::AMD::upcast8xMxfp4_SW; |
22 | 23 |
|
23 | 24 | namespace {
|
24 | 25 |
|
25 |
| -SmallVector<Value, 4> upcast8xMxfp4_SW(RewriterBase &rewriter, |
26 |
| - amdgpu::UpcastMXFPOp upcastOp, |
27 |
| - bool tofp16, Value packedVec) { |
28 |
| - Location loc = upcastOp.getLoc(); |
29 |
| - auto b = TritonLLVMOpBuilder(loc, rewriter); |
30 |
| - |
31 |
| - // MXFP4 has 4 bits, S.EE.M, for Sign, Exponent, and Mantissa respectively. |
32 |
| - // For a specific S, we have a total of 8 bit patterns. We can encode all |
33 |
| - // these 8 resultant bf16/fp16 bit patterns in a lookup table (LUT). It |
34 |
| - // happens that llvm.amdgcn.perm supports selecting 4 bytes from 8 input bytes |
35 |
| - // using a 4-byte selector. So the overall idea is to use llvm.amdgcn.perm to |
36 |
| - // implement such a LUT; though we need to select the two bytes for the |
37 |
| - // resultant bf16/fp16 bit patterns separately. For the byte containing S, we |
38 |
| - // also need to handle the S and E bits separately. |
39 |
| - |
40 |
| - // FP4 has 4 bits: S.EE.M. Bf16/fp16 bit patterns for positive values: |
41 |
| - // |
42 |
| - // FP4 | BF16 | FP16 | Value |
43 |
| - // ------ | ------ | ------ | ----- |
44 |
| - // 0.00.0 | 0x0000 | 0x0000 | + 0.0 |
45 |
| - // 0.00.1 | 0x3f00 | 0x3800 | + 0.5 |
46 |
| - // 0.01.0 | 0x3f80 | 0x3c00 | + 1.0 |
47 |
| - // 0.01.1 | 0x3fc0 | 0x3e00 | + 1.5 |
48 |
| - // 0.10.0 | 0x4000 | 0x4000 | + 2.0 |
49 |
| - // 0.10.1 | 0x4040 | 0x4200 | + 3.0 |
50 |
| - // 0.11.0 | 0x4080 | 0x4400 | + 4.0 |
51 |
| - // 0.11.1 | 0x40c0 | 0x4600 | + 6.0 |
52 |
| - // |
53 |
| - // Encode Byte #0 (M) for BF16/FP16 in a LUT. |
54 |
| - Value resB0LutLo = tofp16 ? b.i32_val(0) : b.i32_val(0xc0800000); |
55 |
| - Value resB0LutHi = tofp16 ? b.i32_val(0) : b.i32_val(0xc0804000); |
56 |
| - // Encode Byte #1 (EM, non-S part) for BF16/FP16 in a LUT. |
57 |
| - Value resB1LutLoNoS = tofp16 ? b.i32_val(0x3e3c3800) : b.i32_val(0x3f3f3f00); |
58 |
| - Value resB1LutHiNoS = tofp16 ? b.i32_val(0x46444240) : b.i32_val(0x40404040); |
59 |
| - |
60 |
| - Type i32Ty = rewriter.getI32Type(); |
61 |
| - auto permU32FnTy = LLVM::LLVMFunctionType::get(i32Ty, {i32Ty, i32Ty, i32Ty}); |
62 |
| - LLVM::LLVMFuncOp funcOp = appendOrGetExternFuncOp( |
63 |
| - rewriter, upcastOp, "llvm.amdgcn.perm", permU32FnTy); |
64 |
| - |
65 |
| - // Start with 8 mxfp4 elements in a single i32 register |
66 |
| - // | e7e6 | e5e4 | e3e2 | e1e0 | |
67 |
| - Value input = b.bitcast(packedVec, i32Ty); |
68 |
| - |
69 |
| - // Step 1: extract EM bits for elements 0,2,4,6 and 1,3,5,7 respectively. |
70 |
| - // e2m1_6420_idx = | 0[0e6EM] | 0[0e4EM] | 0[0e2EM] | 0[0e0EM] | |
71 |
| - Value e2m1_6420_idx = b.and_(input, b.i32_val(0x07070707)); |
72 |
| - // e2m1_7531_idx = | [0e7EM]0 | [0e5EM]0 | [0e3EM]0 | [0e1EM]0 | |
73 |
| - Value e2m1_7531_idx = b.and_(input, b.i32_val(0x70707070)); |
74 |
| - // e2m1_7531_idx = | 0[0e7EM] | 0[0e5EM] | 0[0e3EM] | 0[0e1EM] | |
75 |
| - e2m1_7531_idx = b.lshr(e2m1_7531_idx, b.i32_val(4)); |
76 |
| - |
77 |
| - // Step 2: extract S bit for elements 0,2,4,6 and 1,3,5,7 |
78 |
| - // s_6420 = | 0[e6S000] | 0[e4S000] | 0[e2S000] | 0[e0S000] | |
79 |
| - Value s_6420 = b.and_(input, b.i32_val(0x08080808)); |
80 |
| - // s_6420 = | [e6S000]0 | [e4S000]0 | [e2S000]0 | [e0S000]0 | |
81 |
| - s_6420 = b.shl(s_6420, b.i32_val(4)); |
82 |
| - // s_7531 = | [e7S000]0 | [e5S000]0 | [e3S000]0 | [e1S000]0 | |
83 |
| - Value s_7531 = b.and_(input, b.i32_val(0x80808080)); |
84 |
| - |
85 |
| - // Step 3: Upcast elements 0,2,4,6 to 4 16-bit elements |
86 |
| - // Select Byte #0. It's always 0 if upcasting to fp16. |
87 |
| - // resB0_6420 = | e6B0 | e4B0 | e2B0 | e0B0 | |
88 |
| - Value resB0_6420 = b.i32_val(0); |
89 |
| - if (!tofp16) { |
90 |
| - resB0_6420 = LLVM::createLLVMCallOp(rewriter, loc, funcOp, |
91 |
| - {resB0LutHi, resB0LutLo, e2m1_6420_idx}) |
92 |
| - .getResult(); |
93 |
| - } |
94 |
| - // Select Byte #1 |
95 |
| - Value resB1NoS_6420 = |
96 |
| - LLVM::createLLVMCallOp(rewriter, loc, funcOp, |
97 |
| - {resB1LutHiNoS, resB1LutLoNoS, e2m1_6420_idx}) |
98 |
| - .getResult(); |
99 |
| - // resB1_6420 = | e6B1 | e4B1 | e2B1 | e0B1 | |
100 |
| - Value resB1_6420 = b.or_(resB1NoS_6420, s_6420); |
101 |
| - // Construct 16-bit values of e0 and e2 |
102 |
| - // res_20 = | e2B1 | e2B0 | e0B1 | e0B0 | = | e2_f16 | e0_f16 | |
103 |
| - Value res_20 = |
104 |
| - LLVM::createLLVMCallOp(rewriter, loc, funcOp, |
105 |
| - {resB1_6420, resB0_6420, b.i32_val(0x05010400)}) |
106 |
| - .getResult(); |
107 |
| - // Construct 16-bit values of e4 and e6 |
108 |
| - // res_64 = | e6B1 | e6B0 | e4B1 | e4B0 | = | e6_f16 | e4_f16 | |
109 |
| - Value res_64 = |
110 |
| - LLVM::createLLVMCallOp(rewriter, loc, funcOp, |
111 |
| - {resB1_6420, resB0_6420, b.i32_val(0x07030602)}) |
112 |
| - .getResult(); |
113 |
| - |
114 |
| - // Step 4: Upcast elements 1,3,5,7 to 4 16-bit elements |
115 |
| - // This is a copy of step 3 on different group of elements |
116 |
| - // Select Byte #0. It's always 0 if upcasting to fp16. |
117 |
| - // resB0_7531 = | e7B0 | e5B0 | e3B0 | e1B0 | |
118 |
| - Value resB0_7531 = b.i32_val(0); |
119 |
| - if (!tofp16) { |
120 |
| - resB0_7531 = LLVM::createLLVMCallOp(rewriter, loc, funcOp, |
121 |
| - {resB0LutHi, resB0LutLo, e2m1_7531_idx}) |
122 |
| - .getResult(); |
123 |
| - } |
124 |
| - // Select Byte #1 |
125 |
| - Value resB1NoS_7531 = |
126 |
| - LLVM::createLLVMCallOp(rewriter, loc, funcOp, |
127 |
| - {resB1LutHiNoS, resB1LutLoNoS, e2m1_7531_idx}) |
128 |
| - .getResult(); |
129 |
| - // resB1_7531 = | e7B1 | e5B1 | e3B1 | e1B1 | |
130 |
| - Value resB1_7531 = b.or_(resB1NoS_7531, s_7531); |
131 |
| - // Construct 16-bit values of e1 and e3 |
132 |
| - // res_31 = | e3B1 | e3B0 | e1B1 | e1B0 | = | e3_f16 | e1_f16 | |
133 |
| - Value res_31 = |
134 |
| - LLVM::createLLVMCallOp(rewriter, loc, funcOp, |
135 |
| - {resB1_7531, resB0_7531, b.i32_val(0x05010400)}) |
136 |
| - .getResult(); |
137 |
| - // Construct 16-bit values of e5 and e7 |
138 |
| - // res_75 = | e7B1 | e7B0 | e5B1 | e5B0 | = | e7_f16 | e5_f16 | |
139 |
| - Value res_75 = |
140 |
| - LLVM::createLLVMCallOp(rewriter, loc, funcOp, |
141 |
| - {resB1_7531, resB0_7531, b.i32_val(0x07030602)}) |
142 |
| - .getResult(); |
143 |
| - |
144 |
| - // Step 5: Reorder 16-bit elements to be 0,1,2,3,4,5,6,7 |
145 |
| - // res_10 = | e1_f16 | e0_f16 | |
146 |
| - Value res_10 = LLVM::createLLVMCallOp(rewriter, loc, funcOp, |
147 |
| - {res_31, res_20, b.i32_val(0x05040100)}) |
148 |
| - .getResult(); |
149 |
| - // res_32 = | e3_f16 | e2_f16 | |
150 |
| - Value res_32 = LLVM::createLLVMCallOp(rewriter, loc, funcOp, |
151 |
| - {res_31, res_20, b.i32_val(0x07060302)}) |
152 |
| - .getResult(); |
153 |
| - // res_54 = | e5_f16 | e4_f16 | |
154 |
| - Value res_54 = LLVM::createLLVMCallOp(rewriter, loc, funcOp, |
155 |
| - {res_75, res_64, b.i32_val(0x05040100)}) |
156 |
| - .getResult(); |
157 |
| - // res_76 = | e7_f16 | e6_f16 | |
158 |
| - Value res_76 = LLVM::createLLVMCallOp(rewriter, loc, funcOp, |
159 |
| - {res_75, res_64, b.i32_val(0x07060302)}) |
160 |
| - .getResult(); |
161 |
| - |
162 |
| - return {res_10, res_32, res_54, res_76}; |
163 |
| -} |
164 |
| - |
165 | 26 | SmallVector<Value, 8> upcastMxfp4_SW(RewriterBase &rewriter,
|
166 | 27 | amdgpu::UpcastMXFPOp upcastOp, bool toFp16,
|
167 | 28 | ArrayRef<Value> values, int idx) {
|
|
0 commit comments