Skip to content

Commit 2f338af

Browse files
Improved performance of the fp4tofp conversion (#4299)
Use a simple lookup table instead of explicit conversion. Fixes #4298 This implementation creates 3 constants: ```llvm %32 = llvm.mlir.constant(dense<[0.000000e+00, 5.000000e-01, 1.000000e+00, 1.500000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 6.000000e+00, -0.000000e+00, -5.000000e-01, -1.000000e+00, -1.500000e+00, -2.000000e+00, -3.000000e+00, -4.000000e+00, -6.000000e+00]> : vector<16xbf16>) : vector<16xbf16> %33 = llvm.mlir.constant(4 : i8) : i8 %34 = llvm.mlir.constant(15 : i8) : i8 ``` and 4 operations per each pair of values: ```llvm %35 = llvm.and %0, %34 : i8 %36 = llvm.lshr %0, %33 : i8 %37 = llvm.extractelement %32[%35 : i8] : vector<16xbf16> %38 = llvm.extractelement %32[%36 : i8] : vector<16xbf16> ``` I've not compared the performance, but it seems more efficient than #4298 . Co-authored-by: Ettore Tiotto <[email protected]>
1 parent cc14d65 commit 2f338af

File tree

1 file changed

+40
-84
lines changed

1 file changed

+40
-84
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/Fp4ToFpOpToLLVM.cpp

Lines changed: 40 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -13,80 +13,6 @@ using namespace mlir::triton::gpu;
1313
using namespace mlir::triton::gpu::intel;
1414

1515
namespace {
16-
SmallVector<Value> convertMxfp4x2ToBf16x2(RewriterBase &rewriter, Location loc,
17-
ArrayRef<Value> values) {
18-
auto b = TritonLLVMOpBuilder(loc, rewriter);
19-
SmallVector<Value> results;
20-
for (auto v : values) {
21-
auto em0 = b.and_(v, b.i8_val(0x7));
22-
auto em1 = b.and_(v, b.i8_val(0x70));
23-
Value v0 =
24-
b.or_(b.shl(b.zext(i16_ty, em0), b.i16_val(6)),
25-
b.shl(b.zext(i16_ty, b.and_(v, b.i8_val(0x8))), b.i16_val(12)));
26-
Value v1 =
27-
b.or_(b.shl(b.zext(i16_ty, em1), b.i16_val(2)),
28-
b.shl(b.zext(i16_ty, b.and_(v, b.i8_val(0x80))), b.i16_val(8)));
29-
// Three cases:
30-
// 1) x is normal and non-zero: Correct bias
31-
v0 = b.select(b.icmp_ne(b.and_(em0, b.i8_val(0x6)), b.i8_val(0)),
32-
b.add(v0, b.i16_val((127 - 1) << 7)), v0);
33-
v1 = b.select(b.icmp_ne(b.and_(em1, b.i8_val(0x60)), b.i8_val(0)),
34-
b.add(v1, b.i16_val((127 - 1) << 7)), v1);
35-
// 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in
36-
// bf16
37-
v0 = b.bitcast(
38-
b.select(b.icmp_eq(em0, b.i8_val(0x1)),
39-
b.or_(b.i16_val(16128), b.and_(v0, b.i16_val(0x8000))), v0),
40-
bf16_ty);
41-
v1 = b.bitcast(
42-
b.select(b.icmp_eq(em1, b.i8_val(0x10)),
43-
b.or_(b.i16_val(16128), b.and_(v1, b.i16_val(0x8000))), v1),
44-
bf16_ty);
45-
// 3) x is zero, nothing to do
46-
results.push_back(v0);
47-
results.push_back(v1);
48-
}
49-
return results;
50-
}
51-
52-
SmallVector<Value> convertMxfp4x2ToFp16x2(RewriterBase &rewriter, Location loc,
53-
ArrayRef<Value> values) {
54-
auto b = TritonLLVMOpBuilder(loc, rewriter);
55-
SmallVector<Value> results;
56-
for (auto v : values) {
57-
auto em0 = b.and_(v, b.i8_val(0x7));
58-
auto em1 = b.and_(v, b.i8_val(0x70));
59-
// FP16 bits: sign = 1, exponent = 5, mantissa = 10
60-
Value v0 =
61-
b.or_(b.shl(b.zext(i16_ty, em0), b.i16_val(10 - 1)),
62-
b.shl(b.zext(i16_ty, b.and_(v, b.i8_val(0x8))), b.i16_val(12)));
63-
Value v1 =
64-
b.or_(b.shl(b.zext(i16_ty, em1), b.i16_val(10 - 1 - 4)),
65-
b.shl(b.zext(i16_ty, b.and_(v, b.i8_val(0x80))), b.i16_val(8)));
66-
67-
// Three cases:
68-
// 1) x is normal and non-zero: Correct bias
69-
v0 = b.select(b.icmp_ne(b.and_(em0, b.i8_val(0x6)), b.i8_val(0)),
70-
b.add(v0, b.i16_val((15 - 1) << 10)), v0);
71-
v1 = b.select(b.icmp_ne(b.and_(em1, b.i8_val(0x60)), b.i8_val(0)),
72-
b.add(v1, b.i16_val((15 - 1) << 10)), v1);
73-
74-
// 2) x is subnormal (x == 0bs001 where s is the sign): Map to fp16 +-0.5
75-
v0 = b.bitcast(
76-
b.select(b.icmp_eq(em0, b.i8_val(0x1)),
77-
b.or_(b.i16_val(0x3800), b.and_(v0, b.i16_val(0x8000))), v0),
78-
f16_ty);
79-
v1 = b.bitcast(
80-
b.select(b.icmp_eq(em1, b.i8_val(0x10)),
81-
b.or_(b.i16_val(0x3800), b.and_(v1, b.i16_val(0x8000))), v1),
82-
f16_ty);
83-
// 3) x is zero, nothing to do
84-
results.push_back(v0);
85-
results.push_back(v1);
86-
}
87-
return results;
88-
}
89-
9016
class Fp4ToFpOpPattern : public ConvertOpToLLVMPattern<Fp4ToFpOp> {
9117
public:
9218
Fp4ToFpOpPattern(LLVMTypeConverter &typeConverter, PatternBenefit benefit)
@@ -96,21 +22,51 @@ class Fp4ToFpOpPattern : public ConvertOpToLLVMPattern<Fp4ToFpOp> {
9622
matchAndRewrite(Fp4ToFpOp op, OpAdaptor adaptor,
9723
ConversionPatternRewriter &rewriter) const override {
9824
Location loc = op.getLoc();
99-
auto *ctx = op.getContext();
10025
Type elemType = op.getType().getElementType();
10126
assert(elemType == f16_ty || elemType == bf16_ty);
102-
bool toFp16 = elemType == f16_ty;
103-
104-
SmallVector<Value> xVals =
105-
unpackLLElements(loc, adaptor.getSrc(), rewriter);
106-
xVals = toFp16 ? convertMxfp4x2ToFp16x2(rewriter, loc, xVals)
107-
: convertMxfp4x2ToBf16x2(rewriter, loc, xVals);
10827

109-
Value result =
110-
packLLElements(loc, getTypeConverter(), xVals, rewriter, op.getType());
111-
rewriter.replaceOp(op, result);
28+
SmallVector<Value> results;
29+
{
30+
SmallVector<Value> xVals =
31+
unpackLLElements(loc, adaptor.getSrc(), rewriter);
32+
convertMxfp4x2ToFloat(rewriter, loc, xVals, results,
33+
elemType == f16_ty ? f16_ty : bf16_ty);
34+
}
35+
rewriter.replaceOp(op, packLLElements(loc, getTypeConverter(), results,
36+
rewriter, op.getType()));
11237
return success();
11338
}
39+
40+
private:
41+
static void convertMxfp4x2ToFloat(RewriterBase &rewriter, Location loc,
42+
SmallVector<Value> &values,
43+
SmallVector<Value> &results,
44+
FloatType floatTy) {
45+
assert(results.empty() && !values.empty());
46+
47+
Value table;
48+
{ // Create a constant vector containing all the possible values
49+
auto vecTy = VectorType::get({16}, floatTy);
50+
SmallVector<Attribute, 16> values;
51+
for (double v : {0., 0.5, 1., 1.5, 2., 3., 4., 6., -0., -0.5, -1., -1.5,
52+
-2., -3., -4., -6.})
53+
values.push_back(rewriter.getFloatAttr(floatTy, v));
54+
table = rewriter.create<LLVM::ConstantOp>(
55+
loc, vecTy, DenseElementsAttr::get(vecTy, values));
56+
}
57+
58+
TritonLLVMOpBuilder b(loc, rewriter);
59+
Value i8_4 = b.i8_val(4);
60+
Value i8_15 = b.i8_val(15);
61+
results.reserve(values.size() * 2);
62+
for (Value v : values) {
63+
// The first and last 4 bits are the values indices in the table
64+
Value idx1 = b.and_(v, i8_15);
65+
Value idx2 = b.lshr(v, i8_4);
66+
results.push_back(b.extract_element(table, idx1));
67+
results.push_back(b.extract_element(table, idx2));
68+
}
69+
}
11470
};
11571
} // anonymous namespace
11672

0 commit comments

Comments
 (0)