Skip to content

Commit 73b9356

Browse files
[Intel] Port changes to support fp16 scaled dot (#3153)
These changes are ported from f9d9fad, but they are not enough to make fp16 scaled dot unit tests to pass, more investigation is needed to enable those test cases. Part of #3141. Signed-off-by: Whitney Tsang <[email protected]>
1 parent f7aaf04 commit 73b9356

File tree

2 files changed

+84
-19
lines changed

2 files changed

+84
-19
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/UpcastMXFPToLLVM.cpp

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,53 @@ using namespace mlir::triton::gpu::intel;
1616

1717
namespace {
1818

19+
SmallVector<Value> convertMxfp4x2ToFp16x2(RewriterBase &rewriter, Location loc,
20+
ArrayRef<Value> values) {
21+
SmallVector<Value> results;
22+
for (auto v : values) {
23+
auto em0 = and_(v, i8_val(0x7));
24+
auto em1 = and_(v, i8_val(0x70));
25+
// FP16 bits: sign = 1, exponent = 5, mantissa = 10
26+
Value v0 = or_(shl(zext(i16_ty, em0), i16_val(10 - 1)),
27+
shl(zext(i16_ty, and_(v, i8_val(0x8))), i16_val(12)));
28+
Value v1 = or_(shl(zext(i16_ty, em1), i16_val(10 - 1 - 4)),
29+
shl(zext(i16_ty, and_(v, i8_val(0x80))), i16_val(8)));
30+
31+
// Three cases:
32+
// 1) x is normal and non-zero: Correct bias
33+
v0 = select(icmp_ne(and_(em0, i8_val(0x6)), i8_val(0)),
34+
add(v0, i16_val((15 - 1) << 10)), v0);
35+
v1 = select(icmp_ne(and_(em1, i8_val(0x60)), i8_val(0)),
36+
add(v1, i16_val((15 - 1) << 10)), v1);
37+
38+
// 2) x is subnormal (x == 0bs001 where s is the sign): Map to fp16 +-0.5
39+
v0 = bitcast(select(icmp_eq(em0, i8_val(0x1)),
40+
or_(i16_val(0x3800), and_(v0, i16_val(0x8000))), v0),
41+
f16_ty);
42+
v1 = bitcast(select(icmp_eq(em1, i8_val(0x10)),
43+
or_(i16_val(0x3800), and_(v1, i16_val(0x8000))), v1),
44+
f16_ty);
45+
// 3) x is zero, nothing to do
46+
results.push_back(v0);
47+
results.push_back(v1);
48+
}
49+
return results;
50+
}
51+
52+
Value mxfpScaleFp16(ConversionPatternRewriter &rewriter, Location loc, Value v,
53+
Value scale, bool fastMath) {
54+
Value scaleF32 = bitcast(shl(zext(i32_ty, scale), i32_val(23)), f32_ty);
55+
Value scaleF16 = LLVM::intel::convertFp32ToFp16(loc, rewriter, scaleF32,
56+
RoundingMode::RTNE);
57+
Value mulF16 = fmul(v, scaleF16);
58+
if (fastMath)
59+
return mulF16;
60+
// Account for NaN in the scale as per the mxfp specification.
61+
Value scaleIsNan = icmp_eq(scale, i8_val(0xff));
62+
Value nanF16 = bitcast(i16_val(0x7c01), f16_ty);
63+
return select(scaleIsNan, nanF16, bitcast(mulF16, f16_ty));
64+
};
65+
1966
static Value mxfpScaleBf16(ConversionPatternRewriter &rewriter, Location loc,
2067
Value v, Value scale, bool fastMath) {
2168
Value vBf16 = bitcast(v, bf16_ty);
@@ -61,8 +108,11 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
61108
Value warpId = udiv(tid, warpSize);
62109
Value laneId = urem(tid, warpSize);
63110

64-
if (fpType == ScaleDotElemType::E2M1)
65-
xVals = LLVM::convertMxfp4x2ToBf16x2(rewriter, loc, xVals);
111+
bool useFp16 = op.getType().getElementType().isF16();
112+
if (fpType == ScaleDotElemType::E2M1) {
113+
xVals = useFp16 ? convertMxfp4x2ToFp16x2(rewriter, loc, xVals)
114+
: LLVM::convertMxfp4x2ToBf16x2(rewriter, loc, xVals);
115+
}
66116

67117
auto xType = cast<RankedTensorType>(op->getOperandTypes()[0]);
68118
auto dotEnc = cast<DotOperandEncodingAttr>(xType.getEncoding());
@@ -106,8 +156,11 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
106156
for (int k = 0; k < kWidth; ++k) {
107157
unsigned idx = i * scalingBlockSize + mxfp * mxfpSize +
108158
rep * subTileSize * kWidth + subTile * kWidth + k;
109-
xVals[idx] = mxfpScaleBf16(rewriter, loc, xVals[idx], si[subTile],
110-
op.getFastMath());
159+
xVals[idx] = useFp16
160+
? mxfpScaleFp16(rewriter, loc, xVals[idx],
161+
si[subTile], op.getFastMath())
162+
: mxfpScaleBf16(rewriter, loc, xVals[idx],
163+
si[subTile], op.getFastMath());
111164
}
112165
}
113166
}

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,8 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
227227
return elemType == tt::ScaleDotElemType::E2M1 ||
228228
elemType == tt::ScaleDotElemType::E4M3 ||
229229
elemType == tt::ScaleDotElemType::E5M2 ||
230-
elemType == tt::ScaleDotElemType::BF16;
230+
elemType == tt::ScaleDotElemType::BF16 ||
231+
elemType == tt::ScaleDotElemType::FP16;
231232
};
232233
if (!supportsTypes(aElemType) || !supportsTypes(bElemType))
233234
return rewriter.notifyMatchFailure(scaledDotOp, "NYI: mxfp6 operand");
@@ -263,27 +264,31 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
263264
assert((aDesc.scale || bDesc.scale) && "No scale provided");
264265
assert(!(aDesc.scale && bDesc.scale) && "NYI: Both LHS and RHS scale");
265266

