Skip to content

Commit fa64a21

Browse files
authored
[X86][FP16] Adding lowerings for FP16 ISD::LRINT and ISD::LLRINT (llvm#127382)
Address comment in llvm#126477
1 parent 7981c29 commit fa64a21

File tree

4 files changed

+2839
-7
lines changed

4 files changed

+2839
-7
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
735735
setOperationAction(ISD::FCANONICALIZE, MVT::f16, Custom);
736736
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f32, Custom);
737737
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f64, Custom);
738+
setOperationAction(ISD::LRINT, MVT::f16, Expand);
739+
setOperationAction(ISD::LLRINT, MVT::f16, Expand);
738740

739741
// Lower this to MOVMSK plus an AND.
740742
setOperationAction(ISD::FGETSIGN, MVT::i64, Custom);
@@ -2309,6 +2311,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
23092311
setOperationAction(ISD::FMINIMUMNUM, MVT::f16, Custom);
23102312
setOperationAction(ISD::FP_EXTEND, MVT::f32, Legal);
23112313
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f32, Legal);
2314+
setOperationAction(ISD::LRINT, MVT::f16, Legal);
2315+
setOperationAction(ISD::LLRINT, MVT::f16, Legal);
23122316

23132317
setCondCodeAction(ISD::SETOEQ, MVT::f16, Expand);
23142318
setCondCodeAction(ISD::SETUNE, MVT::f16, Expand);
@@ -2356,6 +2360,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
23562360
setOperationAction(ISD::FMAXIMUM, MVT::v32f16, Custom);
23572361
setOperationAction(ISD::FMINIMUMNUM, MVT::v32f16, Custom);
23582362
setOperationAction(ISD::FMAXIMUMNUM, MVT::v32f16, Custom);
2363+
setOperationAction(ISD::LRINT, MVT::v32f16, Legal);
2364+
setOperationAction(ISD::LLRINT, MVT::v8f16, Legal);
23592365
}
23602366

23612367
if (Subtarget.hasVLX()) {
@@ -2410,6 +2416,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
24102416
setOperationAction(ISD::FMAXIMUM, MVT::v16f16, Custom);
24112417
setOperationAction(ISD::FMINIMUMNUM, MVT::v16f16, Custom);
24122418
setOperationAction(ISD::FMAXIMUMNUM, MVT::v16f16, Custom);
2419+
setOperationAction(ISD::LRINT, MVT::v8f16, Legal);
2420+
setOperationAction(ISD::LRINT, MVT::v16f16, Legal);
24132421
}
24142422
}
24152423

@@ -34080,8 +34088,15 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
3408034088
case ISD::LRINT:
3408134089
if (N->getValueType(0) == MVT::v2i32) {
3408234090
SDValue Src = N->getOperand(0);
34083-
if (Src.getValueType() == MVT::v2f64)
34084-
Results.push_back(DAG.getNode(X86ISD::CVTP2SI, dl, MVT::v4i32, Src));
34091+
if (Subtarget.hasFP16() && Src.getValueType() == MVT::v2f16) {
34092+
Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f16, Src,
34093+
DAG.getUNDEF(MVT::v2f16));
34094+
Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v8f16, Src,
34095+
DAG.getUNDEF(MVT::v4f16));
34096+
} else if (Src.getValueType() != MVT::v2f64) {
34097+
return;
34098+
}
34099+
Results.push_back(DAG.getNode(X86ISD::CVTP2SI, dl, MVT::v4i32, Src));
3408534100
return;
3408634101
}
3408734102
[[fallthrough]];
@@ -53687,13 +53702,35 @@ static SDValue combineLRINT_LLRINT(SDNode *N, SelectionDAG &DAG,
5368753702
EVT SrcVT = Src.getValueType();
5368853703
SDLoc DL(N);
5368953704

53690-
if (!Subtarget.hasDQI() || !Subtarget.hasVLX() || VT != MVT::v2i64 ||
53691-
SrcVT != MVT::v2f32)
53705+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
53706+
53707+
// Let legalize expand this if it isn't a legal type yet.
53708+
if (!TLI.isTypeLegal(VT))
53709+
return SDValue();
53710+
53711+
if ((SrcVT.getScalarType() == MVT::f16 && !Subtarget.hasFP16()) ||
53712+
(SrcVT.getScalarType() == MVT::f32 && !Subtarget.hasDQI()))
5369253713
return SDValue();
5369353714

53694-
return DAG.getNode(X86ISD::CVTP2SI, DL, VT,
53695-
DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v4f32, Src,
53696-
DAG.getUNDEF(SrcVT)));
53715+
if (SrcVT == MVT::v2f16) {
53716+
SrcVT = MVT::v4f16;
53717+
Src = DAG.getNode(ISD::CONCAT_VECTORS, DL, SrcVT, Src,
53718+
DAG.getUNDEF(MVT::v2f16));
53719+
}
53720+
53721+
if (SrcVT == MVT::v4f16) {
53722+
SrcVT = MVT::v8f16;
53723+
Src = DAG.getNode(ISD::CONCAT_VECTORS, DL, SrcVT, Src,
53724+
DAG.getUNDEF(MVT::v4f16));
53725+
} else if (SrcVT == MVT::v2f32) {
53726+
SrcVT = MVT::v4f32;
53727+
Src = DAG.getNode(ISD::CONCAT_VECTORS, DL, SrcVT, Src,
53728+
DAG.getUNDEF(MVT::v2f32));
53729+
} else {
53730+
return SDValue();
53731+
}
53732+
53733+
return DAG.getNode(X86ISD::CVTP2SI, DL, VT, Src);
5369753734
}
5369853735

