@@ -17,10 +17,8 @@ using namespace mlir::triton::gpu;
1717namespace {
1818
1919static Value mxfpScaleBf16 (ConversionPatternRewriter &rewriter, Location loc,
20- Value v, Value scale) {
20+ Value v, Value scale, bool fastMath ) {
2121 Value vBf16 = bitcast (v, bf16_ty);
22- Value nanBf16 = bitcast (i16_val (0x7fff ), bf16_ty);
23- Value scaleIsNan = icmp_eq (scale, i8_val (0xff ));
2422 Value scaleBf16 = bitcast (shl (zext (i16_ty, scale), i16_val (7 )), bf16_ty);
2523
2624 Value v0 = mlir::triton::intel::convertBf16ToFp32 (loc, rewriter, vBf16);
@@ -29,7 +27,11 @@ static Value mxfpScaleBf16(ConversionPatternRewriter &rewriter, Location loc,
2927 auto undefRounding = static_cast <mlir::triton::RoundingMode>(-1 );
3028 Value scaledBf16 = mlir::triton::intel::convertFp32ToBf16 (
3129 loc, rewriter, result, undefRounding);
30+ if (fastMath)
31+ return scaledBf16;
3232 // Account for NaN in the scale as per the mxfp specification.
33+ Value scaleIsNan = icmp_eq (scale, i8_val (0xff ));
34+ Value nanBf16 = bitcast (i16_val (0x7fff ), bf16_ty);
3335 return select (scaleIsNan, nanBf16, scaledBf16);
3436};
3537
@@ -104,8 +106,8 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
104106 for (int k = 0 ; k < kWidth ; ++k) {
105107 unsigned idx = i * scalingBlockSize + mxfp * mxfpSize +
106108 rep * subTileSize * kWidth + subTile * kWidth + k;
107- xVals[idx] =
108- mxfpScaleBf16 (rewriter, loc, xVals[idx], si[subTile] );
109+ xVals[idx] = mxfpScaleBf16 (rewriter, loc, xVals[idx], si[subTile],
110+ op. getFastMath () );
109111 }
110112 }
111113 }
0 commit comments