@@ -53,12 +53,14 @@ SmallVector<Value> convertMxfp4x2ToFp16x2(RewriterBase &rewriter, Location loc,
5353 return results;
5454}
5555
56- Value mxfpScaleFp16 (RewriterBase &rewriter, Location loc, Value v,
57- Value scale ) {
56+ Value mxfpScaleFp16 (RewriterBase &rewriter, Location loc, Value v, Value scale,
57+ bool fastMath ) {
5858 Value scaleF32 = bitcast (shl (zext (i32_ty, scale), i32_val (23 )), f32_ty);
5959 Value scaleF16 =
6060 LLVM::AMD::cvtFp32ToFp16 (loc, rewriter, scaleF32, RoundingMode::RTNE);
6161 Value mulF16 = fmul (v, scaleF16);
62+ if (fastMath)
63+ return mulF16;
6264 // Account for NaN in the scale as per the mxfp specification.
6365 Value scaleIsNan = icmp_eq (scale, i8_val (0xff ));
6466 Value nanF16 = bitcast (i16_val (0x7c01 ), f16_ty);
@@ -72,16 +74,19 @@ Value mxfpScaleFp16(RewriterBase &rewriter, Location loc, Value v,
7274// handles v * scale multiplication using fp32 VALU ops. LLVM backend can do it
7375// for us, just with unnecessary overheads.
7476Value mxfpScaleBf16ViaF32 (RewriterBase &rewriter, Location loc, Value v,
75- Value scale) {
77+ Value scale, bool fastMath ) {
7678 Value c16 = i32_val (16 );
7779 Value vF32 = bitcast (shl (zext (i32_ty, bitcast (v, i16_ty)), c16), f32_ty);
7880 Value scaleF32 = bitcast (shl (zext (i32_ty, scale), i32_val (23 )), f32_ty);
7981 Value mulF32 = fmul (vF32, scaleF32);
8082 Value mulI16 = trunc (i16_ty, lshr (bitcast (mulF32, i32_ty), c16));
83+ Value mulBf16 = bitcast (mulI16, bf16_ty);
84+ if (fastMath)
85+ return mulBf16;
8186 // Account for NaN in the scale as per the mxfp specification.
8287 Value scaleIsNan = icmp_eq (scale, i8_val (0xff ));
8388 Value nanBf16 = bitcast (i16_val (0x7fff ), bf16_ty);
84- return select (scaleIsNan, nanBf16, bitcast (mulI16, bf16_ty) );
89+ return select (scaleIsNan, nanBf16, mulBf16 );
8590};
8691
8792class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern <UpcastMXFPOp> {
@@ -166,9 +171,10 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
166171 for (int j = 0 ; j < 32 ; ++j) {
167172 int index = 32 * i + j;
168173 xVals[index] =
169- useFp16 ? mxfpScaleFp16 (rewriter, loc, xVals[index], si[j / 16 ])
174+ useFp16 ? mxfpScaleFp16 (rewriter, loc, xVals[index], si[j / 16 ],
175+ op.getFastMath ())
170176 : mxfpScaleBf16ViaF32 (rewriter, loc, xVals[index],
171- si[j / 16 ]);
177+ si[j / 16 ], op. getFastMath () );
172178 }
173179 }
174180 } else {
@@ -190,10 +196,11 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
190196
191197 for (int j = 0 ; j < 32 ; ++j) {
192198 int index = 32 * i + j;
193- xVals[index] =
194- useFp16
195- ? mxfpScaleFp16 (rewriter, loc, xVals[index], si[j / 16 ])
196- : mxfpScaleBf16ViaF32 (rewriter, loc, xVals[index], si[j / 8 ]);
199+ xVals[index] = useFp16
200+ ? mxfpScaleFp16 (rewriter, loc, xVals[index],
201+ si[j / 16 ], op.getFastMath ())
202+ : mxfpScaleBf16ViaF32 (rewriter, loc, xVals[index],
203+ si[j / 8 ], op.getFastMath ());
197204 }
198205 }
199206 }
0 commit comments