Skip to content

Commit 7f2a112

Browse files
[X86] Add support for __bf16 to f16 conversion
`bf16` is a typedef short type introduced in AVX-512_BF16 and should be able to leverage SSE/AVX registers used for `f16`. Fixes: #134222.
1 parent 0fc7aec commit 7f2a112

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -661,10 +661,12 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
661661
};
662662

663663
if (!Subtarget.useSoftFloat() && Subtarget.hasSSE2()) {
664-
// f16, f32 and f64 use SSE.
664+
// f16, bf16, f32 and f64 use SSE.
665665
// Set up the FP register classes.
666666
addRegisterClass(MVT::f16, Subtarget.hasAVX512() ? &X86::FR16XRegClass
667667
: &X86::FR16RegClass);
668+
addRegisterClass(MVT::bf16, Subtarget.hasAVX512() ? &X86::FR16XRegClass
669+
: &X86::FR16RegClass);
668670
addRegisterClass(MVT::f32, Subtarget.hasAVX512() ? &X86::FR32XRegClass
669671
: &X86::FR32RegClass);
670672
addRegisterClass(MVT::f64, Subtarget.hasAVX512() ? &X86::FR64XRegClass
@@ -676,6 +678,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
676678
// non-optsize case.
677679
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
678680

681+
// Set the operation action Custom for bitcast to do the customization
682+
// later.
683+
setOperationAction(ISD::BITCAST, MVT::bf16, Custom);
684+
679685
for (auto VT : { MVT::f32, MVT::f64 }) {
680686
// Use ANDPD to simulate FABS.
681687
setOperationAction(ISD::FABS, VT, Custom);
@@ -32151,6 +32157,10 @@ static SDValue LowerBITCAST(SDValue Op, const X86Subtarget &Subtarget,
3215132157
return DAG.getZExtOrTrunc(V, DL, DstVT);
3215232158
}
3215332159

32160+
// Bitcasts between f16 and bf16 should be legal.
32161+
if (DstVT == MVT::f16 || DstVT == MVT::bf16)
32162+
return Op;
32163+
3215432164
assert((SrcVT == MVT::v2i32 || SrcVT == MVT::v4i16 || SrcVT == MVT::v8i8 ||
3215532165
SrcVT == MVT::i64) && "Unexpected VT!");
3215632166

llvm/lib/Target/X86/X86InstrAVX512.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2456,6 +2456,11 @@ let Predicates = [HasFP16] in {
24562456
(VCMPSHZrmi FR16X:$src1, addr:$src2, (X86cmpm_imm_commute timm:$cc))>;
24572457
}
24582458

2459+
let Predicates = [HasAVX512, HasBF16] in {
2460+
def : Pat<(f16 (bitconvert (bf16 FR16X:$src))), (f16 FR16X:$src)>;
2461+
def : Pat<(bf16 (bitconvert (f16 FR16X:$src))), (bf16 FR16X:$src)>;
2462+
}
2463+
24592464
// ----------------------------------------------------------------
24602465
// FPClass
24612466

0 commit comments

Comments
 (0)