@@ -661,10 +661,12 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
661661 };
662662
663663 if (!Subtarget.useSoftFloat() && Subtarget.hasSSE2()) {
664- // f16, f32 and f64 use SSE.
664+ // f16, bf16, f32 and f64 use SSE.
665665 // Set up the FP register classes.
666666 addRegisterClass(MVT::f16, Subtarget.hasAVX512() ? &X86::FR16XRegClass
667667 : &X86::FR16RegClass);
668+ addRegisterClass(MVT::bf16, Subtarget.hasAVX512() ? &X86::FR16XRegClass
669+ : &X86::FR16RegClass);
668670 addRegisterClass(MVT::f32, Subtarget.hasAVX512() ? &X86::FR32XRegClass
669671 : &X86::FR32RegClass);
670672 addRegisterClass(MVT::f64, Subtarget.hasAVX512() ? &X86::FR64XRegClass
@@ -676,6 +678,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
676678 // non-optsize case.
677679 setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
678680
681+ // Set the operation action Custom for bitcast to do the customization
682+ // later.
683+ setOperationAction(ISD::BITCAST, MVT::bf16, Custom);
684+
679685 for (auto VT : { MVT::f32, MVT::f64 }) {
680686 // Use ANDPD to simulate FABS.
681687 setOperationAction(ISD::FABS, VT, Custom);
@@ -32151,6 +32157,10 @@ static SDValue LowerBITCAST(SDValue Op, const X86Subtarget &Subtarget,
3215132157 return DAG.getZExtOrTrunc(V, DL, DstVT);
3215232158 }
3215332159
32160+ // Bitcasts between f16 and bf16 should be legal.
32161+ if (DstVT == MVT::f16 || DstVT == MVT::bf16)
32162+ return Op;
32163+
3215432164 assert((SrcVT == MVT::v2i32 || SrcVT == MVT::v4i16 || SrcVT == MVT::v8i8 ||
3215532165 SrcVT == MVT::i64) && "Unexpected VT!");
3215632166
0 commit comments