Skip to content

Commit 9649f71

Browse files
authored
[AMD] Bypass NaN check for fast math scaled dot (triton-lang#5584)
Following triton-lang#5582.
1 parent f7e6775 commit 9649f71

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,14 @@ SmallVector<Value> convertMxfp4x2ToFp16x2(RewriterBase &rewriter, Location loc,
5353
return results;
5454
}
5555

56-
Value mxfpScaleFp16(RewriterBase &rewriter, Location loc, Value v,
57-
Value scale) {
56+
Value mxfpScaleFp16(RewriterBase &rewriter, Location loc, Value v, Value scale,
57+
bool fastMath) {
5858
Value scaleF32 = bitcast(shl(zext(i32_ty, scale), i32_val(23)), f32_ty);
5959
Value scaleF16 =
6060
LLVM::AMD::cvtFp32ToFp16(loc, rewriter, scaleF32, RoundingMode::RTNE);
6161
Value mulF16 = fmul(v, scaleF16);
62+
if (fastMath)
63+
return mulF16;
6264
// Account for NaN in the scale as per the mxfp specification.
6365
Value scaleIsNan = icmp_eq(scale, i8_val(0xff));
6466
Value nanF16 = bitcast(i16_val(0x7c01), f16_ty);
@@ -72,16 +74,19 @@ Value mxfpScaleFp16(RewriterBase &rewriter, Location loc, Value v,
7274
// handles v * scale multiplication using fp32 VALU ops. LLVM backend can do it
7375
// for us, just with unnecessary overheads.
7476
Value mxfpScaleBf16ViaF32(RewriterBase &rewriter, Location loc, Value v,
75-
Value scale) {
77+
Value scale, bool fastMath) {
7678
Value c16 = i32_val(16);
7779
Value vF32 = bitcast(shl(zext(i32_ty, bitcast(v, i16_ty)), c16), f32_ty);
7880
Value scaleF32 = bitcast(shl(zext(i32_ty, scale), i32_val(23)), f32_ty);
7981
Value mulF32 = fmul(vF32, scaleF32);
8082
Value mulI16 = trunc(i16_ty, lshr(bitcast(mulF32, i32_ty), c16));
83+
Value mulBf16 = bitcast(mulI16, bf16_ty);
84+
if (fastMath)
85+
return mulBf16;
8186
// Account for NaN in the scale as per the mxfp specification.
8287
Value scaleIsNan = icmp_eq(scale, i8_val(0xff));
8388
Value nanBf16 = bitcast(i16_val(0x7fff), bf16_ty);
84-
return select(scaleIsNan, nanBf16, bitcast(mulI16, bf16_ty));
89+
return select(scaleIsNan, nanBf16, mulBf16);
8590
};
8691

8792
class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
@@ -166,9 +171,10 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
166171
for (int j = 0; j < 32; ++j) {
167172
int index = 32 * i + j;
168173
xVals[index] =
169-
useFp16 ? mxfpScaleFp16(rewriter, loc, xVals[index], si[j / 16])
174+
useFp16 ? mxfpScaleFp16(rewriter, loc, xVals[index], si[j / 16],
175+
op.getFastMath())
170176
: mxfpScaleBf16ViaF32(rewriter, loc, xVals[index],
171-
si[j / 16]);
177+
si[j / 16], op.getFastMath());
172178
}
173179
}
174180
} else {
@@ -190,10 +196,11 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
190196

191197
for (int j = 0; j < 32; ++j) {
192198
int index = 32 * i + j;
193-
xVals[index] =
194-
useFp16
195-
? mxfpScaleFp16(rewriter, loc, xVals[index], si[j / 16])
196-
: mxfpScaleBf16ViaF32(rewriter, loc, xVals[index], si[j / 8]);
199+
xVals[index] = useFp16
200+
? mxfpScaleFp16(rewriter, loc, xVals[index],
201+
si[j / 16], op.getFastMath())
202+
: mxfpScaleBf16ViaF32(rewriter, loc, xVals[index],
203+
si[j / 8], op.getFastMath());
197204
}
198205
}
199206
}

0 commit comments

Comments
 (0)