@@ -108,40 +108,48 @@ cvtScalePkDowncastToFp8(Location loc, ConversionPatternRewriter &rewriter,
108
108
109
109
// Fp16 -> OCP Bf8 (RTNE)
110
110
111
- // FP8E5M2 is the open-compute standard FP8E5M2 format. NVIDIA GPU supports it
112
- // natively but we don't have hardware native support on CDNA3.
113
- //
114
- // The SW based downcast with RTNE is not fully functional for the denorm
115
- // values. We need rewrite it if we need to emulate this data type on AMDGPU.
116
111
static SmallVector<Value>
117
112
Fp16_to_Fp8E5M2_RTNE_SW (Location loc, ConversionPatternRewriter &rewriter,
118
113
const SmallVector<Value> &v) {
114
+ assert (v.size () == 4 );
119
115
auto b = TritonLLVMOpBuilder (loc, rewriter);
120
- auto fp16x2VecTy = vec_ty (f16_ty, 2 );
121
- Value fp16x2Vec0 = b.undef (fp16x2VecTy);
122
- Value fp16x2Vec1 = b.undef (fp16x2VecTy);
123
- fp16x2Vec0 = b.insert_element (fp16x2VecTy, fp16x2Vec0, v[0 ], b.i32_val (0 ));
124
- fp16x2Vec0 = b.insert_element (fp16x2VecTy, fp16x2Vec0, v[1 ], b.i32_val (1 ));
125
- fp16x2Vec1 = b.insert_element (fp16x2VecTy, fp16x2Vec1, v[2 ], b.i32_val (0 ));
126
- fp16x2Vec1 = b.insert_element (fp16x2VecTy, fp16x2Vec1, v[3 ], b.i32_val (1 ));
127
-
128
- Value a0 = b.bitcast (fp16x2Vec0, i32_ty);
129
- Value a1 = b.bitcast (fp16x2Vec1, i32_ty);
130
116
131
- a0 = b.and_ (i32_ty, a0, b.i32_val (0xfffefffe ));
132
- a1 = b.and_ (i32_ty, a1, b.i32_val (0xfffefffe ));
133
-
134
- a0 = b.add (i32_ty, a0, b.i32_val (0x00800080 ));
135
- a1 = b.add (i32_ty, a1, b.i32_val (0x00800080 ));
136
-
137
- auto fp8x4VecTy = vec_ty (i8_ty, 4 );
138
- a0 = b.bitcast (a0, fp8x4VecTy);
139
- a1 = b.bitcast (a1, fp8x4VecTy);
117
+ SmallVector<Value> result (4 );
118
+ for (size_t i = 0 ; i < 4 ; ++i) {
119
+ Value fp16 = v[i];
120
+ Value i16 = b.bitcast (fp16, i16_ty);
121
+
122
+ Value s = b.and_ (i16_ty, i16 , b.i16_val (0x8000 ));
123
+ Value exp =
124
+ b.and_ (i16_ty, b.lshr (i16_ty, i16 , b.i16_val (10 )), b.i16_val (0x1F ));
125
+ Value man = b.and_ (i16_ty, i16 , b.i16_val (0x03FF ));
126
+ Value sig = b.and_ (i16_ty, i16 , b.i16_val (0x7FFF ));
127
+
128
+ // Round 10-bit mantissa to 2-bit nearest, ties to even
129
+ Value bias = b.add (
130
+ i16_ty,
131
+ b.lshr (i16_ty, b.and_ (i16_ty, sig, b.i16_val (0x0100 )), b.i16_val (8 )),
132
+ b.i16_val (0x007F ));
133
+ i16 = b.add (i16_ty, sig, bias);
134
+
135
+ // Handle overflow using saturation mode, by setting sig to be the max.
136
+ // Any number equal or larger than 0x7B80 after rounding (including
137
+ // infinite 0x7C00) will cause overflow
138
+ i16 = b.select (b.icmp_uge (sig, b.i16_val (0x7B80 )), b.i16_val (0x7B00 ), i16 );
139
+
140
+ // Handle NaN value by keeping it Nan
141
+ i16 = b.select (
142
+ b.and_ (b.icmp_eq (exp, b.i16_val (0x1F )), b.icmp_ne (man, b.i16_val (0x0 ))),
143
+ b.i16_val (0x7E00 ), i16 );
144
+
145
+ // Add sign bit
146
+ i16 = b.or_ (i16_ty, s, i16 );
147
+
148
+ // Truncate to 8-bit
149
+ result[i] = b.trunc (i8_ty, b.lshr (i16_ty, i16 , b.i16_val (8 )));
150
+ }
140
151
141
- return {b.extract_element (i8_ty, a0, b.i32_val (1 )),
142
- b.extract_element (i8_ty, a0, b.i32_val (3 )),
143
- b.extract_element (i8_ty, a1, b.i32_val (1 )),
144
- b.extract_element (i8_ty, a1, b.i32_val (3 ))};
152
+ return result;
145
153
}
146
154
147
155
static SmallVector<Value>
@@ -377,15 +385,98 @@ static SmallVector<Value> Fp32_to_Fp8E4M3FN(Location loc,
377
385
v[0 ], v[1 ]);
378
386
}
379
387
380
- // Convert Fp32 to OCP Bf8 on CDNA4
388
+ // Fp32 -> OCP Bf8 (RTNE)
389
+
381
390
static SmallVector<Value>
382
- Fp32_to_Fp8E5M2_RTNE (Location loc, ConversionPatternRewriter &rewriter,
383
- const SmallVector<Value> &v) {
391
+ Fp32_to_Fp8E5M2_RTNE_SW (Location loc, ConversionPatternRewriter &rewriter,
392
+ const SmallVector<Value> &v) {
393
+ assert (v.size () == 4 );
394
+ auto b = TritonLLVMOpBuilder (loc, rewriter);
395
+
396
+ SmallVector<Value> result (4 );
397
+ for (size_t i = 0 ; i < 4 ; ++i) {
398
+ Value fp32 = v[i];
399
+ Value i32 = b.bitcast (fp32, i32_ty);
400
+
401
+ Value s = b.and_ (i32_ty, i32 , b.i32_val (0x80000000 ));
402
+ Value exp =
403
+ b.and_ (i32_ty, b.lshr (i32_ty, i32 , b.i32_val (23 )), b.i32_val (0xFF ));
404
+ Value man = b.and_ (i32_ty, i32 , b.i32_val (0x007FFFFF ));
405
+
406
+ // Convert 8-bit exponent to 5-bit
407
+ Value exp5 = b.select (b.icmp_ult (exp, b.i32_val (0x71 )), b.i32_val (0 ),
408
+ b.sub (i32_ty, exp, b.i32_val (0x70 )));
409
+
410
+ // Handle subnormal values (exp5 = 0)
411
+ // - exp < 0x6e: mantissa = 0x00000000 (0)
412
+ // - exp == 0x6e: mantissa = 0x00000000 (0),
413
+ // 0x00200000 (1/4)
414
+ // - exp == 0x6f: mantissa = 0x00200000 (1/4),
415
+ // 0x00400000 (1/2)
416
+ // - exp == 0x70: mantissa = 0x00400000 (1/2),
417
+ // 0x00600000 (3/4),
418
+ // 0x00800000 (1)
419
+ man = b.select (b.icmp_ult (exp, b.i32_val (0x6e )), b.i32_val (0 ), man);
420
+ man = b.select (b.icmp_eq (exp, b.i32_val (0x6e )),
421
+ b.select (b.icmp_ne (man, b.i32_val (0 )), b.i32_val (0x00200000 ),
422
+ b.i32_val (0 )),
423
+ man);
424
+ man = b.select (b.icmp_eq (exp, b.i32_val (0x6f )),
425
+ b.select (b.icmp_uge (man, b.i32_val (0x00400000 )),
426
+ b.i32_val (0x00400000 ), b.i32_val (0x00200000 )),
427
+ man);
428
+ man = b.select (
429
+ b.icmp_eq (exp, b.i32_val (0x70 )),
430
+ b.select (b.icmp_ugt (man, b.i32_val (0x00200000 )),
431
+ b.select (b.icmp_uge (man, b.i32_val (0x00600000 )),
432
+ b.i32_val (0x00800000 ), b.i32_val (0x00600000 )),
433
+ b.i32_val (0x00400000 )),
434
+ man);
435
+
436
+ // Round 23-bit mantissa to 2-bit nearest, ties to even
437
+ Value sig = b.or_ (i32_ty, b.shl (i32_ty, exp5, b.i32_val (23 )), man);
438
+ Value bias =
439
+ b.add (i32_ty,
440
+ b.lshr (i32_ty, b.and_ (i32_ty, sig, b.i32_val (0x00200000 )),
441
+ b.i32_val (21 )),
442
+ b.i32_val (0x000FFFFF ));
443
+ i32 = b.add (i32_ty, sig, bias);
444
+
445
+ // Handle overflow using saturation mode, by setting sig to be the max.
446
+ // Overflow will happe for the following cases:
447
+ // - Any number equal or larger than 0x0F700000 after rounding
448
+ // - Exponent larged than 0x8E (including infinite 0xFF)
449
+ i32 = b.select (b.or_ (b.icmp_ugt (exp, b.i32_val (0x8E )),
450
+ b.icmp_uge (sig, b.i32_val (0x0F700000 ))),
451
+ b.i32_val (0x0F7FFFFF ), i32 );
452
+
453
+ // Handle NaN value by keeping it Nan
454
+ i32 = b.select (
455
+ b.and_ (b.icmp_eq (exp, b.i32_val (0xFF )), b.icmp_ne (man, b.i32_val (0x0 ))),
456
+ b.i32_val (0x0FC00000 ), i32 );
457
+
458
+ // Add sign bit
459
+ i32 = b.or_ (i32_ty, b.lshr (i32_ty, s, b.i32_val (3 )), i32 );
460
+
461
+ // Truncate to 8-bit
462
+ result[i] = b.trunc (i8_ty, b.lshr (i32_ty, i32 , b.i32_val (21 )));
463
+ }
464
+ return result;
465
+ }
466
+
467
+ static SmallVector<Value>
468
+ Fp32_to_Fp8E5M2_RTNE_HW (Location loc, ConversionPatternRewriter &rewriter,
469
+ const SmallVector<Value> &v) {
384
470
assert (v.size () == 2 );
385
471
return cvtScalePkDowncastToFp8<ROCDL::CvtScaleF32PkBf8F32Op>(loc, rewriter,
386
472
v[0 ], v[1 ]);
387
473
}
388
474
475
+ ConverterT Fp32_to_Fp8E5M2_RTNE (AMD::ISAFamily isaFamily) {
476
+ return isaFamily == AMD::ISAFamily::CDNA4 ? Fp32_to_Fp8E5M2_RTNE_HW
477
+ : Fp32_to_Fp8E5M2_RTNE_SW;
478
+ }
479
+
389
480
// Fp32 -> Nanoo Bf8 on CDNA3
390
481
static SmallVector<Value>
391
482
Fp32_to_Fp8E5M2FNUZ (Location loc, ConversionPatternRewriter &rewriter,
@@ -853,86 +944,77 @@ ConverterT Fp8E5M2_to_Bf16(AMD::ISAFamily isaFamily) {
853
944
static SmallVector<Value>
854
945
Bf16_to_Fp8E5M2_SW (Location loc, ConversionPatternRewriter &rewriter,
855
946
const SmallVector<Value> &v) {
947
+ assert (v.size () == 4 );
856
948
auto b = TritonLLVMOpBuilder (loc, rewriter);
857
- auto bf16x2VecTy = vec_ty (bf16_ty, 2 );
858
- Value bf16x2Vec0 = b.undef (bf16x2VecTy);
859
- Value bf16x2Vec1 = b.undef (bf16x2VecTy);
860
- bf16x2Vec0 = b.insert_element (bf16x2VecTy, bf16x2Vec0, v[0 ], b.i32_val (0 ));
861
- bf16x2Vec0 = b.insert_element (bf16x2VecTy, bf16x2Vec0, v[1 ], b.i32_val (1 ));
862
- bf16x2Vec1 = b.insert_element (bf16x2VecTy, bf16x2Vec1, v[2 ], b.i32_val (0 ));
863
- bf16x2Vec1 = b.insert_element (bf16x2VecTy, bf16x2Vec1, v[3 ], b.i32_val (1 ));
864
- bf16x2Vec0 = b.bitcast (bf16x2Vec0, i32_ty);
865
- bf16x2Vec1 = b.bitcast (bf16x2Vec1, i32_ty);
866
-
867
- Value sign0 = b.and_ (i32_ty, bf16x2Vec0, b.i32_val (0x80008000 ));
868
- Value sign1 = b.and_ (i32_ty, bf16x2Vec1, b.i32_val (0x80008000 ));
869
- auto fp8x4VecTy = vec_ty (i8_ty, 4 );
870
- Value sign = b.undef (fp8x4VecTy);
871
- sign0 = b.bitcast (sign0, fp8x4VecTy);
872
- sign1 = b.bitcast (sign1, fp8x4VecTy);
873
- sign = b.insert_element (fp8x4VecTy, sign,
874
- b.extract_element (i8_ty, sign0, b.i32_val (1 )),
875
- b.i32_val (0 ));
876
- sign = b.insert_element (fp8x4VecTy, sign,
877
- b.extract_element (i8_ty, sign0, b.i32_val (3 )),
878
- b.i32_val (1 ));
879
- sign = b.insert_element (fp8x4VecTy, sign,
880
- b.extract_element (i8_ty, sign1, b.i32_val (1 )),
881
- b.i32_val (2 ));
882
- sign = b.insert_element (fp8x4VecTy, sign,
883
- b.extract_element (i8_ty, sign1, b.i32_val (3 )),
884
- b.i32_val (3 ));
885
- sign = b.bitcast (sign, i32_ty);
886
-
887
- Value nosign0 = b.and_ (i32_ty, bf16x2Vec0, b.i32_val (0x7fff7fff ));
888
- Value nosign1 = b.and_ (i32_ty, bf16x2Vec1, b.i32_val (0x7fff7fff ));
889
-
890
- Value nosign_0_0 = b.and_ (i32_ty, nosign0, b.i32_val (0xffff0000 ));
891
- nosign_0_0 = b.umax (i32_ty, nosign_0_0, b.i32_val (0x38000000 ));
892
- nosign_0_0 = b.umin (i32_ty, nosign_0_0, b.i32_val (0x57e00000 ));
893
- Value nosign_0_1 = b.and_ (i32_ty, nosign0, b.i32_val (0x0000ffff ));
894
- nosign_0_1 = b.umax (i32_ty, nosign_0_1, b.i32_val (0x3800 ));
895
- nosign_0_1 = b.umin (i32_ty, nosign_0_1, b.i32_val (0x57e0 ));
896
- nosign0 = b.or_ (i32_ty, nosign_0_0, nosign_0_1);
897
-
898
- Value nosign_1_0 = b.and_ (i32_ty, nosign1, b.i32_val (0xffff0000 ));
899
- nosign_1_0 = b.umax (i32_ty, nosign_1_0, b.i32_val (0x38000000 ));
900
- nosign_1_0 = b.umin (i32_ty, nosign_1_0, b.i32_val (0x57e00000 ));
901
- Value nosign_1_1 = b.and_ (i32_ty, nosign1, b.i32_val (0x0000ffff ));
902
- nosign_1_1 = b.umax (i32_ty, nosign_1_1, b.i32_val (0x3800 ));
903
- nosign_1_1 = b.umin (i32_ty, nosign_1_1, b.i32_val (0x57e0 ));
904
- nosign1 = b.or_ (i32_ty, nosign_1_0, nosign_1_1);
905
-
906
- nosign0 = b.add (i32_ty, nosign0, b.i32_val (0x00100010 ));
907
- nosign1 = b.add (i32_ty, nosign1, b.i32_val (0x00100010 ));
908
- nosign0 = b.sub (i32_ty, nosign0, b.i32_val (0x38003800 ));
909
- nosign1 = b.sub (i32_ty, nosign1, b.i32_val (0x38003800 ));
910
- nosign0 = b.shl (i32_ty, nosign0, b.i32_val (3 ));
911
- nosign1 = b.shl (i32_ty, nosign1, b.i32_val (3 ));
912
-
913
- nosign0 = b.bitcast (nosign0, fp8x4VecTy);
914
- nosign1 = b.bitcast (nosign1, fp8x4VecTy);
915
- Value nosign = b.undef (fp8x4VecTy);
916
- nosign = b.insert_element (fp8x4VecTy, nosign,
917
- b.extract_element (i8_ty, nosign0, b.i32_val (1 )),
918
- b.i32_val (0 ));
919
- nosign = b.insert_element (fp8x4VecTy, nosign,
920
- b.extract_element (i8_ty, nosign0, b.i32_val (3 )),
921
- b.i32_val (1 ));
922
- nosign = b.insert_element (fp8x4VecTy, nosign,
923
- b.extract_element (i8_ty, nosign1, b.i32_val (1 )),
924
- b.i32_val (2 ));
925
- nosign = b.insert_element (fp8x4VecTy, nosign,
926
- b.extract_element (i8_ty, nosign1, b.i32_val (3 )),
927
- b.i32_val (3 ));
928
- nosign = b.bitcast (nosign, i32_ty);
929
-
930
- Value fp8x4Vec = b.or_ (i32_ty, nosign, sign);
931
- fp8x4Vec = b.bitcast (fp8x4Vec, fp8x4VecTy);
932
- return {b.extract_element (i8_ty, fp8x4Vec, b.i32_val (0 )),
933
- b.extract_element (i8_ty, fp8x4Vec, b.i32_val (1 )),
934
- b.extract_element (i8_ty, fp8x4Vec, b.i32_val (2 )),
935
- b.extract_element (i8_ty, fp8x4Vec, b.i32_val (3 ))};
949
+
950
+ SmallVector<Value> result (4 );
951
+ for (size_t i = 0 ; i < 4 ; ++i) {
952
+ Value fp16 = v[i];
953
+ Value i16 = b.bitcast (fp16, i16_ty);
954
+
955
+ Value s = b.and_ (i16_ty, i16 , b.i16_val (0x8000 ));
956
+ Value exp =
957
+ b.and_ (i16_ty, b.lshr (i16_ty, i16 , b.i16_val (7 )), b.i16_val (0xFF ));
958
+ Value man = b.and_ (i16_ty, i16 , b.i16_val (0x7F ));
959
+
960
+ // Convert 8-bit exponent to 5-bit exponent
961
+ Value exp5 = b.select (b.icmp_ult (exp, b.i16_val (0x71 )), b.i16_val (0 ),
962
+ b.sub (i16_ty, exp, b.i16_val (0x70 )));
963
+
964
+ // Handle subnormal values (exp5 = 0)
965
+ // - exp < 0x6e: mantissa = 0x0000 (0)
966
+ // - exp == 0x6e: mantissa = 0x0000 (0),
967
+ // 0x0020 (1/4)
968
+ // - exp == 0x6f: mantissa = 0x0020 (1/4),
969
+ // 0x0040 (1/2)
970
+ // - exp == 0x70: mantissa = 0x0040 (1/2),
971
+ // 0x0060 (3/4),
972
+ // 0x0080 (1)
973
+ man = b.select (b.icmp_ult (exp, b.i16_val (0x6e )), b.i16_val (0 ), man);
974
+ man = b.select (
975
+ b.icmp_eq (exp, b.i16_val (0x6e )),
976
+ b.select (b.icmp_ne (man, b.i16_val (0 )), b.i16_val (0x0020 ), b.i16_val (0 )),
977
+ man);
978
+ man = b.select (b.icmp_eq (exp, b.i16_val (0x6f )),
979
+ b.select (b.icmp_uge (man, b.i16_val (0x0040 )),
980
+ b.i16_val (0x0040 ), b.i16_val (0x0020 )),
981
+ man);
982
+ man = b.select (b.icmp_eq (exp, b.i16_val (0x70 )),
983
+ b.select (b.icmp_ugt (man, b.i16_val (0x0020 )),
984
+ b.select (b.icmp_uge (man, b.i16_val (0x0060 )),
985
+ b.i16_val (0x0080 ), b.i16_val (0x0060 )),
986
+ b.i16_val (0x0040 )),
987
+ man);
988
+
989
+ // Round 7-bit mantissa to 2-bit
990
+ Value sig = b.or_ (i16_ty, b.shl (i16_ty, exp5, b.i16_val (7 )), man);
991
+ Value bias = b.add (
992
+ i16_ty,
993
+ b.lshr (i16_ty, b.and_ (i16_ty, sig, b.i16_val (0x0020 )), b.i16_val (5 )),
994
+ b.i16_val (0x000F ));
995
+ i16 = b.add (i16_ty, sig, bias);
996
+
997
+ // Handle overflow using saturation mode, by setting sig to be the max.
998
+ // Overflow will happe for the following cases:
999
+ // - Any number equal or larger than 0x0F70 after rounding
1000
+ // - Exponent larged than 0x8E (including infinite 0xFF)
1001
+ i16 = b.select (b.or_ (b.icmp_ugt (exp, b.i16_val (0x8E )),
1002
+ b.icmp_uge (sig, b.i16_val (0x0F70 ))),
1003
+ b.i16_val (0x0F7F ), i16 );
1004
+
1005
+ // Handle NaN value by keeping it Nan
1006
+ i16 = b.select (
1007
+ b.and_ (b.icmp_eq (exp, b.i16_val (0xFF )), b.icmp_ne (man, b.i16_val (0x0 ))),
1008
+ b.i16_val (0x0FC0 ), i16 );
1009
+
1010
+ // Add sign bit
1011
+ i16 = b.or_ (i16_ty, b.lshr (i16_ty, s, b.i16_val (3 )), i16 );
1012
+
1013
+ // Truncate to 8-bit
1014
+ result[i] = b.trunc (i8_ty, b.lshr (i16_ty, i16 , b.i16_val (5 )));
1015
+ }
1016
+
1017
+ return result;
936
1018
}
937
1019
938
1020
static SmallVector<Value>
@@ -1262,7 +1344,8 @@ struct FpToFpOpConversion
1262
1344
{{F32TyID, F8E5M2FNUZTyID, RoundingMode::RTNE},
1263
1345
Fp32_to_Fp8E5M2FNUZ},
1264
1346
{{F32TyID, F8E4M3FNTyID, RoundingMode::RTNE}, Fp32_to_Fp8E4M3FN},
1265
- {{F32TyID, F8E5M2TyID, RoundingMode::RTNE}, Fp32_to_Fp8E5M2_RTNE},
1347
+ {{F32TyID, F8E5M2TyID, RoundingMode::RTNE},
1348
+ Fp32_to_Fp8E5M2_RTNE (isaFamily)},
1266
1349
{{F32TyID, F8E5M2TyID, RoundingMode::RTZ}, Fp32_to_Fp8E5M2_RTZ},
1267
1350
{{F8E4M3FNUZTyID, F32TyID, undefRounding}, Fp8E4M3FNUZ_to_Fp32},
1268
1351
{{F8E5M2FNUZTyID, F32TyID, undefRounding}, Fp8E5M2FNUZ_to_Fp32},
@@ -1318,16 +1401,23 @@ struct FpToFpOpConversion
1318
1401
numElements = 4 ;
1319
1402
}
1320
1403
1321
- // f32->fp8/bf8 with rtne, if neither nanoo fp8/bf8 on CDNA3 nor ocp fp8/bf8
1322
- // on CDNA4, is done in two steps: f32->fp16 with rtne and fp16->fp8/bf8
1323
- // with rtne
1404
+ // fp32 -> fp8 with rtne can be done in two steps:
1405
+ // - fp32 -> fp16 with rtne and
1406
+ // - fp16 -> fp8 with rtne
1407
+ // with the following exceptions:
1408
+ // 1. fp32 -> ocp fp8/bf8 on CDNA4: has hardware support
1409
+ // 2. fp32 -> nanoo fp8/bf8 on non-CDNA4: has hardware support
1410
+ // 3. fp32 -> ocp bf8 on non-CDNA4: has software support
1324
1411
bool useFP16IntermediateSrc =
1325
1412
srcElementType.isF32 () && !dstElementType.isF16 () &&
1326
1413
roundingMode == RoundingMode::RTNE &&
1327
1414
!(isaFamily == AMD::ISAFamily::CDNA4 &&
1328
1415
(llvm::isa<Float8E4M3FNType, Float8E5M2Type>(dstElementType))) &&
1329
1416
!(isaFamily == AMD::ISAFamily::CDNA3 &&
1330
- (llvm::isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(dstElementType)));
1417
+ (llvm::isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(
1418
+ dstElementType))) &&
1419
+ !(isaFamily != AMD::ISAFamily::CDNA4 &&
1420
+ (llvm::isa<Float8E5M2Type>(dstElementType)));
1331
1421
1332
1422
// fp8/bf8->f32, if neither nanoo fp8/bf8 on CDNA3 nor ocp fp8/bf8 on CDNA4,
1333
1423
// is done in two steps: fp8/bf8->fp16 and fp16->fp32
0 commit comments