Skip to content

Commit 9987239

Browse files
committed
[ARM] Add instruction selection for strict FP
This consists of marking the various strict opcodes as legal, and adjusting instruction selection patterns so that 'op' is 'any_op'. The changes are similar to those in D114946 for AArch64. Custom lowering and promotion are set for some FP16 strict ops to work correctly.
1 parent dda95d9 commit 9987239

File tree

5 files changed

+1226
-114
lines changed

5 files changed

+1226
-114
lines changed

llvm/lib/Target/ARM/ARMISelLowering.cpp

Lines changed: 71 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -601,10 +601,20 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM_,
601601
setOperationAction(ISD::FP_TO_SINT_SAT, MVT::i64, Custom);
602602
setOperationAction(ISD::FP_TO_UINT_SAT, MVT::i64, Custom);
603603

604-
if (!Subtarget->hasVFP2Base())
604+
if (!Subtarget->hasVFP2Base()) {
605605
setAllExpand(MVT::f32);
606-
if (!Subtarget->hasFP64())
606+
} else {
607+
for (auto Op : {ISD::STRICT_FADD, ISD::STRICT_FSUB, ISD::STRICT_FMUL,
608+
ISD::STRICT_FDIV, ISD::STRICT_FMA, ISD::STRICT_FSQRT})
609+
setOperationAction(Op, MVT::f32, Legal);
610+
}
611+
if (!Subtarget->hasFP64()) {
607612
setAllExpand(MVT::f64);
613+
} else {
614+
for (auto Op : {ISD::STRICT_FADD, ISD::STRICT_FSUB, ISD::STRICT_FMUL,
615+
ISD::STRICT_FDIV, ISD::STRICT_FMA, ISD::STRICT_FSQRT})
616+
setOperationAction(Op, MVT::f64, Legal);
617+
}
608618
}
609619

610620
if (Subtarget->hasFullFP16()) {
@@ -1333,31 +1343,44 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM_,
13331343
}
13341344

13351345
// FP16 often need to be promoted to call lib functions
1346+
// clang-format off
13361347
if (Subtarget->hasFullFP16()) {
1337-
setOperationAction(ISD::FREM, MVT::f16, Promote);
1338-
setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand);
1339-
setOperationAction(ISD::FSIN, MVT::f16, Promote);
1340-
setOperationAction(ISD::FCOS, MVT::f16, Promote);
1341-
setOperationAction(ISD::FTAN, MVT::f16, Promote);
1342-
setOperationAction(ISD::FSINCOS, MVT::f16, Promote);
1343-
setOperationAction(ISD::FPOWI, MVT::f16, Promote);
1344-
setOperationAction(ISD::FPOW, MVT::f16, Promote);
1345-
setOperationAction(ISD::FEXP, MVT::f16, Promote);
1346-
setOperationAction(ISD::FEXP2, MVT::f16, Promote);
1347-
setOperationAction(ISD::FEXP10, MVT::f16, Promote);
1348-
setOperationAction(ISD::FLOG, MVT::f16, Promote);
1349-
setOperationAction(ISD::FLOG10, MVT::f16, Promote);
1350-
setOperationAction(ISD::FLOG2, MVT::f16, Promote);
13511348
setOperationAction(ISD::LRINT, MVT::f16, Expand);
13521349
setOperationAction(ISD::LROUND, MVT::f16, Expand);
1353-
1354-
setOperationAction(ISD::FROUND, MVT::f16, Legal);
1355-
setOperationAction(ISD::FROUNDEVEN, MVT::f16, Legal);
1356-
setOperationAction(ISD::FTRUNC, MVT::f16, Legal);
1357-
setOperationAction(ISD::FNEARBYINT, MVT::f16, Legal);
1358-
setOperationAction(ISD::FRINT, MVT::f16, Legal);
1359-
setOperationAction(ISD::FFLOOR, MVT::f16, Legal);
1360-
setOperationAction(ISD::FCEIL, MVT::f16, Legal);
1350+
setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand);
1351+
1352+
for (auto Op : {ISD::FREM, ISD::FPOW, ISD::FPOWI,
1353+
ISD::FCOS, ISD::FSIN, ISD::FSINCOS,
1354+
ISD::FSINCOSPI, ISD::FMODF, ISD::FACOS,
1355+
ISD::FASIN, ISD::FATAN, ISD::FATAN2,
1356+
ISD::FCOSH, ISD::FSINH, ISD::FTANH,
1357+
ISD::FTAN, ISD::FEXP, ISD::FEXP2,
1358+
ISD::FEXP10, ISD::FLOG, ISD::FLOG2,
1359+
ISD::FLOG10, ISD::STRICT_FREM, ISD::STRICT_FPOW,
1360+
ISD::STRICT_FPOWI, ISD::STRICT_FCOS, ISD::STRICT_FSIN,
1361+
ISD::STRICT_FACOS, ISD::STRICT_FASIN, ISD::STRICT_FATAN,
1362+
ISD::STRICT_FATAN2, ISD::STRICT_FCOSH, ISD::STRICT_FSINH,
1363+
ISD::STRICT_FTANH, ISD::STRICT_FEXP, ISD::STRICT_FEXP2,
1364+
ISD::STRICT_FLOG, ISD::STRICT_FLOG2, ISD::STRICT_FLOG10,
1365+
ISD::STRICT_FTAN}) {
1366+
setOperationAction(Op, MVT::f16, Promote);
1367+
}
1368+
1369+
// Round-to-integer need custom lowering for fp16, as Promote doesn't work
1370+
// because the result type is integer.
1371+
for (auto Op : {ISD::LROUND, ISD::LLROUND, ISD::LRINT, ISD::LLRINT,
1372+
ISD::STRICT_LROUND, ISD::STRICT_LLROUND, ISD::STRICT_LRINT,
1373+
ISD::STRICT_LLRINT})
1374+
setOperationAction(Op, MVT::f16, Custom);
1375+
1376+
for (auto Op : {ISD::FROUND, ISD::FROUNDEVEN, ISD::FTRUNC,
1377+
ISD::FNEARBYINT, ISD::FRINT, ISD::FFLOOR,
1378+
ISD::FCEIL, ISD::STRICT_FROUND, ISD::STRICT_FROUNDEVEN,
1379+
ISD::STRICT_FTRUNC, ISD::STRICT_FNEARBYINT, ISD::STRICT_FRINT,
1380+
ISD::STRICT_FFLOOR, ISD::STRICT_FCEIL}) {
1381+
setOperationAction(Op, MVT::f16, Legal);
1382+
}
1383+
// clang-format on
13611384
}
13621385

