@@ -31,8 +31,10 @@ cvtScalePkUpcastFromFp8(Location loc, ConversionPatternRewriter &rewriter,
3131 auto b = TritonLLVMOpBuilder (loc, rewriter);
3232 auto fp8x4VecTy = vec_ty (i8_ty, 4 );
3333 Value fp8x4Vec = b.undef (fp8x4VecTy);
34- fp8x4Vec = b.insert_element (fp8x4VecTy, fp8x4Vec, v0, b.i32_val (0 ));
35- fp8x4Vec = b.insert_element (fp8x4VecTy, fp8x4Vec, v1, b.i32_val (1 ));
34+ auto idx0 = b.i32_val (0 );
35+ auto idx1 = b.i32_val (1 );
36+ fp8x4Vec = b.insert_element (fp8x4VecTy, fp8x4Vec, v0, idx0);
37+ fp8x4Vec = b.insert_element (fp8x4VecTy, fp8x4Vec, v1, idx1);
3638 auto i32v = b.bitcast (fp8x4Vec, i32_ty);
3739
3840 auto resType = i32_ty;
@@ -41,18 +43,22 @@ cvtScalePkUpcastFromFp8(Location loc, ConversionPatternRewriter &rewriter,
4143 std::is_same_v<convertOp, ROCDL::CvtScaleF32PkF32Bf8Op>) {
4244 resType = i64_ty;
4345 dstType = f32_ty;
44- } else {
46+ } else if constexpr (std::is_same_v<convertOp,
47+ ROCDL::CvtScaleF32PkF16Fp8Op> ||
48+ std::is_same_v<convertOp,
49+ ROCDL::CvtScaleF32PkF16Bf8Op>) {
4550 resType = i32_ty;
4651 dstType = f16_ty;
52+ } else {
53+ resType = i32_ty;
54+ dstType = bf16_ty;
4755 }
4856 Value scale = b.f32_val (1 );
4957 Value select = b.false_val ();
5058 auto result = rewriter.create <convertOp>(loc, resType, i32v, scale, select);
5159 auto retVecTy = vec_ty (dstType, 2 );
5260 auto retVec = b.bitcast (result, retVecTy);
5361 SmallVector<Value> ret (2 );
54- auto idx0 = b.i32_val (0 );
55- auto idx1 = b.i32_val (1 );
5662 ret[0 ] = b.extract_element (dstType, retVec, idx0);
5763 ret[1 ] = b.extract_element (dstType, retVec, idx1);
5864 return ret;
@@ -77,8 +83,10 @@ cvtScalePkDowncastToFp8(Location loc, ConversionPatternRewriter &rewriter,
7783 } else {
7884 Type v2F16Ty = vec_ty (v0.getType (), 2 );
7985 Value srcVec = b.undef (v2F16Ty);
80- srcVec = b.insert_element (v2F16Ty, srcVec, v0, b.i32_val (0 ));
81- srcVec = b.insert_element (v2F16Ty, srcVec, v1, b.i32_val (1 ));
86+ auto idx0 = b.i32_val (0 );
87+ auto idx1 = b.i32_val (1 );
88+ srcVec = b.insert_element (v2F16Ty, srcVec, v0, idx0);
89+ srcVec = b.insert_element (v2F16Ty, srcVec, v1, idx1);
8290 result = rewriter.create <convertOp>(loc, v2I16Ty, v2I16Vec, srcVec, scale,
8391 select);
8492 }
@@ -698,9 +706,10 @@ ConverterT Fp8E5M2FNUZ_to_Fp16(AMD::ISAFamily isaFamily) {
698706 : Fp8E5M2FNUZ_to_Fp16_SW;
699707}
700708
701- static SmallVector<Value> Fp8E5M2_to_Bf16 (Location loc,
702- ConversionPatternRewriter &rewriter,
703- const SmallVector<Value> &v) {
709+ // OCP Bf8 -> Bf16
710+ static SmallVector<Value>
711+ Fp8E5M2_to_Bf16_SW (Location loc, ConversionPatternRewriter &rewriter,
712+ const SmallVector<Value> &v) {
704713 auto b = TritonLLVMOpBuilder (loc, rewriter);
705714 auto fp8x4VecTy = vec_ty (i8_ty, 4 );
706715 Value a0 = b.undef (fp8x4VecTy);
@@ -761,6 +770,19 @@ static SmallVector<Value> Fp8E5M2_to_Bf16(Location loc,
761770 b.extract_element (bf16_ty, out1, b.i32_val (1 ))};
762771}
763772
773+ static SmallVector<Value>
774+ Fp8E5M2_to_Bf16_HW (Location loc, ConversionPatternRewriter &rewriter,
775+ const SmallVector<Value> &v) {
776+ assert (v.size () == 2 );
777+ return cvtScalePkUpcastFromFp8<ROCDL::CvtScaleF32PkBf16Bf8Op>(loc, rewriter,
778+ v[0 ], v[1 ]);
779+ }
780+
781+ ConverterT Fp8E5M2_to_Bf16 (AMD::ISAFamily isaFamily) {
782+ return isaFamily == AMD::ISAFamily::CDNA4 ? Fp8E5M2_to_Bf16_HW
783+ : Fp8E5M2_to_Bf16_SW;
784+ }
785+
764786// Bf16 -> OCP Bf8
765787static SmallVector<Value>
766788Bf16_to_Fp8E5M2_SW (Location loc, ConversionPatternRewriter &rewriter,
@@ -869,9 +891,9 @@ static SmallVector<Value> Bf16_to_Fp8E4M3FN(Location loc,
869891}
870892
871893// fp8e4m3fn to bf16
872- static SmallVector<Value> Fp8E4M3FN_to_Bf16 (Location loc,
873- ConversionPatternRewriter &rewriter,
874- const SmallVector<Value> &v) {
894+ static SmallVector<Value>
895+ Fp8E4M3FN_to_Bf16_SW (Location loc, ConversionPatternRewriter &rewriter,
896+ const SmallVector<Value> &v) {
875897 auto b = TritonLLVMOpBuilder (loc, rewriter);
876898 auto fp8x4VecTy = vec_ty (i8_ty, 4 );
877899 Value a0 = b.undef (fp8x4VecTy);
@@ -904,6 +926,19 @@ static SmallVector<Value> Fp8E4M3FN_to_Bf16(Location loc,
904926 b.extract_element (bf16_ty, out0, b.i32_val (1 ))};
905927}
906928
929+ static SmallVector<Value>
930+ Fp8E4M3FN_to_Bf16_HW (Location loc, ConversionPatternRewriter &rewriter,
931+ const SmallVector<Value> &v) {
932+ assert (v.size () == 2 );
933+ return cvtScalePkUpcastFromFp8<ROCDL::CvtScaleF32PkBf16Fp8Op>(loc, rewriter,
934+ v[0 ], v[1 ]);
935+ }
936+
937+ ConverterT Fp8E4M3FN_to_Bf16 (AMD::ISAFamily isaFamily) {
938+ return isaFamily == AMD::ISAFamily::CDNA4 ? Fp8E4M3FN_to_Bf16_HW
939+ : Fp8E4M3FN_to_Bf16_SW;
940+ }
941+
907942// fp8e4m3fnuz to bf16
908943static SmallVector<Value>
909944Fp8E4M3FNUZ_to_Bf16 (Location loc, ConversionPatternRewriter &rewriter,
@@ -1130,9 +1165,10 @@ struct FpToFpOpConversion
11301165 Fp16_to_Fp8E5M2_RTNE (isaFamily)},
11311166 {{F16TyID, F8E5M2TyID, RoundingMode::RTZ}, Fp16_to_Fp8E5M2_RTZ},
11321167 // F8 -> BF16
1133- {{F8E5M2TyID, BF16TyID, undefRounding}, Fp8E5M2_to_Bf16},
1168+ {{F8E5M2TyID, BF16TyID, undefRounding}, Fp8E5M2_to_Bf16 (isaFamily) },
11341169 {{F8E5M2FNUZTyID, BF16TyID, undefRounding}, Fp8E5M2FNUZ_to_Bf16},
1135- {{F8E4M3FNTyID, BF16TyID, undefRounding}, Fp8E4M3FN_to_Bf16},
1170+ {{F8E4M3FNTyID, BF16TyID, undefRounding},
1171+ Fp8E4M3FN_to_Bf16 (isaFamily)},
11361172 {{F8E4M3FNUZTyID, BF16TyID, undefRounding}, Fp8E4M3FNUZ_to_Bf16},
11371173 // BF16 -> F8
11381174 {{BF16TyID, F8E5M2TyID, RoundingMode::RTNE},
@@ -1197,16 +1233,11 @@ struct FpToFpOpConversion
11971233 }
11981234
11991235 // numElements = 4 for conversions:
1200- // ocp bf8->bf16, or
1201- // ocp bf8->fp32/fp16 on non-CDNA4, or
1236+ // ocp bf8->fp32/fp16/bf16 on non-CDNA4, or
12021237 // fp32/bf16/fp16->ocp bf8 on non-CDNA4
12031238 // fp32/bf16/fp16->ocp bf8 (RTZ) on CDNA4
12041239 size_t numElements = 2 ;
12051240 if ((llvm::isa<Float8E5M2Type>(srcElementType) &&
1206- llvm::isa<BFloat16Type>(dstElementType)) ||
1207- (llvm::isa<Float8E5M2Type>(srcElementType) &&
1208- (llvm::isa<Float16Type>(dstElementType) ||
1209- llvm::isa<Float32Type>(dstElementType)) &&
12101241 isaFamily != AMD::ISAFamily::CDNA4) ||
12111242 (llvm::isa<Float8E5M2Type>(dstElementType) &&
12121243 isaFamily != AMD::ISAFamily::CDNA4) ||
0 commit comments