267+
bool useFp16 = aDesc.elemType == tt::ScaleDotElemType::FP16 ||
268+
bDesc.elemType == tt::ScaleDotElemType::FP16;
269+
266270
if (aDesc.scale) {
267271
TensorValue newA =
268272
convertScaledOperand<ttgi::DpasEncodingAttr::OpIdx::OperandA>(
269-
aDesc, fastMath, dpasEnc, newRetType, mod, rewriter);
273+
aDesc, useFp16, fastMath, dpasEnc, newRetType, mod, rewriter);
270274
TensorValue newB =
271275
convertUnscaledOperand<ttgi::DpasEncodingAttr::OpIdx::OperandB>(
272-
bDesc, dpasEnc, newRetType, rewriter);
276+
bDesc, useFp16, dpasEnc, newRetType, rewriter);
273277
return {newA, newB};
274278
}
275279

276280
TensorValue newB =
277281
convertScaledOperand<ttgi::DpasEncodingAttr::OpIdx::OperandB>(
278-
bDesc, fastMath, dpasEnc, newRetType, mod, rewriter);
282+
bDesc, useFp16, fastMath, dpasEnc, newRetType, mod, rewriter);
279283
TensorValue newA =
280284
convertUnscaledOperand<ttgi::DpasEncodingAttr::OpIdx::OperandA>(
281-
aDesc, dpasEnc, newRetType, rewriter);
285+
aDesc, useFp16, dpasEnc, newRetType, rewriter);
282286
return {newA, newB};
283287
}
284288

