@@ -65,6 +65,12 @@ template <typename FPType> struct FPTypeInfo {
65
65
if (dstTyID == TypeID::get<Float8E5M2Type>())
66
66
return {0x0080 , 0x0180 , 0x0200 , 0x0380 };
67
67
}
68
+ if constexpr (std::is_same_v<FPType, BFloat16Type>) {
69
+ if (dstTyID == TypeID::get<Float8E4M3FNType>())
70
+ return {0x3a80 , 0x3b40 , 0x3ba0 , 0x3be0 , 0x3c10 , 0x3c30 , 0x3c50 , 0x3c70 };
71
+ if (dstTyID == TypeID::get<Float8E5M2Type>())
72
+ return {0x3700 , 0x37c0 , 0x3820 , 0x3860 };
73
+ }
68
74
return {};
69
75
}
70
76
@@ -278,13 +284,16 @@ static Value Fp_to_Fp8E4M3FN_RTNE_oneValue(Location loc,
278
284
ConversionPatternRewriter &rewriter,
279
285
Value v) {
280
286
static_assert ((std::is_same_v<SrcFPType, Float32Type>) ||
281
- (std::is_same_v<SrcFPType, Float16Type>));
287
+ (std::is_same_v<SrcFPType, Float16Type>) ||
288
+ (std::is_same_v<SrcFPType, BFloat16Type>));
282
289
auto b = TritonLLVMOpBuilder (loc, rewriter);
283
290
const llvm::fltSemantics *srcSemantic = nullptr ;
284
291
if constexpr (std::is_same_v<SrcFPType, Float32Type>)
285
292
srcSemantic = &llvm::APFloat::IEEEsingle ();
286
- else
293
+ else if constexpr (std::is_same_v<SrcFPType, Float16Type>)
287
294
srcSemantic = &llvm::APFloat::IEEEhalf ();
295
+ else
296
+ srcSemantic = &llvm::APFloat::BFloat ();
288
297
auto srcWidth = llvm::APFloat::getSizeInBits (*srcSemantic);
289
298
auto srcMantissaBits = llvm::APFloat::semanticsPrecision (*srcSemantic) - 1 ;
290
299
auto srcExponentBits = srcWidth - srcMantissaBits - 1 ;
@@ -365,8 +374,10 @@ static Value Fp_to_Fp8E4M3FN_RTNE_oneValue(Location loc,
365
374
// number(including infinity) after rounding in FP8E4M3
366
375
if constexpr (std::is_same_v<SrcFPType, Float32Type>)
367
376
dstMaxOfSrcType |= 0x7ffff ;
368
- else
377
+ else if constexpr (std::is_same_v<SrcFPType, Float16Type>)
369
378
dstMaxOfSrcType |= 0x7f ;
379
+ else
380
+ dstMaxOfSrcType |= 0x7 ;
370
381
Value isOverflowOrInf =
371
382
b.icmp_ugt (intVal, srcFpInfo.toLLVMIntValue (dstMaxOfSrcType));
372
383
vFp8 =
@@ -1168,15 +1179,30 @@ static ConverterT Bf16_to_Fp8E5M2(AMD::ISAFamily isaFamily) {
1168
1179
return isaFamily == AMD::ISAFamily::CDNA4 ? Bf16_to_Fp8E5M2_HW
1169
1180
: Bf16_to_Fp8E5M2_SW;
1170
1181
}
1171
- // Bf16 -> OCP Fp8
1172
- static SmallVector<Value> Bf16_to_Fp8E4M3FN (Location loc,
1173
- ConversionPatternRewriter &rewriter,
1174
- const SmallVector<Value> &v) {
1182
+
1183
+ // Bf16 -> OCP Fp8 using RTNE
1184
+ static SmallVector<Value>
1185
+ Bf16_to_Fp8E4M3FN_RTNE_SW (Location loc, ConversionPatternRewriter &rewriter,
1186
+ const SmallVector<Value> &v) {
1187
+ SmallVector<Value> result (2 );
1188
+ result[0 ] = Fp_to_Fp8E4M3FN_RTNE_oneValue<BFloat16Type>(loc, rewriter, v[0 ]);
1189
+ result[1 ] = Fp_to_Fp8E4M3FN_RTNE_oneValue<BFloat16Type>(loc, rewriter, v[1 ]);
1190
+ return result;
1191
+ }
1192
+
1193
+ static SmallVector<Value>
1194
+ Bf16_to_Fp8E4M3FN_RTNE_HW (Location loc, ConversionPatternRewriter &rewriter,
1195
+ const SmallVector<Value> &v) {
1175
1196
assert (v.size () == 2 );
1176
1197
return cvtScalePkDowncastToFp8<ROCDL::CvtScaleF32PkFp8Bf16Op>(loc, rewriter,
1177
1198
v[0 ], v[1 ]);
1178
1199
}
1179
1200
1201
+ ConverterT Bf16_to_Fp8E4M3FN (AMD::ISAFamily isaFamily) {
1202
+ return isaFamily == AMD::ISAFamily::CDNA4 ? Bf16_to_Fp8E4M3FN_RTNE_HW
1203
+ : Bf16_to_Fp8E4M3FN_RTNE_SW;
1204
+ }
1205
+
1180
1206
// fp8e4m3fn to bf16
1181
1207
static SmallVector<Value>
1182
1208
Fp8E4M3FN_to_Bf16_SW (Location loc, ConversionPatternRewriter &rewriter,
@@ -1472,7 +1498,8 @@ struct FpToFpOpConversion
1472
1498
// BF16 -> F8
1473
1499
{{BF16TyID, F8E5M2TyID, RoundingMode::RTNE},
1474
1500
Bf16_to_Fp8E5M2 (isaFamily)},
1475
- {{BF16TyID, F8E4M3FNTyID, RoundingMode::RTNE}, Bf16_to_Fp8E4M3FN},
1501
+ {{BF16TyID, F8E4M3FNTyID, RoundingMode::RTNE},
1502
+ Bf16_to_Fp8E4M3FN (isaFamily)},
1476
1503
{{BF16TyID, F8E5M2FNUZTyID, RoundingMode::RTNE},
1477
1504
Bf16_to_Fp8E5M2FNUZ},
1478
1505
{{BF16TyID, F8E4M3FNUZTyID, RoundingMode::RTNE},
0 commit comments