Skip to content

Commit fcf3e3e

Browse files
authored
[AMD] Fix fp32/fp16 to OCP fp8 conversion on MI300 (#7382)
The current implementation for converting FP32 to OCP FP8 on MI300 involves two steps: FP32 → FP16 and FP16 → OCP FP8. However, this approach produces incorrect results for subnormal numbers. To fix the issue, this patch introduces a direct conversion from FP32 to OCP FP8.
1 parent 4fe73e1 commit fcf3e3e

File tree

4 files changed

+187
-52
lines changed

4 files changed

+187
-52
lines changed

python/test/unit/language/test_conversions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,8 @@ def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device):
344344
if dst_dtype == 'float8e4nv':
345345
if not rounding == 'rtne':
346346
pytest.skip("float8e4nv downcast tests only supported with RTNE rounding on AMDGPU")
347-
if not (is_hip_cdna3() and src_dtype == 'float16' or is_hip_cdna4()):
348-
pytest.skip("float8e4nv downcast tests only supported on AMDGPU CDNA3 or on CDNA4 and from float16 with RTNE rounding")
347+
if not is_hip_cdna4() and src_dtype == 'bfloat16':
348+
pytest.skip("float8e4nv downcast tests from bfloat16 only supported on AMDGPU CDNA4")
349349

350350
if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne' and not is_hip_cdna3():
351351
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU CDNA3")

python/triton_kernels/tests/test_matmul.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,6 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
257257
if split_k > 1:
258258
pytest.skip("splitK hasn't been fully tested on AMD GPU.")
259259

260-
if is_hip_cdna3() and ("float8_e4m3fn" in (weight_dtype_str, act_dtype_str)):
261-
pytest.skip("float8_e4m3fn hasn't been fully tested on AMD CDNA3 platform.")
262-
263260
if "float8_e4m3fnuz" in (weight_dtype_str, act_dtype_str) and not is_hip_cdna3():
264261
pytest.skip("float8_e4m3fnuz only tested on AMD CDNA3 Platform")
265262

python/triton_kernels/tests/test_mxfp.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
upcast_from_mxfp_torch,
2323
)
2424
from triton_kernels.testing import assert_close, assert_equal
25-
from triton_kernels.target_info import is_hip, is_hip_cdna3
25+
from triton_kernels.target_info import is_hip
2626

2727

