@@ -23,6 +23,70 @@ namespace {
23
23
// ===----------------------------------------------------------------------===//
24
24
// Data type conversion utility functions
25
25
// ===----------------------------------------------------------------------===//
26
+ template <typename FPType> struct FPTypeInfo {
27
+ FPTypeInfo (Location loc, ConversionPatternRewriter &rewriter,
28
+ TritonLLVMOpBuilder &builder)
29
+ : loc(loc), rewriter(rewriter), b(builder) {}
30
+ IntegerType getIntType () {
31
+ if constexpr (std::is_same_v<FPType, Float32Type>) {
32
+ return i32_ty;
33
+ }
34
+ if constexpr (std::is_same_v<FPType, Float16Type> ||
35
+ std::is_same_v<FPType, BFloat16Type>) {
36
+ return i16_ty;
37
+ }
38
+ if constexpr (std::is_same_v<FPType, Float8E4M3FNType> ||
39
+ std::is_same_v<FPType, Float8E5M2Type>) {
40
+ return i8_ty;
41
+ }
42
+ return nullptr ;
43
+ }
44
+
45
+ SmallVector<float > getHalfwayPointsForDstType (TypeID dstTyID) {
46
+ if constexpr (std::is_same_v<FPType, Float32Type>) {
47
+ if (dstTyID == TypeID::get<Float8E4M3FNType>())
48
+ return {0x3a800000 , // halfway between [0/8 * 2^-6, 1/8 * 2^-6]
49
+ 0x3b400000 , // halfway between [1/8 * 2^-6, 2/8 * 2^-6]
50
+ 0x3ba00000 , // halfway between [2/8 * 2^-6, 3/8 * 2^-6]
51
+ 0x3be00000 , // halfway between [3/8 * 2^-6, 4/8 * 2^-6]
52
+ 0x3c100000 , // halfway between [4/8 * 2^-6, 5/8 * 2^-6]
53
+ 0x3c300000 , // halfway between [5/8 * 2^-6, 6/8 * 2^-6]
54
+ 0x3c500000 , // halfway between [6/8 * 2^-6, 7/8 * 2^-6]
55
+ 0x3c700000 }; // halfway between [7/8 * 2^-6, 8/8 * 2^-6]
56
+ if (dstTyID == TypeID::get<Float8E5M2Type>())
57
+ return {0x37000000 , // halfway between [0/4 * 2^(-14), 1/4 * 2^(-14)]
58
+ 0x37c00000 , // halfway between [1/4 * 2^(-14), 2/4 * 2^(-14)]
59
+ 0x38200000 , // halfway between [2/4 * 2^(-14), 3/4 * 2^(-14)]
60
+ 0x38600000 }; // halfway between [3/4 * 2^(-14), 4/4 * 2^(-14)]
61
+ }
62
+ if constexpr (std::is_same_v<FPType, Float16Type>) {
63
+ if (dstTyID == TypeID::get<Float8E4M3FNType>())
64
+ return {0x1400 , 0x1A00 , 0x1D00 , 0x1F00 , 0x2080 , 0x2180 , 0x2280 , 0x2380 };
65
+ if (dstTyID == TypeID::get<Float8E5M2Type>())
66
+ return {0x0080 , 0x0180 , 0x0200 , 0x0380 };
67
+ }
68
+ return {};
69
+ }
70
+
71
+ Value toLLVMIntValue (int32_t val) {
72
+ if constexpr (std::is_same_v<FPType, Float32Type>) {
73
+ return b.i32_val (val);
74
+ }
75
+ if constexpr (std::is_same_v<FPType, Float16Type> ||
76
+ std::is_same_v<FPType, BFloat16Type>) {
77
+ return b.i16_val (val);
78
+ }
79
+ if constexpr (std::is_same_v<FPType, Float8E4M3FNType> ||
80
+ std::is_same_v<FPType, Float8E5M2Type>) {
81
+ return b.i8_val (val);
82
+ }
83
+ return nullptr ;
84
+ }
85
+ Location loc;
86
+ ConversionPatternRewriter &rewriter;
87
+ TritonLLVMOpBuilder &b;
88
+ };
89
+
26
90
// Convert Ocp Fp8/Bf8 to Fp16/Bf16/Fp32 on CDNA4
27
91
template <typename ConvertOp>
28
92
static SmallVector<Value>
@@ -111,6 +175,7 @@ cvtScalePkDowncastToFp8(Location loc, ConversionPatternRewriter &rewriter,
111
175
static SmallVector<Value>
112
176
Fp16_to_Fp8E5M2_RTNE_SW (Location loc, ConversionPatternRewriter &rewriter,
113
177
const SmallVector<Value> &v) {
178
+
114
179
assert (v.size () == 4 );
115
180
auto b = TritonLLVMOpBuilder (loc, rewriter);
116
181
@@ -203,88 +268,155 @@ static Value checkIsNan(TritonLLVMOpBuilder &builder, Value v) {
203
268
->getResult (0 );
204
269
}
205
270
206
- // Fp16 -> OCP Fp8 (RTNZ)
207
-
208
- // Cast FP16 to FP8E4M3FN in saturation and round-to-nearest-even mode.
271
+ // Cast Fp32 or FP16 to FP8E4M3FN in saturation and round-to-nearest-even mode.
209
272
// According to
210
273
// https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1,
211
274
// In saturation mode, inf and out-of-range numbers are converted to the largest
212
275
// normal number, i.e. ±448. NaNs are converted to NaNs.
213
- static Value
214
- Fp16_to_Fp8E4M3FN_RTNE_oneValue (Location loc,
215
- ConversionPatternRewriter &rewriter, Value v) {
276
+ template <typename SrcFPType>
277
+ static Value Fp_to_Fp8E4M3FN_RTNE_oneValue (Location loc,
278
+ ConversionPatternRewriter &rewriter,
279
+ Value v) {
280
+ static_assert ((std::is_same_v<SrcFPType, Float32Type>) ||
281
+ (std::is_same_v<SrcFPType, Float16Type>));
216
282
auto b = TritonLLVMOpBuilder (loc, rewriter);
283
+ const llvm::fltSemantics *srcSemantic = nullptr ;
284
+ if constexpr (std::is_same_v<SrcFPType, Float32Type>)
285
+ srcSemantic = &llvm::APFloat::IEEEsingle ();
286
+ else
287
+ srcSemantic = &llvm::APFloat::IEEEhalf ();
288
+ auto srcWidth = llvm::APFloat::getSizeInBits (*srcSemantic);
289
+ auto srcMantissaBits = llvm::APFloat::semanticsPrecision (*srcSemantic) - 1 ;
290
+ auto srcExponentBits = srcWidth - srcMantissaBits - 1 ;
291
+ auto srcBias = (1 << (srcExponentBits - 1 )) - 1 ;
292
+
293
+ const llvm::fltSemantics &dstSemantic = llvm::APFloat::Float8E4M3FN ();
294
+ auto dstWidth = llvm::APFloat::getSizeInBits (dstSemantic);
295
+ auto dstMantissaBits = llvm::APFloat::semanticsPrecision (dstSemantic) - 1 ;
296
+ auto dstExponentBits = dstWidth - dstMantissaBits - 1 ;
297
+ auto dstBias = (1 << (dstExponentBits - 1 )) - 1 ;
298
+
299
+ FPTypeInfo<SrcFPType> srcFpInfo (loc, rewriter, b);
300
+ FPTypeInfo<Float8E4M3FNType> dstFpInfo (loc, rewriter, b);
301
+ auto srcIntType = srcFpInfo.getIntType ();
217
302
Value isNaN = checkIsNan (b, v);
303
+
304
+ uint32_t reducedMantissaBits = srcMantissaBits - dstMantissaBits;
305
+ Value reducedMantissaValue = srcFpInfo.toLLVMIntValue (reducedMantissaBits);
306
+
218
307
// Get sign and absolute value
219
- Value vi16 = b.bitcast (v, i16_ty);
308
+ Value intVal = b.bitcast (v, srcIntType);
309
+ int32_t signMask = 1 << (srcWidth - 1 );
220
310
Value sign =
221
- b.trunc (i8_ty, b.lshr (b.and_ (vi16, b.i16_val (0x8000 )), b.i16_val (8 )));
222
- vi16 = b.and_ (vi16, b.i16_val (0x7FFF ));
311
+ b.trunc (i8_ty, b.lshr (b.and_ (intVal, srcFpInfo.toLLVMIntValue (signMask)),
312
+ srcFpInfo.toLLVMIntValue (srcWidth - 8 )));
313
+
314
+ int32_t absoluteMask = signMask - 1 ;
315
+ intVal = b.and_ (intVal, srcFpInfo.toLLVMIntValue (absoluteMask));
223
316
224
317
// Rounding to nearest even
225
- constexpr uint16_t baseRoundingBias = 0x003F ; // 1 << (10 - 3 - 1) - 1
318
+ uint32_t baseRoundingBias = ( 1 << (reducedMantissaBits - 1 )) - 1 ;
226
319
227
- // S.EEEEE.MMMMMMMMMM => 0.00000.00M0000000 => 0.00000.000000000M
320
+ // For Fp16, S.EEEEE.MMMMMMMMMM => 0.00000.00M0000000 => 0.00000.000000000M
321
+ uint32_t mantissaLSB = 1 << reducedMantissaBits;
322
+ Value mantissaLSBValue = srcFpInfo.toLLVMIntValue (mantissaLSB);
228
323
Value remainingMantissaLSB =
229
- b.lshr (b.and_ (vi16, b.i16_val (0x0080 )), b.i16_val (7 ));
230
- Value roundingBias = b.add (remainingMantissaLSB, b.i16_val (baseRoundingBias));
231
- Value vFp8 = b.add (vi16, roundingBias);
324
+ b.lshr (b.and_ (intVal, mantissaLSBValue), reducedMantissaValue);
325
+ Value roundingBias =
326
+ b.add (remainingMantissaLSB, srcFpInfo.toLLVMIntValue (baseRoundingBias));
327
+ Value vFp8 = b.add (intVal, roundingBias);
232
328
233
329
// Reduce mantissa to 3 bits
234
- vFp8 = b.and_ (vFp8, b.i16_val (0xFF80 )); // 0xFF80 == 1.11111.1110000000
235
-
236
- // 0x2400 is the FP16 representation of 2^{-6}, which is the smallest normal
237
- // number in FP8E4M3FN. We round numbers smaller than that to 0x2400 to make
330
+ // For Fp16, reduceMantissaMask == 1.11111.1110000000
331
+ uint32_t reduceMantissaMask =
332
+ ((1 << (1 + srcExponentBits + dstMantissaBits + 1 )) - 1 )
333
+ << reducedMantissaBits;
334
+ Value reduceMantissa = srcFpInfo.toLLVMIntValue (reduceMantissaMask);
335
+ vFp8 = b.and_ (vFp8, reduceMantissa);
336
+
337
+ // We round numbers smaller than the minimal normal number in Fp8 to make
238
338
// it easier to handle subnormals
239
- vFp8 = b.umax (vFp8, b.i16_val (0x2400 ));
339
+ auto dstSmallest = llvm::APFloat::getSmallestNormalized (dstSemantic);
340
+ // Get the srcFpType representation of the minimal normal number in Fp8
341
+ bool losesInfo;
342
+ dstSmallest.convert (*srcSemantic, APFloat::rmNearestTiesToEven, &losesInfo);
343
+ uint32_t dstMinimal =
344
+ static_cast <uint32_t >(dstSmallest.bitcastToAPInt ().getZExtValue ());
345
+ vFp8 = b.umax (vFp8, srcFpInfo.toLLVMIntValue (dstMinimal));
240
346
241
347
// Adjust exponent bias
242
- vFp8 = b.sub (vFp8, b.i16_val (0x2000 )); // (15 - 7) << 10
348
+ uint32_t expBias = (srcBias - dstBias) << srcMantissaBits;
349
+ vFp8 = b.sub (vFp8, srcFpInfo.toLLVMIntValue (expBias));
243
350
244
351
// Shift right and truncate
245
- vFp8 = b.trunc (i8_ty, b.lshr (vFp8, b.i16_val (7 ))); // 10 - 3
246
-
247
- // 0x5F7F == 0.10111.1101111111 is the largest possible normal
248
- // number(including infinity) after rounding in FP8
249
- //
250
- // In saturation mode, numbers larger than the max normal number(including
251
- // infinity) in FP8 after rounding will be replaced with max_E4M3, i.e. 0x7E
252
- // === 0.1111.110
253
- Value isOverflowOrInf = b.icmp_ugt (vi16, b.i16_val (0x5F7F ));
254
- vFp8 = b.select (isOverflowOrInf, b.i8_val (0x7E ), vFp8);
352
+ vFp8 = b.trunc (i8_ty, b.lshr (vFp8, reducedMantissaValue));
353
+
354
+ // Any numbers larger than the max normal number(including infinity) in FP8
355
+ // after rounding will cause overflow
356
+ auto dstLargest = llvm::APFloat::getLargest (dstSemantic);
357
+ uint32_t dstMaxPositive =
358
+ static_cast <uint32_t >(dstLargest.bitcastToAPInt ().getZExtValue ());
359
+ // Get the srcFpType representation of the maximal normal number in Fp8
360
+ dstLargest.convert (*srcSemantic, APFloat::rmNearestTiesToEven, &losesInfo);
361
+ uint32_t dstMaxOfSrcType =
362
+ static_cast <uint32_t >(dstLargest.bitcastToAPInt ().getZExtValue ());
363
+
364
+ // For Fp16, 0x5F7F == 0.10111.1101111111 is the largest possible normal
365
+ // number(including infinity) after rounding in FP8E4M3
366
+ if constexpr (std::is_same_v<SrcFPType, Float32Type>)
367
+ dstMaxOfSrcType |= 0x7ffff ;
368
+ else
369
+ dstMaxOfSrcType |= 0x7f ;
370
+ Value isOverflowOrInf =
371
+ b.icmp_ugt (intVal, srcFpInfo.toLLVMIntValue (dstMaxOfSrcType));
372
+ vFp8 =
373
+ b.select (isOverflowOrInf, dstFpInfo.toLLVMIntValue (dstMaxPositive), vFp8);
255
374
256
375
// Round subnormals to nearest even. Ref:
257
376
// https://github.com/openxla/xla/blob/f20c6fe2/xla/service/elemental_ir_emitter.cc#L272
258
377
constexpr size_t lutSize = 8 ;
259
- constexpr float halfwayPointsLUT[lutSize] = {0x1400 , 0x1A00 , 0x1D00 , 0x1F00 ,
260
- 0x2080 , 0x2180 , 0x2280 , 0x2380 };
378
+ auto dstTyID = TypeID::get<Float8E4M3FNType>();
379
+ SmallVector<float > halfwayPointsLUT =
380
+ srcFpInfo.getHalfwayPointsForDstType (dstTyID);
261
381
262
382
for (int i = lutSize - 1 ; i >= 0 ; i--) {
263
383
Value cmp;
264
384
if (i % 2 == 0 ) {
265
- cmp = b.icmp_ule (vi16, b. i16_val (halfwayPointsLUT[i]));
385
+ cmp = b.icmp_ule (intVal, srcFpInfo. toLLVMIntValue (halfwayPointsLUT[i]));
266
386
} else {
267
- cmp = b.icmp_ult (vi16, b. i16_val (halfwayPointsLUT[i]));
387
+ cmp = b.icmp_ult (intVal, srcFpInfo. toLLVMIntValue (halfwayPointsLUT[i]));
268
388
}
269
389
270
390
vFp8 = b.select (cmp, b.i8_val (i), vFp8);
271
391
}
272
392
273
393
// NaN remains NaN after conversion
274
- vFp8 = b.select (isNaN, b.i8_val (0x7F ), vFp8);
394
+ int32_t positiveNan = (1 << (dstExponentBits + dstMantissaBits)) - 1 ;
395
+ vFp8 = b.select (isNaN, dstFpInfo.toLLVMIntValue (positiveNan), vFp8);
275
396
276
397
// Set sign bit
277
398
vFp8 = b.or_ (vFp8, sign);
278
399
279
400
return vFp8;
280
401
}
281
402
403
+ // Fp32 -> OCP Fp8 (RTNZ)
404
+ static SmallVector<Value>
405
+ Fp32_to_Fp8E4M3FN_RTNE_SW (Location loc, ConversionPatternRewriter &rewriter,
406
+ const SmallVector<Value> &v) {
407
+ SmallVector<Value> result (2 );
408
+ result[0 ] = Fp_to_Fp8E4M3FN_RTNE_oneValue<Float32Type>(loc, rewriter, v[0 ]);
409
+ result[1 ] = Fp_to_Fp8E4M3FN_RTNE_oneValue<Float32Type>(loc, rewriter, v[1 ]);
410
+ return result;
411
+ }
412
+
413
+ // Fp16 -> OCP Fp8 (RTNZ)
282
414
static SmallVector<Value>
283
415
Fp16_to_Fp8E4M3FN_RTNE_SW (Location loc, ConversionPatternRewriter &rewriter,
284
416
const SmallVector<Value> &v) {
285
417
SmallVector<Value> result (2 );
286
- result[0 ] = Fp16_to_Fp8E4M3FN_RTNE_oneValue (loc, rewriter, v[0 ]);
287
- result[1 ] = Fp16_to_Fp8E4M3FN_RTNE_oneValue (loc, rewriter, v[1 ]);
418
+ result[0 ] = Fp_to_Fp8E4M3FN_RTNE_oneValue<Float16Type> (loc, rewriter, v[0 ]);
419
+ result[1 ] = Fp_to_Fp8E4M3FN_RTNE_oneValue<Float16Type> (loc, rewriter, v[1 ]);
288
420
return result;
289
421
}
290
422
@@ -377,14 +509,21 @@ static SmallVector<Value> Fp8E5M2_to_Fp32(Location loc,
377
509
}
378
510
379
511
// Convert Fp32 to OCP Fp8 on CDNA4
380
- static SmallVector<Value> Fp32_to_Fp8E4M3FN (Location loc,
381
- ConversionPatternRewriter &rewriter,
382
- const SmallVector<Value> &v) {
512
+
513
+ static SmallVector<Value>
514
+ Fp32_to_Fp8E4M3FN_RTNE_HW (Location loc, ConversionPatternRewriter &rewriter,
515
+ const SmallVector<Value> &v) {
383
516
assert (v.size () == 2 );
384
517
return cvtScalePkDowncastToFp8<ROCDL::CvtScaleF32PkFp8F32Op>(loc, rewriter,
385
518
v[0 ], v[1 ]);
386
519
}
387
520
521
+ // Fp32 -> OCP Fp8 (RTNE)
522
+ ConverterT Fp32_to_Fp8E4M3FN_RTNE (AMD::ISAFamily isaFamily) {
523
+ return isaFamily == AMD::ISAFamily::CDNA4 ? Fp32_to_Fp8E4M3FN_RTNE_HW
524
+ : Fp32_to_Fp8E4M3FN_RTNE_SW;
525
+ }
526
+
388
527
// Fp32 -> OCP Bf8 (RTNE)
389
528
390
529
static SmallVector<Value>
@@ -1343,7 +1482,8 @@ struct FpToFpOpConversion
1343
1482
Fp32_to_Fp8E4M3FNUZ},
1344
1483
{{F32TyID, F8E5M2FNUZTyID, RoundingMode::RTNE},
1345
1484
Fp32_to_Fp8E5M2FNUZ},
1346
- {{F32TyID, F8E4M3FNTyID, RoundingMode::RTNE}, Fp32_to_Fp8E4M3FN},
1485
+ {{F32TyID, F8E4M3FNTyID, RoundingMode::RTNE},
1486
+ Fp32_to_Fp8E4M3FN_RTNE (isaFamily)},
1347
1487
{{F32TyID, F8E5M2TyID, RoundingMode::RTNE},
1348
1488
Fp32_to_Fp8E5M2_RTNE (isaFamily)},
1349
1489
{{F32TyID, F8E5M2TyID, RoundingMode::RTZ}, Fp32_to_Fp8E5M2_RTZ},
@@ -1406,8 +1546,8 @@ struct FpToFpOpConversion
1406
1546
// - fp16 -> fp8 with rtne
1407
1547
// with the following exceptions:
1408
1548
// 1. fp32 -> ocp fp8/bf8 on CDNA4: has hardware support
1409
- // 2. fp32 -> nanoo fp8/bf8 on non-CDNA4 : has hardware support
1410
- // 3. fp32 -> ocp bf8 on non-CDNA4: has software support
1549
+ // 2. fp32 -> nanoo fp8/bf8 on CDNA3 : has hardware support
1550
+ // 3. fp32 -> ocp fp8/ bf8 on non-CDNA4: has software support
1411
1551
bool useFP16IntermediateSrc =
1412
1552
srcElementType.isF32 () && !dstElementType.isF16 () &&
1413
1553
roundingMode == RoundingMode::RTNE &&
@@ -1417,7 +1557,7 @@ struct FpToFpOpConversion
1417
1557
(llvm::isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(
1418
1558
dstElementType))) &&
1419
1559
!(isaFamily != AMD::ISAFamily::CDNA4 &&
1420
- (llvm::isa<Float8E5M2Type>(dstElementType)));
1560
+ (llvm::isa<Float8E5M2Type, Float8E4M3FNType >(dstElementType)));
1421
1561
1422
1562
// fp8/bf8->f32, if neither nanoo fp8/bf8 on CDNA3 nor ocp fp8/bf8 on CDNA4,
1423
1563
// is done in two steps: fp8/bf8->fp16 and fp16->fp32
0 commit comments