@@ -13,80 +13,6 @@ using namespace mlir::triton::gpu;
13
13
using namespace mlir ::triton::gpu::intel;
14
14
15
15
namespace {
16
- SmallVector<Value> convertMxfp4x2ToBf16x2 (RewriterBase &rewriter, Location loc,
17
- ArrayRef<Value> values) {
18
- auto b = TritonLLVMOpBuilder (loc, rewriter);
19
- SmallVector<Value> results;
20
- for (auto v : values) {
21
- auto em0 = b.and_ (v, b.i8_val (0x7 ));
22
- auto em1 = b.and_ (v, b.i8_val (0x70 ));
23
- Value v0 =
24
- b.or_ (b.shl (b.zext (i16_ty, em0), b.i16_val (6 )),
25
- b.shl (b.zext (i16_ty, b.and_ (v, b.i8_val (0x8 ))), b.i16_val (12 )));
26
- Value v1 =
27
- b.or_ (b.shl (b.zext (i16_ty, em1), b.i16_val (2 )),
28
- b.shl (b.zext (i16_ty, b.and_ (v, b.i8_val (0x80 ))), b.i16_val (8 )));
29
- // Three cases:
30
- // 1) x is normal and non-zero: Correct bias
31
- v0 = b.select (b.icmp_ne (b.and_ (em0, b.i8_val (0x6 )), b.i8_val (0 )),
32
- b.add (v0, b.i16_val ((127 - 1 ) << 7 )), v0);
33
- v1 = b.select (b.icmp_ne (b.and_ (em1, b.i8_val (0x60 )), b.i8_val (0 )),
34
- b.add (v1, b.i16_val ((127 - 1 ) << 7 )), v1);
35
- // 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in
36
- // bf16
37
- v0 = b.bitcast (
38
- b.select (b.icmp_eq (em0, b.i8_val (0x1 )),
39
- b.or_ (b.i16_val (16128 ), b.and_ (v0, b.i16_val (0x8000 ))), v0),
40
- bf16_ty);
41
- v1 = b.bitcast (
42
- b.select (b.icmp_eq (em1, b.i8_val (0x10 )),
43
- b.or_ (b.i16_val (16128 ), b.and_ (v1, b.i16_val (0x8000 ))), v1),
44
- bf16_ty);
45
- // 3) x is zero, nothing to do
46
- results.push_back (v0);
47
- results.push_back (v1);
48
- }
49
- return results;
50
- }
51
-
52
- SmallVector<Value> convertMxfp4x2ToFp16x2 (RewriterBase &rewriter, Location loc,
53
- ArrayRef<Value> values) {
54
- auto b = TritonLLVMOpBuilder (loc, rewriter);
55
- SmallVector<Value> results;
56
- for (auto v : values) {
57
- auto em0 = b.and_ (v, b.i8_val (0x7 ));
58
- auto em1 = b.and_ (v, b.i8_val (0x70 ));
59
- // FP16 bits: sign = 1, exponent = 5, mantissa = 10
60
- Value v0 =
61
- b.or_ (b.shl (b.zext (i16_ty, em0), b.i16_val (10 - 1 )),
62
- b.shl (b.zext (i16_ty, b.and_ (v, b.i8_val (0x8 ))), b.i16_val (12 )));
63
- Value v1 =
64
- b.or_ (b.shl (b.zext (i16_ty, em1), b.i16_val (10 - 1 - 4 )),
65
- b.shl (b.zext (i16_ty, b.and_ (v, b.i8_val (0x80 ))), b.i16_val (8 )));
66
-
67
- // Three cases:
68
- // 1) x is normal and non-zero: Correct bias
69
- v0 = b.select (b.icmp_ne (b.and_ (em0, b.i8_val (0x6 )), b.i8_val (0 )),
70
- b.add (v0, b.i16_val ((15 - 1 ) << 10 )), v0);
71
- v1 = b.select (b.icmp_ne (b.and_ (em1, b.i8_val (0x60 )), b.i8_val (0 )),
72
- b.add (v1, b.i16_val ((15 - 1 ) << 10 )), v1);
73
-
74
- // 2) x is subnormal (x == 0bs001 where s is the sign): Map to fp16 +-0.5
75
- v0 = b.bitcast (
76
- b.select (b.icmp_eq (em0, b.i8_val (0x1 )),
77
- b.or_ (b.i16_val (0x3800 ), b.and_ (v0, b.i16_val (0x8000 ))), v0),
78
- f16_ty);
79
- v1 = b.bitcast (
80
- b.select (b.icmp_eq (em1, b.i8_val (0x10 )),
81
- b.or_ (b.i16_val (0x3800 ), b.and_ (v1, b.i16_val (0x8000 ))), v1),
82
- f16_ty);
83
- // 3) x is zero, nothing to do
84
- results.push_back (v0);
85
- results.push_back (v1);
86
- }
87
- return results;
88
- }
89
-
90
16
class Fp4ToFpOpPattern : public ConvertOpToLLVMPattern <Fp4ToFpOp> {
91
17
public:
92
18
Fp4ToFpOpPattern (LLVMTypeConverter &typeConverter, PatternBenefit benefit)
@@ -96,21 +22,51 @@ class Fp4ToFpOpPattern : public ConvertOpToLLVMPattern<Fp4ToFpOp> {
96
22
matchAndRewrite (Fp4ToFpOp op, OpAdaptor adaptor,
97
23
ConversionPatternRewriter &rewriter) const override {
98
24
Location loc = op.getLoc ();
99
- auto *ctx = op.getContext ();
100
25
Type elemType = op.getType ().getElementType ();
101
26
assert (elemType == f16_ty || elemType == bf16_ty);
102
- bool toFp16 = elemType == f16_ty;
103
-
104
- SmallVector<Value> xVals =
105
- unpackLLElements (loc, adaptor.getSrc (), rewriter);
106
- xVals = toFp16 ? convertMxfp4x2ToFp16x2 (rewriter, loc, xVals)
107
- : convertMxfp4x2ToBf16x2 (rewriter, loc, xVals);
108
27
109
- Value result =
110
- packLLElements (loc, getTypeConverter (), xVals, rewriter, op.getType ());
111
- rewriter.replaceOp (op, result);
28
+ SmallVector<Value> results;
29
+ {
30
+ SmallVector<Value> xVals =
31
+ unpackLLElements (loc, adaptor.getSrc (), rewriter);
32
+ convertMxfp4x2ToFloat (rewriter, loc, xVals, results,
33
+ elemType == f16_ty ? f16_ty : bf16_ty);
34
+ }
35
+ rewriter.replaceOp (op, packLLElements (loc, getTypeConverter (), results,
36
+ rewriter, op.getType ()));
112
37
return success ();
113
38
}
39
+
40
+ private:
41
+ static void convertMxfp4x2ToFloat (RewriterBase &rewriter, Location loc,
42
+ SmallVector<Value> &values,
43
+ SmallVector<Value> &results,
44
+ FloatType floatTy) {
45
+ assert (results.empty () && !values.empty ());
46
+
47
+ Value table;
48
+ { // Create a constant vector containing all the possible values
49
+ auto vecTy = VectorType::get ({16 }, floatTy);
50
+ SmallVector<Attribute, 16 > values;
51
+ for (double v : {0 ., 0.5 , 1 ., 1.5 , 2 ., 3 ., 4 ., 6 ., -0 ., -0.5 , -1 ., -1.5 ,
52
+ -2 ., -3 ., -4 ., -6 .})
53
+ values.push_back (rewriter.getFloatAttr (floatTy, v));
54
+ table = rewriter.create <LLVM::ConstantOp>(
55
+ loc, vecTy, DenseElementsAttr::get (vecTy, values));
56
+ }
57
+
58
+ TritonLLVMOpBuilder b (loc, rewriter);
59
+ Value i8_4 = b.i8_val (4 );
60
+ Value i8_15 = b.i8_val (15 );
61
+ results.reserve (values.size () * 2 );
62
+ for (Value v : values) {
63
+ // The first and last 4 bits are the values indices in the table
64
+ Value idx1 = b.and_ (v, i8_15);
65
+ Value idx2 = b.lshr (v, i8_4);
66
+ results.push_back (b.extract_element (table, idx1));
67
+ results.push_back (b.extract_element (table, idx2));
68
+ }
69
+ }
114
70
};
115
71
} // anonymous namespace
116
72
0 commit comments