Skip to content

Commit b56a6ba

Browse files
!fixup custom hooks for trunc/ext, add missing instruction selection patterns
1 parent 22f3976 commit b56a6ba

File tree

5 files changed

+102
-7
lines changed

5 files changed

+102
-7
lines changed

llvm/lib/Target/X86/X86FastISel.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,8 @@ class X86FastISel final : public FastISel {
147147
/// computed in an SSE register, not on the X87 floating point stack.
148148
bool isScalarFPTypeInSSEReg(EVT VT) const {
149149
return (VT == MVT::f64 && Subtarget->hasSSE2()) ||
150-
(VT == MVT::f32 && Subtarget->hasSSE1()) || VT == MVT::f16;
150+
(VT == MVT::f32 && Subtarget->hasSSE1()) || VT == MVT::f16 ||
151+
VT == MVT::bf16;
151152
}
152153

153154
bool isTypeLegal(Type *Ty, MVT &VT, bool AllowI1 = false);
@@ -2283,6 +2284,7 @@ bool X86FastISel::X86FastEmitPseudoSelect(MVT RetVT, const Instruction *I) {
22832284
case MVT::i16: Opc = X86::CMOV_GR16; break;
22842285
case MVT::i32: Opc = X86::CMOV_GR32; break;
22852286
case MVT::f16:
2287+
case MVT::bf16:
22862288
Opc = Subtarget->hasAVX512() ? X86::CMOV_FR16X : X86::CMOV_FR16; break;
22872289
case MVT::f32:
22882290
Opc = Subtarget->hasAVX512() ? X86::CMOV_FR32X : X86::CMOV_FR32; break;
@@ -3972,6 +3974,7 @@ Register X86FastISel::fastMaterializeFloatZero(const ConstantFP *CF) {
39723974
switch (VT.SimpleTy) {
39733975
default: return 0;
39743976
case MVT::f16:
3977+
case MVT::bf16:
39753978
Opc = HasAVX512 ? X86::AVX512_FsFLD0SH : X86::FsFLD0SH;
39763979
break;
39773980
case MVT::f32:

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -678,9 +678,11 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
678678
// non-optsize case.
679679
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
680680

681-
// Set the operation action Custom for bitcast to do the customization
682-
// later.
681+
// Set the operation action Custom for bitcast and conversion, and fall-back
682+
// to software libcalls for the latter for the now.
683683
setOperationAction(ISD::BITCAST, MVT::bf16, Custom);
684+
setOperationAction(ISD::FP_EXTEND, MVT::bf16, Custom);
685+
setOperationAction(ISD::FP_ROUND, MVT::bf16, Custom);
684686

685687
for (auto VT : { MVT::f32, MVT::f64 }) {
686688
// Use ANDPD to simulate FABS.
@@ -22066,6 +22068,31 @@ SDValue X86TargetLowering::LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const {
2206622068
return Res;
2206722069
}
2206822070

22071+
if (SVT == MVT::bf16 && VT == MVT::f32) {
22072+
TargetLowering::CallLoweringInfo CLI(DAG);
22073+
Chain = IsStrict ? Op.getOperand(0) : DAG.getEntryNode();
22074+
22075+
TargetLowering::ArgListTy Args;
22076+
TargetLowering::ArgListEntry Entry;
22077+
Entry.Node = In;
22078+
Entry.Ty = EVT(SVT).getTypeForEVT(*DAG.getContext());
22079+
Args.push_back(Entry);
22080+
22081+
SDValue Callee =
22082+
DAG.getExternalSymbol(getLibcallName(RTLIB::FPEXT_BF16_F32),
22083+
getPointerTy(DAG.getDataLayout()));
22084+
CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
22085+
CallingConv::C, EVT(VT).getTypeForEVT(*DAG.getContext()), Callee,
22086+
std::move(Args));
22087+
22088+
SDValue Res;
22089+
std::tie(Res, Chain) = LowerCallTo(CLI);
22090+
if (IsStrict)
22091+
Res = DAG.getMergeValues({Res, Chain}, DL);
22092+
22093+
return Res;
22094+
}
22095+
2206922096
if (!SVT.isVector() || SVT.getVectorElementType() == MVT::bf16)
2207022097
return Op;
2207122098

@@ -22149,6 +22176,30 @@ SDValue X86TargetLowering::LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const {
2214922176
((Subtarget.hasBF16() && Subtarget.hasVLX()) ||
2215022177
Subtarget.hasAVXNECONVERT()))
2215122178
return Op;
22179+
22180+
// Need a soft libcall if the target has not BF16.
22181+
if (SVT.getScalarType() == MVT::f32 || SVT.getScalarType() == MVT::f64) {
22182+
TargetLowering::CallLoweringInfo CLI(DAG);
22183+
Chain = IsStrict ? Op.getOperand(0) : DAG.getEntryNode();
22184+
22185+
TargetLowering::ArgListTy Args;
22186+
TargetLowering::ArgListEntry Entry;
22187+
Entry.Node = In;
22188+
Entry.Ty = EVT(SVT).getTypeForEVT(*DAG.getContext());
22189+
Args.push_back(Entry);
22190+
SDValue Callee = DAG.getExternalSymbol(
22191+
getLibcallName(SVT == MVT::f64 ? RTLIB::FPROUND_F64_BF16
22192+
: RTLIB::FPROUND_F32_BF16),
22193+
getPointerTy(DAG.getDataLayout()));
22194+
CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
22195+
CallingConv::C, EVT(MVT::bf16).getTypeForEVT(*DAG.getContext()),
22196+
Callee, std::move(Args));
22197+
22198+
SDValue Res;
22199+
std::tie(Res, Chain) = LowerCallTo(CLI);
22200+
return IsStrict ? DAG.getMergeValues({Res, Chain}, DL) : Res;
22201+
}
22202+
2215222203
return SDValue();
2215322204
}
2215422205

llvm/lib/Target/X86/X86InstrAVX512.td

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11585,6 +11585,13 @@ let Predicates = [HasBWI], AddedComplexity = -10 in {
1158511585
def : Pat<(f16 (bitconvert (bf16 FR16X:$src))), (f16 FR16X:$src)>;
1158611586
def : Pat<(bf16 (bitconvert (f16 FR16X:$src))), (bf16 FR16X:$src)>;
1158711587

11588+
let Predicates = [HasBWI, HasBF16] in {
11589+
def : Pat<(bf16 (load addr:$src)), (COPY_TO_REGCLASS (VPINSRWZrmi (v8i16 (IMPLICIT_DEF)), addr:$src, 0), FR16X)>;
11590+
def : Pat<(store bf16:$src, addr:$dst), (VPEXTRWZmri addr:$dst, (v8i16 (COPY_TO_REGCLASS FR16X:$src, VR128)), 0)>;
11591+
def : Pat<(i16 (bitconvert bf16:$src)), (EXTRACT_SUBREG (VPEXTRWZrri (v8i16 (COPY_TO_REGCLASS FR16X:$src, VR128X)), 0), sub_16bit)>;
11592+
def : Pat<(bf16 (bitconvert i16:$src)), (COPY_TO_REGCLASS (VPINSRWZrri (v8i16 (IMPLICIT_DEF)), (INSERT_SUBREG (IMPLICIT_DEF), GR16:$src, sub_16bit), 0), FR16X)>;
11593+
}
11594+
1158811595
//===----------------------------------------------------------------------===//
1158911596
// VSHUFPS - VSHUFPD Operations
1159011597
//===----------------------------------------------------------------------===//
@@ -12809,17 +12816,31 @@ let Predicates = [HasBF16, HasVLX] in {
1280912816
}
1281012817

1281112818
let Predicates = [HasBF16] in {
12819+
def : Pat<(v8bf16 (X86VBroadcastld16 addr:$src)),
12820+
(VPBROADCASTWrm addr:$src)>;
12821+
def : Pat<(v16bf16 (X86VBroadcastld16 addr:$src)),
12822+
(VPBROADCASTWYrm addr:$src)>;
1281212823
def : Pat<(v32bf16 (X86VBroadcastld16 addr:$src)),
1281312824
(VPBROADCASTWZrm addr:$src)>;
1281412825

12826+
def : Pat<(v8bf16 (X86VBroadcast (v8bf16 VR128:$src))),
12827+
(VPBROADCASTWrr VR128:$src)>;
12828+
def : Pat<(v16bf16 (X86VBroadcast (v8bf16 VR128:$src))),
12829+
(VPBROADCASTWYrr VR128:$src)>;
1281512830
def : Pat<(v32bf16 (X86VBroadcast (v8bf16 VR128X:$src))),
1281612831
(VPBROADCASTWZrr VR128X:$src)>;
1281712832

1281812833
def : Pat<(v16bf16 (X86vfpround (v16f32 VR512:$src))),
1281912834
(VCVTNEPS2BF16Zrr VR512:$src)>;
1282012835
def : Pat<(v16bf16 (X86vfpround (loadv16f32 addr:$src))),
1282112836
(VCVTNEPS2BF16Zrm addr:$src)>;
12822-
// TODO: No scalar broadcast due to we don't support legal scalar bf16 so far.
12837+
12838+
def : Pat<(v8bf16 (X86VBroadcast (bf16 FR16X:$src))),
12839+
(VPBROADCASTWrr (COPY_TO_REGCLASS FR16X:$src, VR128))>;
12840+
def : Pat<(v16bf16 (X86VBroadcast (bf16 FR16X:$src))),
12841+
(VPBROADCASTWYrr (COPY_TO_REGCLASS FR16X:$src, VR128))>;
12842+
def : Pat<(v32bf16 (X86VBroadcast (bf16 FR16X:$src))),
12843+
(VPBROADCASTWZrr (COPY_TO_REGCLASS FR16X:$src, VR128X))>;
1282312844
}
1282412845

1282512846
let Constraints = "$src1 = $dst" in {

llvm/lib/Target/X86/X86InstrFragmentsSIMD.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,6 +1193,10 @@ def fp16imm0 : PatLeaf<(f16 fpimm), [{
11931193
return N->isExactlyValue(+0.0);
11941194
}]>;
11951195

1196+
def bfp16imm0 : PatLeaf<(bf16 fpimm), [{
1197+
return N->isExactlyValue(+0.0);
1198+
}]>;
1199+
11961200
def fp32imm0 : PatLeaf<(f32 fpimm), [{
11971201
return N->isExactlyValue(+0.0);
11981202
}]>;

llvm/lib/Target/X86/X86InstrSSE.td

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4048,6 +4048,19 @@ let Predicates = [HasAVX, NoBWI] in {
40484048
def : Pat<(f16 (bitconvert i16:$src)), (COPY_TO_REGCLASS (VPINSRWrri (v8i16 (IMPLICIT_DEF)), (INSERT_SUBREG (IMPLICIT_DEF), GR16:$src, sub_16bit), 0), FR16)>;
40494049
}
40504050

4051+
let Predicates = [UseSSE2] in {
4052+
def : Pat<(bf16 (load addr:$src)), (COPY_TO_REGCLASS (PINSRWrmi (v8i16 (IMPLICIT_DEF)), addr:$src, 0), FR16X)>;
4053+
def : Pat<(store bf16:$src, addr:$dst), (MOV16mr addr:$dst, (EXTRACT_SUBREG (PEXTRWrri (v8i16 (COPY_TO_REGCLASS FR16X:$src, VR128)), 0), sub_16bit))>;
4054+
def : Pat<(i16 (bitconvert bf16:$src)), (EXTRACT_SUBREG (PEXTRWrri (v8i16 (COPY_TO_REGCLASS FR16X:$src, VR128)), 0), sub_16bit)>;
4055+
def : Pat<(bf16 (bitconvert i16:$src)), (COPY_TO_REGCLASS (PINSRWrri (v8i16 (IMPLICIT_DEF)), (INSERT_SUBREG (IMPLICIT_DEF), GR16:$src, sub_16bit), 0), FR16X)>;
4056+
}
4057+
4058+
let Predicates = [HasAVX, NoBWI] in {
4059+
def : Pat<(bf16 (load addr:$src)), (COPY_TO_REGCLASS (VPINSRWrmi (v8i16 (IMPLICIT_DEF)), addr:$src, 0), FR16X)>;
4060+
def : Pat<(i16 (bitconvert bf16:$src)), (EXTRACT_SUBREG (VPEXTRWrri (v8i16 (COPY_TO_REGCLASS FR16X:$src, VR128)), 0), sub_16bit)>;
4061+
def : Pat<(bf16 (bitconvert i16:$src)), (COPY_TO_REGCLASS (VPINSRWrri (v8i16 (IMPLICIT_DEF)), (INSERT_SUBREG (IMPLICIT_DEF), GR16:$src, sub_16bit), 0), FR16)>;
4062+
}
4063+
40514064
//===---------------------------------------------------------------------===//
40524065
// SSE2 - Packed Mask Creation
40534066
//===---------------------------------------------------------------------===//
@@ -5279,12 +5292,15 @@ let Predicates = [HasAVX, NoBWI] in
52795292

52805293
defm PEXTRW : SS41I_extract16<0x15, "pextrw">;
52815294

5282-
let Predicates = [UseSSE41] in
5295+
let Predicates = [UseSSE41] in {
52835296
def : Pat<(store f16:$src, addr:$dst), (PEXTRWmri addr:$dst, (v8i16 (COPY_TO_REGCLASS FR16:$src, VR128)), 0)>;
5297+
def : Pat<(store bf16:$src, addr:$dst), (PEXTRWmri addr:$dst, (v8i16 (COPY_TO_REGCLASS FR16X:$src, VR128)), 0)>;
5298+
}
52845299

5285-
let Predicates = [HasAVX, NoBWI] in
5300+
let Predicates = [HasAVX, NoBWI] in {
52865301
def : Pat<(store f16:$src, addr:$dst), (VPEXTRWmri addr:$dst, (v8i16 (COPY_TO_REGCLASS FR16:$src, VR128)), 0)>;
5287-
5302+
def : Pat<(store bf16:$src, addr:$dst), (VPEXTRWmri addr:$dst, (v8i16 (COPY_TO_REGCLASS FR16X:$src, VR128)), 0)>;
5303+
}
52885304

52895305
/// SS41I_extract32 - SSE 4.1 extract 32 bits to int reg or memory destination
52905306
multiclass SS41I_extract32<bits<8> opc, string OpcodeStr> {

0 commit comments

Comments
 (0)