5369953736
// Attempt to fold some (truncate (srl (add X, C1), C2)) patterns to

llvm/lib/Target/X86/X86InstrAVX512.td

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13143,6 +13143,26 @@ defm VCVTTPH2UQQ : avx512_cvttph2qq<0x78, "vcvttph2uqq", X86any_cvttp2ui,
1314313143
SchedWriteCvtPS2DQ>, T_MAP5, PD,
1314413144
EVEX_CD8<16, CD8VQ>;
1314513145

13146+
let Predicates = [HasFP16, HasVLX] in {
13147+
def : Pat<(v8i16 (lrint (v8f16 VR128X:$src))), (VCVTPH2WZ128rr VR128X:$src)>;
13148+
def : Pat<(v8i16 (lrint (loadv8f16 addr:$src))), (VCVTPH2WZ128rm addr:$src)>;
13149+
def : Pat<(v16i16 (lrint (v16f16 VR256X:$src))), (VCVTPH2WZ256rr VR256X:$src)>;
13150+
def : Pat<(v16i16 (lrint (loadv16f16 addr:$src))), (VCVTPH2WZ256rm addr:$src)>;
13151+
def : Pat<(v8i32 (lrint (v8f16 VR128X:$src))), (VCVTPH2DQZ256rr VR128X:$src)>;
13152+
def : Pat<(v8i32 (lrint (loadv8f16 addr:$src))), (VCVTPH2DQZ256rm addr:$src)>;
13153+
}
13154+
13155+
let Predicates = [HasFP16] in {
13156+
def : Pat<(v32i16 (lrint (v32f16 VR512:$src))), (VCVTPH2WZrr VR512:$src)>;
13157+
def : Pat<(v32i16 (lrint (loadv32f16 addr:$src))), (VCVTPH2WZrm addr:$src)>;
13158+
def : Pat<(v16i32 (lrint (v16f16 VR256X:$src))), (VCVTPH2DQZrr VR256X:$src)>;
13159+
def : Pat<(v16i32 (lrint (loadv16f16 addr:$src))), (VCVTPH2DQZrm addr:$src)>;
13160+
def : Pat<(v8i64 (lrint (v8f16 VR128X:$src))), (VCVTPH2QQZrr VR128X:$src)>;
13161+
def : Pat<(v8i64 (lrint (loadv8f16 addr:$src))), (VCVTPH2QQZrm addr:$src)>;
13162+
def : Pat<(v8i64 (llrint (v8f16 VR128X:$src))), (VCVTPH2QQZrr VR128X:$src)>;
13163+
def : Pat<(v8i64 (llrint (loadv8f16 addr:$src))), (VCVTPH2QQZrm addr:$src)>;
13164+
}
13165+
1314613166
// Convert Signed/Unsigned Quardword to Half
1314713167
multiclass avx512_cvtqq2ph<bits<8> opc, string OpcodeStr, SDPatternOperator OpNode,
1314813168
SDPatternOperator MaskOpNode, SDNode OpNodeRnd,
@@ -13269,6 +13289,19 @@ defm VCVTTSH2USI64Z: avx512_cvt_s_all<0x78, "vcvttsh2usi", f16x_info, i64x_info,
1326913289
any_fp_to_uint, X86cvtts2UInt, X86cvtts2UIntSAE, WriteCvtSS2I,
1327013290
"{q}", HasFP16>, T_MAP5, XS, REX_W, EVEX_CD8<16, CD8VT1>;
1327113291

13292+
let Predicates = [HasFP16] in {
13293+
def : Pat<(i16 (lrint FR16:$src)), (EXTRACT_SUBREG (VCVTTSH2SIZrr FR16:$src), sub_16bit)>;
13294+
def : Pat<(i32 (lrint FR16:$src)), (VCVTTSH2SIZrr FR16:$src)>;
13295+
def : Pat<(i32 (lrint (loadf16 addr:$src))), (VCVTTSH2SIZrm addr:$src)>;
13296+
}
13297+
13298+
let Predicates = [HasFP16, In64BitMode] in {
13299+
def : Pat<(i64 (lrint FR16:$src)), (VCVTTSH2SI64Zrr FR16:$src)>;
13300+
def : Pat<(i64 (lrint (loadf16 addr:$src))), (VCVTTSH2SI64Zrm addr:$src)>;
13301+
def : Pat<(i64 (llrint FR16:$src)), (VCVTTSH2SI64Zrr FR16:$src)>;
13302+
def : Pat<(i64 (llrint (loadf16 addr:$src))), (VCVTTSH2SI64Zrm addr:$src)>;
13303+
}
13304+
1327213305
let Predicates = [HasFP16] in {
1327313306
defm VCVTSI2SHZ : avx512_vcvtsi_common<0x2A, X86SintToFp, X86SintToFpRnd, WriteCvtI2SS, GR32,
1327413307
v8f16x_info, i32mem, loadi32, "cvtsi2sh", "l">,

0 commit comments

Comments
 (0)