@@ -372,9 +372,9 @@ static SmallVector<Value> Fp32_to_Fp8E4M3FN(Location loc,
372372}
373373
374374// Convert Fp32 to OCP Bf8 on CDNA4
375- static SmallVector<Value> Fp32_to_Fp8E5M2 (Location loc,
376- ConversionPatternRewriter &rewriter,
377- const SmallVector<Value> &v) {
375+ static SmallVector<Value>
376+ Fp32_to_Fp8E5M2_RTNE (Location loc, ConversionPatternRewriter &rewriter,
377+ const SmallVector<Value> &v) {
378378 assert (v.size () == 2 );
379379 return cvtScalePkDowncastToFp8<ROCDL::CvtScaleF32PkBf8F32Op>(loc, rewriter,
380380 v[0 ], v[1 ]);
@@ -575,6 +575,43 @@ ConverterT Fp8E5M2_to_Fp16(AMD::ISAFamily isaFamily) {
575575 : Fp8E5M2_to_Fp16_SW;
576576}
577577
578+ static SmallVector<Value>
579+ convertFp32ToFp16RTZ (Location loc, ConversionPatternRewriter &rewriter,
580+ const SmallVector<Value> &v) {
581+ assert (v.size () == 2 );
582+
583+ auto b = TritonLLVMOpBuilder (loc, rewriter);
584+ Type v2f16Ty = vec_ty (f16_ty, 2 );
585+
586+ Value result;
587+ result = rewriter.create <ROCDL::CvtPkRtz>(loc, v2f16Ty, v[0 ], v[1 ]);
588+ SmallVector<Value> ret (2 );
589+ auto idx0 = b.i32_val (0 );
590+ auto idx1 = b.i32_val (1 );
591+ ret[0 ] = b.extract_element (f16_ty, result, idx0);
592+ ret[1 ] = b.extract_element (f16_ty, result, idx1);
593+ return ret;
594+ }
595+
596+ static SmallVector<Value>
597+ Fp32_to_Fp8E5M2_RTZ (Location loc, ConversionPatternRewriter &rewriter,
598+ const SmallVector<Value> &v) {
599+ assert (v.size () == 4 );
600+ SmallVector<Value> inVals (2 );
601+ inVals[0 ] = v[0 ];
602+ inVals[1 ] = v[1 ];
603+ auto f16Vec = convertFp32ToFp16RTZ (loc, rewriter, inVals);
604+ SmallVector<Value> vec (4 );
605+ vec[0 ] = f16Vec[0 ];
606+ vec[1 ] = f16Vec[1 ];
607+ inVals[0 ] = v[2 ];
608+ inVals[1 ] = v[3 ];
609+ f16Vec = convertFp32ToFp16RTZ (loc, rewriter, inVals);
610+ vec[2 ] = f16Vec[0 ];
611+ vec[3 ] = f16Vec[1 ];
612+ return Fp16_to_Fp8E5M2_RTZ (loc, rewriter, vec);
613+ }
614+
578615static Value convertBf16ToFp32 (Location loc,
579616 ConversionPatternRewriter &rewriter,
580617 const Value &v) {
@@ -670,8 +707,8 @@ Fp8E5M2FNUZ_to_Fp16_HW(Location loc, ConversionPatternRewriter &rewriter,
670707 cvtPkF8ToFp32<ROCDL::CvtPkF32Bf8Op>(loc, rewriter, v[0 ], v[1 ]);
671708
672709 // Convert fp32 to fp16
673- ret[0 ] = LLVM::AMD::cvtFp32ToFp16 (loc, rewriter, ret[0 ], RoundingMode::RTNE );
674- ret[1 ] = LLVM::AMD::cvtFp32ToFp16 (loc, rewriter, ret[1 ], RoundingMode::RTNE );
710+ ret[0 ] = LLVM::AMD::cvtFp32ToFp16RTNE (loc, rewriter, ret[0 ]);
711+ ret[1 ] = LLVM::AMD::cvtFp32ToFp16RTNE (loc, rewriter, ret[1 ]);
675712
676713 return ret;
677714}
@@ -1006,8 +1043,8 @@ Fp8E4M3FNUZ_to_Fp16_HW(Location loc, ConversionPatternRewriter &rewriter,
10061043 cvtPkF8ToFp32<ROCDL::CvtPkF32Fp8Op>(loc, rewriter, v[0 ], v[1 ]);
10071044
10081045 // Convert fp32 to fp16
1009- ret[0 ] = LLVM::AMD::cvtFp32ToFp16 (loc, rewriter, ret[0 ], RoundingMode::RTNE );
1010- ret[1 ] = LLVM::AMD::cvtFp32ToFp16 (loc, rewriter, ret[1 ], RoundingMode::RTNE );
1046+ ret[0 ] = LLVM::AMD::cvtFp32ToFp16RTNE (loc, rewriter, ret[0 ]);
1047+ ret[1 ] = LLVM::AMD::cvtFp32ToFp16RTNE (loc, rewriter, ret[1 ]);
10111048
10121049 return ret;
10131050}
@@ -1171,11 +1208,14 @@ struct FpToFpOpConversion
11711208 {{F32TyID, F8E5M2FNUZTyID, RoundingMode::RTNE},
11721209 Fp32_to_Fp8E5M2FNUZ},
11731210 {{F32TyID, F8E4M3FNTyID, RoundingMode::RTNE}, Fp32_to_Fp8E4M3FN},
1174- {{F32TyID, F8E5M2TyID, RoundingMode::RTNE}, Fp32_to_Fp8E5M2},
1211+ {{F32TyID, F8E5M2TyID, RoundingMode::RTNE}, Fp32_to_Fp8E5M2_RTNE},
1212+ {{F32TyID, F8E5M2TyID, RoundingMode::RTZ}, Fp32_to_Fp8E5M2_RTZ},
11751213 {{F8E4M3FNUZTyID, F32TyID, undefRounding}, Fp8E4M3FNUZ_to_Fp32},
11761214 {{F8E5M2FNUZTyID, F32TyID, undefRounding}, Fp8E5M2FNUZ_to_Fp32},
11771215 {{F8E4M3FNTyID, F32TyID, undefRounding}, Fp8E4M3FN_to_Fp32},
11781216 {{F8E5M2TyID, F32TyID, undefRounding}, Fp8E5M2_to_Fp32},
1217+ // F32 -> F16 with RTZ
1218+ {{F32TyID, F16TyID, RoundingMode::RTZ}, convertFp32ToFp16RTZ},
11791219 };
11801220 std::tuple<TypeID, TypeID, RoundingMode> key = {
11811221 srcTy.getTypeID (), dstTy.getTypeID (),
@@ -1195,14 +1235,14 @@ struct FpToFpOpConversion
11951235 auto dstElementType = getElementType (op.getResult ());
11961236
11971237 auto roundingMode = op.getRounding ();
1198- if (srcElementType.isF32 () && dstElementType.isF16 ()) {
1238+ if (srcElementType.isF32 () && dstElementType.isF16 () &&
1239+ roundingMode.value () == RoundingMode::RTNE) {
11991240 assert (roundingMode.has_value () &&
12001241 " rounding mode must be specified for fp32->fp16 conversion" );
12011242 SmallVector<Value> outVals;
12021243 outVals.reserve (operands[0 ].size ());
12031244 for (Value v : operands[0 ]) {
1204- outVals.push_back (
1205- LLVM::AMD::cvtFp32ToFp16 (loc, rewriter, v, roundingMode.value ()));
1245+ outVals.push_back (LLVM::AMD::cvtFp32ToFp16RTNE (loc, rewriter, v));
12061246 }
12071247 return outVals;
12081248 }
@@ -1234,18 +1274,19 @@ struct FpToFpOpConversion
12341274 numElements = 4 ;
12351275 }
12361276
1237- // f32->fp8/bf8, if not nanoo fp8/bf8 on CDNA3 or ocp fp8/bf8 on CDNA4, is
1238- // done in two steps: f32->fp16 with rtne and fp16->fp8/bf8 with rtne
1277+ // f32->fp8/bf8 with rtne, if neither nanoo fp8/bf8 on CDNA3 nor ocp fp8/bf8
1278+ // on CDNA4, is done in two steps: f32->fp16 with rtne and fp16->fp8/bf8
1279+ // with rtne
12391280 bool useFP16IntermediateSrc =
1240- srcElementType.isF32 () &&
1281+ srcElementType.isF32 () && !dstElementType.isF16 () &&
1282+ roundingMode == RoundingMode::RTNE &&
12411283 !(isaFamily == AMD::ISAFamily::CDNA4 &&
1242- (llvm::isa<Float8E4M3FNType, Float8E5M2Type>(dstElementType)) &&
1243- roundingMode == RoundingMode::RTNE) &&
1284+ (llvm::isa<Float8E4M3FNType, Float8E5M2Type>(dstElementType))) &&
12441285 !(isaFamily == AMD::ISAFamily::CDNA3 &&
12451286 (llvm::isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(dstElementType)));
12461287
1247- // fp8/bf8->f32, if not nanoo fp8/bf8 on CDNA3 or ocp fp8/bf8 on CDNA4, is
1248- // done in two steps: fp8/bf8->fp16 and fp16->fp32
1288+ // fp8/bf8->f32, if neither nanoo fp8/bf8 on CDNA3 nor ocp fp8/bf8 on CDNA4,
1289+ // is done in two steps: fp8/bf8->fp16 and fp16->fp32
12491290 bool isDstFP32 = dstElementType.isF32 ();
12501291 bool useFP16IntermediateDst =
12511292 (isDstFP32 &&
@@ -1277,8 +1318,8 @@ struct FpToFpOpConversion
12771318 }
12781319 if (useFP16IntermediateSrc)
12791320 for (Value &v : inVals)
1280- v = LLVM::AMD::cvtFp32ToFp16 (loc, rewriter, v,
1281- roundingMode. value_or (RoundingMode::RTNE));
1321+ v = LLVM::AMD::cvtFp32ToFp16RTNE (loc, rewriter, v);
1322+
12821323 inVals.resize (numElements, b.undef (typeConverter->convertType (srcType)));
12831324 SmallVector<Value> outVals;
12841325 if (srcType != dstType) {
0 commit comments