2828
def dtype_str_to_torch(dtype_str: str) -> torch.dtype:
@@ -146,8 +146,6 @@ def test_mxfp_casting(
146146
if is_hip():
147147
if swizzle_value is not None or swizzle_scale is not None:
148148
pytest.skip("Other swizzling patterns are not supported by AMD GPU")
149-
if quant_dtype == 'float8_e4m3fn' and is_hip_cdna3():
150-
pytest.skip("float8_e4m3fn cast hasn't been fully tested on AMD CDNA3")
151149

152150
swizzle_axis = swizzle_axis if (swizzle_value or swizzle_scale) else None
153151
quant_torch_type = dtype_str_to_torch(quant_dtype)

third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 184 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,70 @@ namespace {
2323
//===----------------------------------------------------------------------===//
2424
// Data type conversion utility functions
2525
//===----------------------------------------------------------------------===//
26+
template <typename FPType> struct FPTypeInfo {
27+
FPTypeInfo(Location loc, ConversionPatternRewriter &rewriter,
28+
TritonLLVMOpBuilder &builder)
29+
: loc(loc), rewriter(rewriter), b(builder) {}
30+
IntegerType getIntType() {
31+
if constexpr (std::is_same_v<FPType, Float32Type>) {
32+
return i32_ty;
33+
}
34+
if constexpr (std::is_same_v<FPType, Float16Type> ||
35+
std::is_same_v<FPType, BFloat16Type>) {
36+
return i16_ty;
37+
}
38+
if constexpr (std::is_same_v<FPType, Float8E4M3FNType> ||
39+
std::is_same_v<FPType, Float8E5M2Type>) {
40+
return i8_ty;
41+
}
42+
return nullptr;
43+
}
44+
45+
SmallVector<float> getHalfwayPointsForDstType(TypeID dstTyID) {
46+
if constexpr (std::is_same_v<FPType, Float32Type>) {
47+
if (dstTyID == TypeID::get<Float8E4M3FNType>())
48+
return {0x3a800000, // halfway between [0/8 * 2^-6, 1/8 * 2^-6]
49+
0x3b400000, // halfway between [1/8 * 2^-6, 2/8 * 2^-6]
50+
0x3ba00000, // halfway between [2/8 * 2^-6, 3/8 * 2^-6]
51+
0x3be00000, // halfway between [3/8 * 2^-6, 4/8 * 2^-6]
52+
0x3c100000, // halfway between [4/8 * 2^-6, 5/8 * 2^-6]
53+
0x3c300000, // halfway between [5/8 * 2^-6, 6/8 * 2^-6]
54+
0x3c500000, // halfway between [6/8 * 2^-6, 7/8 * 2^-6]
55+
0x3c700000}; // halfway between [7/8 * 2^-6, 8/8 * 2^-6]
56+
if (dstTyID == TypeID::get<Float8E5M2Type>())
57+
return {0x37000000, // halfway between [0/4 * 2^(-14), 1/4 * 2^(-14)]
58+
0x37c00000, // halfway between [1/4 * 2^(-14), 2/4 * 2^(-14)]
59+
0x38200000, // halfway between [2/4 * 2^(-14), 3/4 * 2^(-14)]
60+
0x38600000}; // halfway between [3/4 * 2^(-14), 4/4 * 2^(-14)]
61+
}
62+
if constexpr (std::is_same_v<FPType, Float16Type>) {
63+
if (dstTyID == TypeID::get<Float8E4M3FNType>())
64+
return {0x1400, 0x1A00, 0x1D00, 0x1F00, 0x2080, 0x2180, 0x2280, 0x2380};
65+
if (dstTyID == TypeID::get<Float8E5M2Type>())
66+
return {0x0080, 0x0180, 0x0200, 0x0380};
67+
}
68+
return {};
69+
}
70+
71+
Value toLLVMIntValue(int32_t val) {
72+
if constexpr (std::is_same_v<FPType, Float32Type>) {
73+
return b.i32_val(val);
74+
}
75+
if constexpr (std::is_same_v<FPType, Float16Type> ||
76+
std::is_same_v<FPType, BFloat16Type>) {
77+
return b.i16_val(val);
78+
}
79+
if constexpr (std::is_same_v<FPType, Float8E4M3FNType> ||
80+
std::is_same_v<FPType, Float8E5M2Type>) {
81+
return b.i8_val(val);
82+
}
83+
return nullptr;
84+
}
85+
Location loc;
86+
ConversionPatternRewriter &rewriter;
87+
TritonLLVMOpBuilder &b;
88+
};
89+
2690
// Convert Ocp Fp8/Bf8 to Fp16/Bf16/Fp32 on CDNA4
2791
template <typename ConvertOp>
2892
static SmallVector<Value>
@@ -111,6 +175,7 @@ cvtScalePkDowncastToFp8(Location loc, ConversionPatternRewriter &rewriter,
111175
static SmallVector<Value>
112176
Fp16_to_Fp8E5M2_RTNE_SW(Location loc, ConversionPatternRewriter &rewriter,
113177
const SmallVector<Value> &v) {
178+
114179
assert(v.size() == 4);
115180
auto b = TritonLLVMOpBuilder(loc, rewriter);
116181

@@ -203,88 +268,155 @@ static Value checkIsNan(TritonLLVMOpBuilder &builder, Value v) {
203268
->getResult(0);
204269
}
205270

206-
// Fp16 -> OCP Fp8 (RTNZ)
207-
208-
// Cast FP16 to FP8E4M3FN in saturation and round-to-nearest-even mode.
271+
// Cast Fp32 or FP16 to FP8E4M3FN in saturation and round-to-nearest-even mode.
209272
// According to
210273
// https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1,
211274
// In saturation mode, inf and out-of-range numbers are converted to the largest
212275
// normal number, i.e. ±448. NaNs are converted to NaNs.
213-
static Value
214-
Fp16_to_Fp8E4M3FN_RTNE_oneValue(Location loc,
215-
ConversionPatternRewriter &rewriter, Value v) {
276+
template <typename SrcFPType>
277+
static Value Fp_to_Fp8E4M3FN_RTNE_oneValue(Location loc,
278+
ConversionPatternRewriter &rewriter,
279+
Value v) {
280+
static_assert((std::is_same_v<SrcFPType, Float32Type>) ||
281+
(std::is_same_v<SrcFPType, Float16Type>));
216282
auto b = TritonLLVMOpBuilder(loc, rewriter);
283+
const llvm::fltSemantics *srcSemantic = nullptr;
284+
if constexpr (std::is_same_v<SrcFPType, Float32Type>)
285+
srcSemantic = &llvm::APFloat::IEEEsingle();
286+
else
287+
srcSemantic = &llvm::APFloat::IEEEhalf();
288+
auto srcWidth = llvm::APFloat::getSizeInBits(*srcSemantic);
289+
auto srcMantissaBits = llvm::APFloat::semanticsPrecision(*srcSemantic) - 1;
290+
auto srcExponentBits = srcWidth - srcMantissaBits - 1;
291+
auto srcBias = (1 << (srcExponentBits - 1)) - 1;
292+
293+
const llvm::fltSemantics &dstSemantic = llvm::APFloat::Float8E4M3FN();
294+
auto dstWidth = llvm::APFloat::getSizeInBits(dstSemantic);
295+
auto dstMantissaBits = llvm::APFloat::semanticsPrecision(dstSemantic) - 1;
296+
auto dstExponentBits = dstWidth - dstMantissaBits - 1;
297+
auto dstBias = (1 << (dstExponentBits - 1)) - 1;
298+
299+
FPTypeInfo<SrcFPType> srcFpInfo(loc, rewriter, b);
300+
FPTypeInfo<Float8E4M3FNType> dstFpInfo(loc, rewriter, b);
301+
auto srcIntType = srcFpInfo.getIntType();
217302
Value isNaN = checkIsNan(b, v);
303+
304+
uint32_t reducedMantissaBits = srcMantissaBits - dstMantissaBits;
305+
Value reducedMantissaValue = srcFpInfo.toLLVMIntValue(reducedMantissaBits);
306+
218307
// Get sign and absolute value
219-
Value vi16 = b.bitcast(v, i16_ty);
308+
Value intVal = b.bitcast(v, srcIntType);
309+
int32_t signMask = 1 << (srcWidth - 1);
220310
Value sign =
221-
b.trunc(i8_ty, b.lshr(b.and_(vi16, b.i16_val(0x8000)), b.i16_val(8)));
222-
vi16 = b.and_(vi16, b.i16_val(0x7FFF));
311+
b.trunc(i8_ty, b.lshr(b.and_(intVal, srcFpInfo.toLLVMIntValue(signMask)),
312+
srcFpInfo.toLLVMIntValue(srcWidth - 8)));
313+
314+
int32_t absoluteMask = signMask - 1;
315+
intVal = b.and_(intVal, srcFpInfo.toLLVMIntValue(absoluteMask));
223316

224317
// Rounding to nearest even
225-
constexpr uint16_t baseRoundingBias = 0x003F; // 1 << (10 - 3 - 1) - 1
318+
uint32_t baseRoundingBias = (1 << (reducedMantissaBits - 1)) - 1;
226319

227-
// S.EEEEE.MMMMMMMMMM => 0.00000.00M0000000 => 0.00000.000000000M
320+
// For Fp16, S.EEEEE.MMMMMMMMMM => 0.00000.00M0000000 => 0.00000.000000000M
321+
uint32_t mantissaLSB = 1 << reducedMantissaBits;
322+
Value mantissaLSBValue = srcFpInfo.toLLVMIntValue(mantissaLSB);
228323
Value remainingMantissaLSB =
229-
b.lshr(b.and_(vi16, b.i16_val(0x0080)), b.i16_val(7));
230-
Value roundingBias = b.add(remainingMantissaLSB, b.i16_val(baseRoundingBias));
231-
Value vFp8 = b.add(vi16, roundingBias);
324+
b.lshr(b.and_(intVal, mantissaLSBValue), reducedMantissaValue);
325+
Value roundingBias =
326+
b.add(remainingMantissaLSB, srcFpInfo.toLLVMIntValue(baseRoundingBias));
327+
Value vFp8 = b.add(intVal, roundingBias);
232328

233329
// Reduce mantissa to 3 bits
234-
vFp8 = b.and_(vFp8, b.i16_val(0xFF80)); // 0xFF80 == 1.11111.1110000000
235-
236-
// 0x2400 is the FP16 representation of 2^{-6}, which is the smallest normal
237-
// number in FP8E4M3FN. We round numbers smaller than that to 0x2400 to make
330+
// For Fp16, reduceMantissaMask == 1.11111.1110000000
331+
uint32_t reduceMantissaMask =
332+
((1 << (1 + srcExponentBits + dstMantissaBits + 1)) - 1)
333+
<< reducedMantissaBits;
334+
Value reduceMantissa = srcFpInfo.toLLVMIntValue(reduceMantissaMask);
335+
vFp8 = b.and_(vFp8, reduceMantissa);
336+
337+
// We round numbers smaller than the minimal normal number in Fp8 to make
238338
// it easier to handle subnormals
239-
vFp8 = b.umax(vFp8, b.i16_val(0x2400));
339+
auto dstSmallest = llvm::APFloat::getSmallestNormalized(dstSemantic);
340+
// Get the srcFpType representation of the minimal normal number in Fp8
341+
bool losesInfo;
342+
dstSmallest.convert(*srcSemantic, APFloat::rmNearestTiesToEven, &losesInfo);
343+
uint32_t dstMinimal =
344+
static_cast<uint32_t>(dstSmallest.bitcastToAPInt().getZExtValue());
345+
vFp8 = b.umax(vFp8, srcFpInfo.toLLVMIntValue(dstMinimal));
240346

241347
// Adjust exponent bias
242-
vFp8 = b.sub(vFp8, b.i16_val(0x2000)); // (15 - 7) << 10
348+
uint32_t expBias = (srcBias - dstBias) << srcMantissaBits;
349+
vFp8 = b.sub(vFp8, srcFpInfo.toLLVMIntValue(expBias));
243350

244351
// Shift right and truncate
245-
vFp8 = b.trunc(i8_ty, b.lshr(vFp8, b.i16_val(7))); // 10 - 3
246-
247-
// 0x5F7F == 0.10111.1101111111 is the largest possible normal
248-
// number(including infinity) after rounding in FP8
249-
//
250-
// In saturation mode, numbers larger than the max normal number(including
251-
// infinity) in FP8 after rounding will be replaced with max_E4M3, i.e. 0x7E
252-
// === 0.1111.110
253-
Value isOverflowOrInf = b.icmp_ugt(vi16, b.i16_val(0x5F7F));
254-
vFp8 = b.select(isOverflowOrInf, b.i8_val(0x7E), vFp8);
352+
vFp8 = b.trunc(i8_ty, b.lshr(vFp8, reducedMantissaValue));
353+
354+
// Any numbers larger than the max normal number(including infinity) in FP8
355+
// after rounding will cause overflow
356+
auto dstLargest = llvm::APFloat::getLargest(dstSemantic);
357+
uint32_t dstMaxPositive =
358+
static_cast<uint32_t>(dstLargest.bitcastToAPInt().getZExtValue());
359+
// Get the srcFpType representation of the maximal normal number in Fp8
360+
dstLargest.convert(*srcSemantic, APFloat::rmNearestTiesToEven, &losesInfo);
361+
uint32_t dstMaxOfSrcType =
362+
static_cast<uint32_t>(dstLargest.bitcastToAPInt().getZExtValue());
363+
364+
// For Fp16, 0x5F7F == 0.10111.1101111111 is the largest possible normal
365+
// number(including infinity) after rounding in FP8E4M3
366+
if constexpr (std::is_same_v<SrcFPType, Float32Type>)
367+
dstMaxOfSrcType |= 0x7ffff;
368+
else
369+
dstMaxOfSrcType |= 0x7f;
370+
Value isOverflowOrInf =
371+
b.icmp_ugt(intVal, srcFpInfo.toLLVMIntValue(dstMaxOfSrcType));
372+
vFp8 =
373+
b.select(isOverflowOrInf, dstFpInfo.toLLVMIntValue(dstMaxPositive), vFp8);
255374

256375
// Round subnormals to nearest even. Ref:
257376
// https://github.com/openxla/xla/blob/f20c6fe2/xla/service/elemental_ir_emitter.cc#L272
258377
constexpr size_t lutSize = 8;
259-
constexpr float halfwayPointsLUT[lutSize] = {0x1400, 0x1A00, 0x1D00, 0x1F00,
260-
0x2080, 0x2180, 0x2280, 0x2380};
378+
auto dstTyID = TypeID::get<Float8E4M3FNType>();
379+
SmallVector<float> halfwayPointsLUT =
380+
srcFpInfo.getHalfwayPointsForDstType(dstTyID);
261381

262382
for (int i = lutSize - 1; i >= 0; i--) {
263383
Value cmp;
264384
if (i % 2 == 0) {
265-
cmp = b.icmp_ule(vi16, b.i16_val(halfwayPointsLUT[i]));
385+
cmp = b.icmp_ule(intVal, srcFpInfo.toLLVMIntValue(halfwayPointsLUT[i]));
266386
} else {
267-
cmp = b.icmp_ult(vi16, b.i16_val(halfwayPointsLUT[i]));
387+
cmp = b.icmp_ult(intVal, srcFpInfo.toLLVMIntValue(halfwayPointsLUT[i]));
268388
}
269389

270390
vFp8 = b.select(cmp, b.i8_val(i), vFp8);
271391
}
272392

273393
// NaN remains NaN after conversion
274-
vFp8 = b.select(isNaN, b.i8_val(0x7F), vFp8);
394+
int32_t positiveNan = (1 << (dstExponentBits + dstMantissaBits)) - 1;
395+
vFp8 = b.select(isNaN, dstFpInfo.toLLVMIntValue(positiveNan), vFp8);
275396

276397
// Set sign bit
277398
vFp8 = b.or_(vFp8, sign);
278399

279400
return vFp8;
280401
}
281402

403+
// Fp32 -> OCP Fp8 (RTNZ)
404+
static SmallVector<Value>
405+
Fp32_to_Fp8E4M3FN_RTNE_SW(Location loc, ConversionPatternRewriter &rewriter,
406+
const SmallVector<Value> &v) {
407+
SmallVector<Value> result(2);
408+
result[0] = Fp_to_Fp8E4M3FN_RTNE_oneValue<Float32Type>(loc, rewriter, v[0]);
409+
result[1] = Fp_to_Fp8E4M3FN_RTNE_oneValue<Float32Type>(loc, rewriter, v[1]);
410+
return result;
411+
}
412+
413+
// Fp16 -> OCP Fp8 (RTNZ)
282414
static SmallVector<Value>
283415
Fp16_to_Fp8E4M3FN_RTNE_SW(Location loc, ConversionPatternRewriter &rewriter,
284416
const SmallVector<Value> &v) {
285417
SmallVector<Value> result(2);
286-
result[0] = Fp16_to_Fp8E4M3FN_RTNE_oneValue(loc, rewriter, v[0]);
287-
result[1] = Fp16_to_Fp8E4M3FN_RTNE_oneValue(loc, rewriter, v[1]);
418+
result[0] = Fp_to_Fp8E4M3FN_RTNE_oneValue<Float16Type>(loc, rewriter, v[0]);
419+
result[1] = Fp_to_Fp8E4M3FN_RTNE_oneValue<Float16Type>(loc, rewriter, v[1]);
288420
return result;
289421
}
290422

@@ -377,14 +509,21 @@ static SmallVector<Value> Fp8E5M2_to_Fp32(Location loc,
377509
}
378510

379511
// Convert Fp32 to OCP Fp8 on CDNA4
380-
static SmallVector<Value> Fp32_to_Fp8E4M3FN(Location loc,
381-
ConversionPatternRewriter &rewriter,
382-
const SmallVector<Value> &v) {
512+
513+
static SmallVector<Value>
514+
Fp32_to_Fp8E4M3FN_RTNE_HW(Location loc, ConversionPatternRewriter &rewriter,
515+
const SmallVector<Value> &v) {
383516
assert(v.size() == 2);
384517
return cvtScalePkDowncastToFp8<ROCDL::CvtScaleF32PkFp8F32Op>(loc, rewriter,
385518
v[0], v[1]);
386519
}
387520

521+
// Fp32 -> OCP Fp8 (RTNE)
522+
ConverterT Fp32_to_Fp8E4M3FN_RTNE(AMD::ISAFamily isaFamily) {
523+
return isaFamily == AMD::ISAFamily::CDNA4 ? Fp32_to_Fp8E4M3FN_RTNE_HW
524+
: Fp32_to_Fp8E4M3FN_RTNE_SW;
525+
}
526+
388527
// Fp32 -> OCP Bf8 (RTNE)
389528

390529
static SmallVector<Value>
@@ -1343,7 +1482,8 @@ struct FpToFpOpConversion
13431482
Fp32_to_Fp8E4M3FNUZ},
13441483
{{F32TyID, F8E5M2FNUZTyID, RoundingMode::RTNE},
13451484
Fp32_to_Fp8E5M2FNUZ},
1346-
{{F32TyID, F8E4M3FNTyID, RoundingMode::RTNE}, Fp32_to_Fp8E4M3FN},
1485+
{{F32TyID, F8E4M3FNTyID, RoundingMode::RTNE},
1486+
Fp32_to_Fp8E4M3FN_RTNE(isaFamily)},
13471487
{{F32TyID, F8E5M2TyID, RoundingMode::RTNE},
13481488
Fp32_to_Fp8E5M2_RTNE(isaFamily)},
13491489
{{F32TyID, F8E5M2TyID, RoundingMode::RTZ}, Fp32_to_Fp8E5M2_RTZ},
@@ -1406,8 +1546,8 @@ struct FpToFpOpConversion
14061546
// - fp16 -> fp8 with rtne
14071547
// with the following exceptions:
14081548
// 1. fp32 -> ocp fp8/bf8 on CDNA4: has hardware support
1409-
// 2. fp32 -> nanoo fp8/bf8 on non-CDNA4: has hardware support
1410-
// 3. fp32 -> ocp bf8 on non-CDNA4: has software support
1549+
// 2. fp32 -> nanoo fp8/bf8 on CDNA3: has hardware support
1550+
// 3. fp32 -> ocp fp8/bf8 on non-CDNA4: has software support
14111551
bool useFP16IntermediateSrc =
14121552
srcElementType.isF32() && !dstElementType.isF16() &&
14131553
roundingMode == RoundingMode::RTNE &&
@@ -1417,7 +1557,7 @@ struct FpToFpOpConversion
14171557
(llvm::isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(
14181558
dstElementType))) &&
14191559
!(isaFamily != AMD::ISAFamily::CDNA4 &&
1420-
(llvm::isa<Float8E5M2Type>(dstElementType)));
1560+
(llvm::isa<Float8E5M2Type, Float8E4M3FNType>(dstElementType)));
14211561

14221562
// fp8/bf8->f32, if neither nanoo fp8/bf8 on CDNA3 nor ocp fp8/bf8 on CDNA4,
14231563
// is done in two steps: fp8/bf8->fp16 and fp16->fp32

0 commit comments

Comments
 (0)