diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp index a583a5cb990e7..21f8c7cfeec1f 100644 --- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp @@ -911,8 +911,7 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM, setOperationAction(ISD::MUL, MVT::i1, Promote); if (Subtarget->hasBF16ConversionInsts()) { - setOperationAction(ISD::FP_ROUND, MVT::v2bf16, Legal); - setOperationAction(ISD::FP_ROUND, MVT::bf16, Legal); + setOperationAction(ISD::FP_ROUND, {MVT::bf16, MVT::v2bf16}, Custom); setOperationAction(ISD::BUILD_VECTOR, MVT::v2bf16, Legal); } @@ -6888,23 +6887,34 @@ SDValue SITargetLowering::getFPExtOrFPRound(SelectionDAG &DAG, SDValue Op, } SDValue SITargetLowering::lowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const { - assert(Op.getValueType() == MVT::f16 && - "Do not know how to custom lower FP_ROUND for non-f16 type"); - SDValue Src = Op.getOperand(0); EVT SrcVT = Src.getValueType(); - if (SrcVT != MVT::f64) - return Op; - - // TODO: Handle strictfp - if (Op.getOpcode() != ISD::FP_ROUND) + if (SrcVT.getScalarType() != MVT::f64) return Op; + EVT DstVT = Op.getValueType(); SDLoc DL(Op); + if (DstVT == MVT::f16) { + // TODO: Handle strictfp + if (Op.getOpcode() != ISD::FP_ROUND) + return Op; + + SDValue FpToFp16 = DAG.getNode(ISD::FP_TO_FP16, DL, MVT::i32, Src); + SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToFp16); + return DAG.getNode(ISD::BITCAST, DL, MVT::f16, Trunc); + } + + assert(DstVT.getScalarType() == MVT::bf16 && + "custom lower FP_ROUND for f16 or bf16"); + assert(Subtarget->hasBF16ConversionInsts() && "f32 -> bf16 is legal"); - SDValue FpToFp16 = DAG.getNode(ISD::FP_TO_FP16, DL, MVT::i32, Src); - SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToFp16); - return DAG.getNode(ISD::BITCAST, DL, MVT::f16, Trunc); + // Round-inexact-to-odd f64 to f32, then do the final rounding using the + // hardware f32 -> bf16 instruction. + EVT F32VT = SrcVT.isVector() ? SrcVT.changeVectorElementType(MVT::f32) : + MVT::f32; + SDValue Rod = expandRoundInexactToOdd(F32VT, Src, DL, DAG); + return DAG.getNode(ISD::FP_ROUND, DL, DstVT, Rod, + DAG.getTargetConstant(0, DL, MVT::i32)); } SDValue SITargetLowering::lowerFMINNUM_FMAXNUM(SDValue Op, diff --git a/llvm/lib/Target/AMDGPU/VOP3Instructions.td b/llvm/lib/Target/AMDGPU/VOP3Instructions.td index 9feb5df2f9203..8686a85620a17 100644 --- a/llvm/lib/Target/AMDGPU/VOP3Instructions.td +++ b/llvm/lib/Target/AMDGPU/VOP3Instructions.td @@ -1443,16 +1443,11 @@ let SubtargetPredicate = HasBF16ConversionInsts in { } def : GCNPat<(v2bf16 (bf16_fpround v2f32:$src)), (V_CVT_PK_BF16_F32_e64 0, (EXTRACT_SUBREG VReg_64:$src, sub0), 0, (EXTRACT_SUBREG VReg_64:$src, sub1))>; - def : GCNPat<(v2bf16 (bf16_fpround v2f64:$src)), - (V_CVT_PK_BF16_F32_e64 0, (V_CVT_F32_F64_e64 0, (EXTRACT_SUBREG VReg_128:$src, sub0_sub1)), - 0, (V_CVT_F32_F64_e64 0, (EXTRACT_SUBREG VReg_128:$src, sub2_sub3)))>; def : GCNPat<(v2bf16 (build_vector (bf16 (bf16_fpround (f32 (VOP3Mods f32:$src0, i32:$src0_modifiers)))), (bf16 (bf16_fpround (f32 (VOP3Mods f32:$src1, i32:$src1_modifiers)))))), (V_CVT_PK_BF16_F32_e64 $src0_modifiers, $src0, $src1_modifiers, $src1)>; def : GCNPat<(bf16 (bf16_fpround (f32 (VOP3Mods f32:$src0, i32:$src0_modifiers)))), (V_CVT_PK_BF16_F32_e64 $src0_modifiers, $src0, 0, (f32 (IMPLICIT_DEF)))>; - def : GCNPat<(bf16 (bf16_fpround (f64 (VOP3Mods f64:$src0, i32:$src0_modifiers)))), - (V_CVT_PK_BF16_F32_e64 0, (f32 (V_CVT_F32_F64_e64 $src0_modifiers, $src0)), 0, (f32 (IMPLICIT_DEF)))>; } class Cvt_Scale_Sr_F32ToBF16F16_Pat : GCNPat< diff --git a/llvm/test/CodeGen/AMDGPU/bf16-conversions.ll b/llvm/test/CodeGen/AMDGPU/bf16-conversions.ll index 4c01e583713a7..3be911ab9e7f4 100644 --- a/llvm/test/CodeGen/AMDGPU/bf16-conversions.ll +++ b/llvm/test/CodeGen/AMDGPU/bf16-conversions.ll @@ -153,9 +153,34 @@ define amdgpu_ps float @v_test_cvt_v2f64_v2bf16_v(<2 x double> %src) { ; ; GFX-950-LABEL: v_test_cvt_v2f64_v2bf16_v: ; GFX-950: ; %bb.0: -; GFX-950-NEXT: v_cvt_f32_f64_e32 v2, v[2:3] -; GFX-950-NEXT: v_cvt_f32_f64_e32 v0, v[0:1] -; GFX-950-NEXT: v_cvt_pk_bf16_f32 v0, v0, v2 +; GFX-950-NEXT: v_mov_b32_e32 v4, v3 +; GFX-950-NEXT: v_and_b32_e32 v3, 0x7fffffff, v4 +; GFX-950-NEXT: v_mov_b32_e32 v5, v1 +; GFX-950-NEXT: v_cvt_f32_f64_e32 v1, v[2:3] +; GFX-950-NEXT: v_cvt_f64_f32_e32 v[6:7], v1 +; GFX-950-NEXT: v_and_b32_e32 v8, 1, v1 +; GFX-950-NEXT: v_cmp_gt_f64_e64 s[2:3], v[2:3], v[6:7] +; GFX-950-NEXT: v_cmp_nlg_f64_e32 vcc, v[2:3], v[6:7] +; GFX-950-NEXT: v_cmp_eq_u32_e64 s[0:1], 1, v8 +; GFX-950-NEXT: v_cndmask_b32_e64 v2, -1, 1, s[2:3] +; GFX-950-NEXT: v_add_u32_e32 v2, v1, v2 +; GFX-950-NEXT: s_or_b64 vcc, vcc, s[0:1] +; GFX-950-NEXT: v_cndmask_b32_e32 v1, v2, v1, vcc +; GFX-950-NEXT: s_brev_b32 s4, 1 +; GFX-950-NEXT: v_and_or_b32 v4, v4, s4, v1 +; GFX-950-NEXT: v_and_b32_e32 v1, 0x7fffffff, v5 +; GFX-950-NEXT: v_cvt_f32_f64_e32 v6, v[0:1] +; GFX-950-NEXT: v_cvt_f64_f32_e32 v[2:3], v6 +; GFX-950-NEXT: v_and_b32_e32 v7, 1, v6 +; GFX-950-NEXT: v_cmp_gt_f64_e64 s[2:3], v[0:1], v[2:3] +; GFX-950-NEXT: v_cmp_nlg_f64_e32 vcc, v[0:1], v[2:3] +; GFX-950-NEXT: v_cmp_eq_u32_e64 s[0:1], 1, v7 +; GFX-950-NEXT: v_cndmask_b32_e64 v0, -1, 1, s[2:3] +; GFX-950-NEXT: v_add_u32_e32 v0, v6, v0 +; GFX-950-NEXT: s_or_b64 vcc, vcc, s[0:1] +; GFX-950-NEXT: v_cndmask_b32_e32 v0, v0, v6, vcc +; GFX-950-NEXT: v_and_or_b32 v0, v5, s4, v0 +; GFX-950-NEXT: v_cvt_pk_bf16_f32 v0, v0, v4 ; GFX-950-NEXT: ; return to shader part epilog %res = fptrunc <2 x double> %src to <2 x bfloat> %cast = bitcast <2 x bfloat> %res to float @@ -347,7 +372,18 @@ define amdgpu_ps void @fptrunc_f64_to_bf16(double %a, ptr %out) { ; ; GFX-950-LABEL: fptrunc_f64_to_bf16: ; GFX-950: ; %bb.0: ; %entry -; GFX-950-NEXT: v_cvt_f32_f64_e32 v0, v[0:1] +; GFX-950-NEXT: v_cvt_f32_f64_e64 v6, |v[0:1]| +; GFX-950-NEXT: v_cvt_f64_f32_e32 v[4:5], v6 +; GFX-950-NEXT: v_and_b32_e32 v7, 1, v6 +; GFX-950-NEXT: v_cmp_gt_f64_e64 s[2:3], |v[0:1]|, v[4:5] +; GFX-950-NEXT: v_cmp_nlg_f64_e64 s[0:1], |v[0:1]|, v[4:5] +; GFX-950-NEXT: v_cmp_eq_u32_e32 vcc, 1, v7 +; GFX-950-NEXT: v_cndmask_b32_e64 v0, -1, 1, s[2:3] +; GFX-950-NEXT: v_add_u32_e32 v0, v6, v0 +; GFX-950-NEXT: s_or_b64 vcc, s[0:1], vcc +; GFX-950-NEXT: v_cndmask_b32_e32 v0, v0, v6, vcc +; GFX-950-NEXT: s_brev_b32 s0, 1 +; GFX-950-NEXT: v_and_or_b32 v0, v1, s0, v0 ; GFX-950-NEXT: v_cvt_pk_bf16_f32 v0, v0, s0 ; GFX-950-NEXT: flat_store_short v[2:3], v0 ; GFX-950-NEXT: s_endpgm @@ -385,7 +421,19 @@ define amdgpu_ps void @fptrunc_f64_to_bf16_neg(double %a, ptr %out) { ; ; GFX-950-LABEL: fptrunc_f64_to_bf16_neg: ; GFX-950: ; %bb.0: ; %entry -; GFX-950-NEXT: v_cvt_f32_f64_e64 v0, -v[0:1] +; GFX-950-NEXT: v_cvt_f32_f64_e64 v7, |v[0:1]| +; GFX-950-NEXT: v_cvt_f64_f32_e32 v[4:5], v7 +; GFX-950-NEXT: v_and_b32_e32 v8, 1, v7 +; GFX-950-NEXT: v_cmp_gt_f64_e64 s[2:3], |v[0:1]|, v[4:5] +; GFX-950-NEXT: v_cmp_nlg_f64_e64 s[0:1], |v[0:1]|, v[4:5] +; GFX-950-NEXT: v_cmp_eq_u32_e32 vcc, 1, v8 +; GFX-950-NEXT: v_cndmask_b32_e64 v0, -1, 1, s[2:3] +; GFX-950-NEXT: v_add_u32_e32 v0, v7, v0 +; GFX-950-NEXT: s_or_b64 vcc, s[0:1], vcc +; GFX-950-NEXT: s_brev_b32 s4, 1 +; GFX-950-NEXT: v_xor_b32_e32 v6, 0x80000000, v1 +; GFX-950-NEXT: v_cndmask_b32_e32 v0, v0, v7, vcc +; GFX-950-NEXT: v_and_or_b32 v0, v6, s4, v0 ; GFX-950-NEXT: v_cvt_pk_bf16_f32 v0, v0, s0 ; GFX-950-NEXT: flat_store_short v[2:3], v0 ; GFX-950-NEXT: s_endpgm @@ -424,7 +472,19 @@ define amdgpu_ps void @fptrunc_f64_to_bf16_abs(double %a, ptr %out) { ; ; GFX-950-LABEL: fptrunc_f64_to_bf16_abs: ; GFX-950: ; %bb.0: ; %entry -; GFX-950-NEXT: v_cvt_f32_f64_e64 v0, |v[0:1]| +; GFX-950-NEXT: v_cvt_f32_f64_e64 v7, |v[0:1]| +; GFX-950-NEXT: v_cvt_f64_f32_e32 v[4:5], v7 +; GFX-950-NEXT: v_and_b32_e32 v8, 1, v7 +; GFX-950-NEXT: v_cmp_gt_f64_e64 s[2:3], |v[0:1]|, v[4:5] +; GFX-950-NEXT: v_cmp_nlg_f64_e64 s[0:1], |v[0:1]|, v[4:5] +; GFX-950-NEXT: v_cmp_eq_u32_e32 vcc, 1, v8 +; GFX-950-NEXT: v_cndmask_b32_e64 v0, -1, 1, s[2:3] +; GFX-950-NEXT: v_add_u32_e32 v0, v7, v0 +; GFX-950-NEXT: s_or_b64 vcc, s[0:1], vcc +; GFX-950-NEXT: v_and_b32_e32 v6, 0x7fffffff, v1 +; GFX-950-NEXT: v_cndmask_b32_e32 v0, v0, v7, vcc +; GFX-950-NEXT: s_brev_b32 s0, 1 +; GFX-950-NEXT: v_and_or_b32 v0, v6, s0, v0 ; GFX-950-NEXT: v_cvt_pk_bf16_f32 v0, v0, s0 ; GFX-950-NEXT: flat_store_short v[2:3], v0 ; GFX-950-NEXT: s_endpgm