285289
template <ttgi::DpasEncodingAttr::OpIdx opIdx>
286-
TensorValue convertScaledOperand(OpDescriptor opDesc, bool fastMath,
290+
TensorValue convertScaledOperand(OpDescriptor opDesc, bool useFp16,
291+
bool fastMath,
287292
ttg::intel::DpasEncodingAttr dpasEnc,
288293
RankedTensorType retType, ModuleOp mod,
289294
PatternRewriter &rewriter) const {
@@ -304,7 +309,7 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
304309
auto newOpEncoding = ttg::DotOperandEncodingAttr::get(
305310
ctx, unsigned(opIdx), opEncoding, opEncoding.getOpsPerChannel());
306311
TensorValue op =
307-
createArg(opDesc.op, opDesc.elemType, newOpEncoding, rewriter);
312+
createArg(opDesc.op, opDesc.elemType, useFp16, newOpEncoding, rewriter);
308313

309314
unsigned instrShapeM = dpasEnc.getDPASInstShapeA()[0];
310315
SmallVector<unsigned, 2> threadsPerWarp{instrShapeM,
@@ -332,7 +337,7 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
332337
}
333338

334339
template <ttgi::DpasEncodingAttr::OpIdx opIdx>
335-
TensorValue convertUnscaledOperand(OpDescriptor opDesc,
340+
TensorValue convertUnscaledOperand(OpDescriptor opDesc, bool useFp16,
336341
ttg::intel::DpasEncodingAttr dpasEnc,
337342
RankedTensorType retType,
338343
PatternRewriter &rewriter) const {
@@ -341,7 +346,8 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
341346
auto newOpEncoding = ttg::DotOperandEncodingAttr::get(
342347
opDesc.op.getContext(), unsigned(opIdx), dpasEnc,
343348
dpasEnc.getOpsPerChannel());
344-
return createArg(opDesc.op, opDesc.elemType, newOpEncoding, rewriter);
349+
return createArg(opDesc.op, opDesc.elemType, useFp16, newOpEncoding,
350+
rewriter);
345351
}
346352

347353
ttg::intel::DpasEncodingAttr
@@ -385,7 +391,7 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
385391
oldAcc);
386392
}
387393

388-
TensorValue createArg(TensorValue v, tt::ScaleDotElemType type,
394+
TensorValue createArg(TensorValue v, tt::ScaleDotElemType type, bool useFp16,
389395
Attribute vEncoding, PatternRewriter &rewriter) const {
390396
RankedTensorType vType = v.getType();
391397
auto newVType = RankedTensorType::get(vType.getShape(),
@@ -395,13 +401,16 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
395401

396402
// convert to bf16
397403
if (type != tt::ScaleDotElemType::E2M1 &&
398-
type != tt::ScaleDotElemType::BF16) {
404+
type != tt::ScaleDotElemType::BF16 &&
405+
type != tt::ScaleDotElemType::FP16) {
399406
assert(type == tt::ScaleDotElemType::E5M2 ||
400407
type == tt::ScaleDotElemType::E4M3);
401-
auto vTypeBf16 = RankedTensorType::get(
402-
newVType.getShape(), rewriter.getBF16Type(), newVType.getEncoding());
408+
auto upcastedType = RankedTensorType::get(
409+
newVType.getShape(),
410+
useFp16 ? rewriter.getF16Type() : rewriter.getBF16Type(),
411+
newVType.getEncoding());
403412
ret = cast<TypedValue<RankedTensorType>>(
404-
rewriter.create<tt::FpToFpOp>(v.getLoc(), vTypeBf16, ret)
413+
rewriter.create<tt::FpToFpOp>(v.getLoc(), upcastedType, ret)
405414
.getResult());
406415
}
407416
return ret;
@@ -423,8 +432,11 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
423432
if (!scale)
424433
return v;
425434

435+
Builder b(v.getContext());
436+
bool useFp16 = elemType == tt::ScaleDotElemType::FP16;
437+
Type outputElemType = useFp16 ? b.getF16Type() : b.getBF16Type();
426438
auto retTy = triton::gpu::intel::UpcastMXFPOp::deduceOutputType(
427-
v, elemType, Builder(v.getContext()).getBF16Type());
439+
v, elemType, outputElemType);
428440
return rewriter.create<ttgi::UpcastMXFPOp>(v.getLoc(), retTy, v, scale,
429441
elemType, fastMath);
430442
}

0 commit comments

Comments
 (0)