13631386
if (Subtarget->hasNEON()) {
@@ -10725,6 +10748,30 @@ SDValue ARMTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
1072510748
return LowerCMP(Op, DAG);
1072610749
case ISD::ABS:
1072710750
return LowerABS(Op, DAG);
10751+
case ISD::LRINT:
10752+
case ISD::LLRINT:
10753+
case ISD::LROUND:
10754+
case ISD::LLROUND: {
10755+
assert((Op.getOperand(0).getValueType() == MVT::f16 ||
10756+
Op.getOperand(1).getValueType() == MVT::bf16) &&
10757+
"Expected custom lowering of rounding operations only for f16");
10758+
SDLoc DL(Op);
10759+
SDValue Ext = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Op.getOperand(0));
10760+
return DAG.getNode(Op.getOpcode(), DL, Op.getValueType(), Ext);
10761+
}
10762+
case ISD::STRICT_LROUND:
10763+
case ISD::STRICT_LLROUND:
10764+
case ISD::STRICT_LRINT:
10765+
case ISD::STRICT_LLRINT: {
10766+
assert((Op.getOperand(1).getValueType() == MVT::f16 ||
10767+
Op.getOperand(1).getValueType() == MVT::bf16) &&
10768+
"Expected custom lowering of rounding operations only for f16");
10769+
SDLoc DL(Op);
10770+
SDValue Ext = DAG.getNode(ISD::STRICT_FP_EXTEND, DL, {MVT::f32, MVT::Other},
10771+
{Op.getOperand(0), Op.getOperand(1)});
10772+
return DAG.getNode(Op.getOpcode(), DL, {Op.getValueType(), MVT::Other},
10773+
{Ext.getValue(1), Ext.getValue(0)});
10774+
}
1072810775
}
1072910776
}
1073010777

llvm/lib/Target/ARM/ARMInstrInfo.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -473,15 +473,15 @@ def xor_su : PatFrag<(ops node:$lhs, node:$rhs), (xor node:$lhs, node:$rhs)>;
473473

474474
// An 'fmul' node with a single use.
475475
let HasOneUse = 1 in
476-
def fmul_su : PatFrag<(ops node:$lhs, node:$rhs), (fmul node:$lhs, node:$rhs)>;
476+
def fmul_su : PatFrag<(ops node:$lhs, node:$rhs), (any_fmul node:$lhs, node:$rhs)>;
477477

478478
// An 'fadd' node which checks for single non-hazardous use.
479-
def fadd_mlx : PatFrag<(ops node:$lhs, node:$rhs),(fadd node:$lhs, node:$rhs),[{
479+
def fadd_mlx : PatFrag<(ops node:$lhs, node:$rhs),(any_fadd node:$lhs, node:$rhs),[{
480480
return hasNoVMLxHazardUse(N);
481481
}]>;
482482

483483
// An 'fsub' node which checks for single non-hazardous use.
484-
def fsub_mlx : PatFrag<(ops node:$lhs, node:$rhs),(fsub node:$lhs, node:$rhs),[{
484+
def fsub_mlx : PatFrag<(ops node:$lhs, node:$rhs),(any_fsub node:$lhs, node:$rhs),[{
485485
return hasNoVMLxHazardUse(N);
486486
}]>;
487487

0 commit comments

Comments
 (0)