Skip to content

Commit ea6b8fa

Browse files
authored
[SDAG] Merge multiple-result libcall expansion into DAG.expandMultipleResultFPLibCall() (#114792)
This merges the logic for expanding both FFREXP and FSINCOS into one method `DAG.expandMultipleResultFPLibCall()`. This reduces duplication and also allows FFREXP to benefit from the stack slot elimination implemented for FSINCOS. This method will also be used in future to implement more multiple-result intrinsics (such as modf and sincospi).
1 parent 40556d0 commit ea6b8fa

File tree

7 files changed

+196
-250
lines changed

7 files changed

+196
-250
lines changed

llvm/include/llvm/CodeGen/SelectionDAG.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "llvm/IR/ConstantRange.h"
3535
#include "llvm/IR/DebugLoc.h"
3636
#include "llvm/IR/Metadata.h"
37+
#include "llvm/IR/RuntimeLibcalls.h"
3738
#include "llvm/Support/Allocator.h"
3839
#include "llvm/Support/ArrayRecycler.h"
3940
#include "llvm/Support/CodeGen.h"
@@ -1595,8 +1596,14 @@ class SelectionDAG {
15951596
SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
15961597
SDValue Op2);
15971598

1598-
/// Expand the specified \c ISD::FSINCOS node as the Legalize pass would.
1599-
bool expandFSINCOS(SDNode *Node, SmallVectorImpl<SDValue> &Results);
1599+
/// Expands a node with multiple results to an FP or vector libcall. The
1600+
/// libcall is expected to take all the operands of the \p Node followed by
1601+
/// output pointers for each of the results. \p CallRetResNo can be optionally
1602+
/// set to indicate that one of the results comes from the libcall's return
1603+
/// value.
1604+
bool expandMultipleResultFPLibCall(RTLIB::Libcall LC, SDNode *Node,
1605+
SmallVectorImpl<SDValue> &Results,
1606+
std::optional<unsigned> CallRetResNo = {});
16001607

16011608
/// Expand the specified \c ISD::VAARG node as the Legalize pass would.
16021609
SDValue expandVAArg(SDNode *Node);

llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp

Lines changed: 11 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ class SelectionDAGLegalize {
132132
TargetLowering::ArgListTy &&Args, bool isSigned);
133133
std::pair<SDValue, SDValue> ExpandLibCall(RTLIB::Libcall LC, SDNode *Node, bool isSigned);
134134

135-
void ExpandFrexpLibCall(SDNode *Node, SmallVectorImpl<SDValue> &Results);
136135
void ExpandFPLibCall(SDNode *Node, RTLIB::Libcall LC,
137136
SmallVectorImpl<SDValue> &Results);
138137
void ExpandFPLibCall(SDNode *Node, RTLIB::Libcall Call_F32,
@@ -2144,47 +2143,6 @@ std::pair<SDValue, SDValue> SelectionDAGLegalize::ExpandLibCall(RTLIB::Libcall L
21442143
return ExpandLibCall(LC, Node, std::move(Args), isSigned);
21452144
}
21462145

2147-
void SelectionDAGLegalize::ExpandFrexpLibCall(
2148-
SDNode *Node, SmallVectorImpl<SDValue> &Results) {
2149-
SDLoc dl(Node);
2150-
EVT VT = Node->getValueType(0);
2151-
EVT ExpVT = Node->getValueType(1);
2152-
2153-
SDValue FPOp = Node->getOperand(0);
2154-
2155-
EVT ArgVT = FPOp.getValueType();
2156-
Type *ArgTy = ArgVT.getTypeForEVT(*DAG.getContext());
2157-
2158-
TargetLowering::ArgListEntry FPArgEntry;
2159-
FPArgEntry.Node = FPOp;
2160-
FPArgEntry.Ty = ArgTy;
2161-
2162-
SDValue StackSlot = DAG.CreateStackTemporary(ExpVT);
2163-
TargetLowering::ArgListEntry PtrArgEntry;
2164-
PtrArgEntry.Node = StackSlot;
2165-
PtrArgEntry.Ty = PointerType::get(*DAG.getContext(),
2166-
DAG.getDataLayout().getAllocaAddrSpace());
2167-
2168-
TargetLowering::ArgListTy Args = {FPArgEntry, PtrArgEntry};
2169-
2170-
RTLIB::Libcall LC = RTLIB::getFREXP(VT);
2171-
auto [Call, Chain] = ExpandLibCall(LC, Node, std::move(Args), false);
2172-
2173-
// FIXME: Get type of int for libcall declaration and cast
2174-
2175-
int FrameIdx = cast<FrameIndexSDNode>(StackSlot)->getIndex();
2176-
auto PtrInfo =
2177-
MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FrameIdx);
2178-
2179-
SDValue LoadExp = DAG.getLoad(ExpVT, dl, Chain, StackSlot, PtrInfo);
2180-
SDValue OutputChain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other,
2181-
LoadExp.getValue(1), DAG.getRoot());
2182-
DAG.setRoot(OutputChain);
2183-
2184-
Results.push_back(Call);
2185-
Results.push_back(LoadExp);
2186-
}
2187-
21882146
void SelectionDAGLegalize::ExpandFPLibCall(SDNode* Node,
21892147
RTLIB::Libcall LC,
21902148
SmallVectorImpl<SDValue> &Results) {
@@ -4562,10 +4520,13 @@ void SelectionDAGLegalize::ConvertNodeToLibcall(SDNode *Node) {
45624520
ExpandFPLibCall(Node, RTLIB::TANH_F32, RTLIB::TANH_F64, RTLIB::TANH_F80,
45634521
RTLIB::TANH_F128, RTLIB::TANH_PPCF128, Results);
45644522
break;
4565-
case ISD::FSINCOS:
4566-
// Expand into sincos libcall.
4567-
(void)DAG.expandFSINCOS(Node, Results);
4523+
case ISD::FSINCOS: {
4524+
RTLIB::Libcall LC = RTLIB::getFSINCOS(Node->getValueType(0));
4525+
bool Expanded = DAG.expandMultipleResultFPLibCall(LC, Node, Results);
4526+
if (!Expanded)
4527+
llvm_unreachable("Expected scalar FSINCOS to expand to libcall!");
45684528
break;
4529+
}
45694530
case ISD::FLOG:
45704531
case ISD::STRICT_FLOG:
45714532
ExpandFPLibCall(Node, RTLIB::LOG_F32, RTLIB::LOG_F64, RTLIB::LOG_F80,
@@ -4649,7 +4610,11 @@ void SelectionDAGLegalize::ConvertNodeToLibcall(SDNode *Node) {
46494610
RTLIB::LDEXP_F128, RTLIB::LDEXP_PPCF128, Results);
46504611
break;
46514612
case ISD::FFREXP: {
4652-
ExpandFrexpLibCall(Node, Results);
4613+
RTLIB::Libcall LC = RTLIB::getFREXP(Node->getValueType(0));
4614+
bool Expanded = DAG.expandMultipleResultFPLibCall(LC, Node, Results,
4615+
/*CallRetResNo=*/0);
4616+
if (!Expanded)
4617+
llvm_unreachable("Expected scalar FFREXP to expand to libcall!");
46534618
break;
46544619
}
46554620
case ISD::FPOWI:

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,11 +1192,13 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
11921192
return;
11931193

11941194
break;
1195-
case ISD::FSINCOS:
1196-
if (DAG.expandFSINCOS(Node, Results))
1195+
case ISD::FSINCOS: {
1196+
RTLIB::Libcall LC =
1197+
RTLIB::getFSINCOS(Node->getValueType(0).getVectorElementType());
1198+
if (DAG.expandMultipleResultFPLibCall(LC, Node, Results))
11971199
return;
1198-
11991200
break;
1201+
}
12001202
case ISD::VECTOR_COMPRESS:
12011203
Results.push_back(TLI.expandVECTOR_COMPRESS(Node, DAG));
12021204
return;

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 69 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2481,13 +2481,12 @@ SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
24812481
return Subvectors[0];
24822482
}
24832483

2484-
bool SelectionDAG::expandFSINCOS(SDNode *Node,
2485-
SmallVectorImpl<SDValue> &Results) {
2484+
bool SelectionDAG::expandMultipleResultFPLibCall(
2485+
RTLIB::Libcall LC, SDNode *Node, SmallVectorImpl<SDValue> &Results,
2486+
std::optional<unsigned> CallRetResNo) {
2487+
LLVMContext &Ctx = *getContext();
24862488
EVT VT = Node->getValueType(0);
2487-
LLVMContext *Ctx = getContext();
2488-
Type *Ty = VT.getTypeForEVT(*Ctx);
2489-
RTLIB::Libcall LC =
2490-
RTLIB::getFSINCOS(VT.isVector() ? VT.getVectorElementType() : VT);
2489+
unsigned NumResults = Node->getNumValues();
24912490

24922491
const char *LCName = TLI->getLibcallName(LC);
24932492
if (!LC || !LCName)
@@ -2503,78 +2502,116 @@ bool SelectionDAG::expandFSINCOS(SDNode *Node,
25032502
return nullptr;
25042503
};
25052504

2505+
// For vector types, we must find a vector mapping for the libcall.
25062506
VecDesc const *VD = nullptr;
25072507
if (VT.isVector() && !(VD = getVecDesc()))
25082508
return false;
25092509

25102510
// Find users of the node that store the results (and share input chains). The
25112511
// destination pointers can be used instead of creating stack allocations.
25122512
SDValue StoresInChain{};
2513-
std::array<StoreSDNode *, 2> ResultStores = {nullptr};
2513+
SmallVector<StoreSDNode *, 2> ResultStores(NumResults);
25142514
for (SDNode *User : Node->uses()) {
25152515
if (!ISD::isNormalStore(User))
25162516
continue;
25172517
auto *ST = cast<StoreSDNode>(User);
2518-
if (!ST->isSimple() || ST->getAddressSpace() != 0 ||
2519-
ST->getAlign() < getDataLayout().getABITypeAlign(Ty->getScalarType()) ||
2518+
SDValue StoreValue = ST->getValue();
2519+
unsigned ResNo = StoreValue.getResNo();
2520+
Type *StoreType = StoreValue.getValueType().getTypeForEVT(Ctx);
2521+
if (CallRetResNo == ResNo || !ST->isSimple() ||
2522+
ST->getAddressSpace() != 0 ||
2523+
ST->getAlign() <
2524+
getDataLayout().getABITypeAlign(StoreType->getScalarType()) ||
25202525
(StoresInChain && ST->getChain() != StoresInChain) ||
25212526
Node->isPredecessorOf(ST->getChain().getNode()))
25222527
continue;
2523-
ResultStores[ST->getValue().getResNo()] = ST;
2528+
ResultStores[ResNo] = ST;
25242529
StoresInChain = ST->getChain();
25252530
}
25262531

25272532
TargetLowering::ArgListTy Args;
2528-
TargetLowering::ArgListEntry Entry{};
2533+
auto AddArgListEntry = [&](SDValue Node, Type *Ty) {
2534+
TargetLowering::ArgListEntry Entry{};
2535+
Entry.Ty = Ty;
2536+
Entry.Node = Node;
2537+
Args.push_back(Entry);
2538+
};
25292539

2530-
// Pass the argument.
2531-
Entry.Node = Node->getOperand(0);
2532-
Entry.Ty = Ty;
2533-
Args.push_back(Entry);
2540+
// Pass the arguments.
2541+
for (const SDValue &Op : Node->op_values()) {
2542+
EVT ArgVT = Op.getValueType();
2543+
Type *ArgTy = ArgVT.getTypeForEVT(Ctx);
2544+
AddArgListEntry(Op, ArgTy);
2545+
}
25342546

2535-
// Pass the output pointers for sin and cos.
2536-
SmallVector<SDValue, 2> ResultPtrs{};
2537-
for (StoreSDNode *ST : ResultStores) {
2538-
SDValue ResultPtr = ST ? ST->getBasePtr() : CreateStackTemporary(VT);
2539-
Entry.Node = ResultPtr;
2540-
Entry.Ty = PointerType::getUnqual(Ty->getContext());
2541-
Args.push_back(Entry);
2542-
ResultPtrs.push_back(ResultPtr);
2547+
// Pass the output pointers.
2548+
SmallVector<SDValue, 2> ResultPtrs(NumResults);
2549+
Type *PointerTy = PointerType::getUnqual(Ctx);
2550+
for (auto [ResNo, ST] : llvm::enumerate(ResultStores)) {
2551+
if (ResNo == CallRetResNo)
2552+
continue;
2553+
EVT ResVT = Node->getValueType(ResNo);
2554+
SDValue ResultPtr = ST ? ST->getBasePtr() : CreateStackTemporary(ResVT);
2555+
ResultPtrs[ResNo] = ResultPtr;
2556+
AddArgListEntry(ResultPtr, PointerTy);
25432557
}
25442558

25452559
SDLoc DL(Node);
25462560

2561+
// Pass the vector mask (if required).
25472562
if (VD && VD->isMasked()) {
2548-
EVT MaskVT = TLI->getSetCCResultType(getDataLayout(), *Ctx, VT);
2549-
Entry.Node = getBoolConstant(true, DL, MaskVT, VT);
2550-
Entry.Ty = MaskVT.getTypeForEVT(*Ctx);
2551-
Args.push_back(Entry);
2563+
EVT MaskVT = TLI->getSetCCResultType(getDataLayout(), Ctx, VT);
2564+
SDValue Mask = getBoolConstant(true, DL, MaskVT, VT);
2565+
AddArgListEntry(Mask, MaskVT.getTypeForEVT(Ctx));
25522566
}
25532567

2568+
Type *RetType = CallRetResNo.has_value()
2569+
? Node->getValueType(*CallRetResNo).getTypeForEVT(Ctx)
2570+
: Type::getVoidTy(Ctx);
25542571
SDValue InChain = StoresInChain ? StoresInChain : getEntryNode();
25552572
SDValue Callee = getExternalSymbol(VD ? VD->getVectorFnName().data() : LCName,
25562573
TLI->getPointerTy(getDataLayout()));
25572574
TargetLowering::CallLoweringInfo CLI(*this);
25582575
CLI.setDebugLoc(DL).setChain(InChain).setLibCallee(
2559-
TLI->getLibcallCallingConv(LC), Type::getVoidTy(*Ctx), Callee,
2560-
std::move(Args));
2576+
TLI->getLibcallCallingConv(LC), RetType, Callee, std::move(Args));
25612577

2562-
auto [Call, OutChain] = TLI->LowerCallTo(CLI);
2578+
auto [Call, CallChain] = TLI->LowerCallTo(CLI);
25632579

25642580
for (auto [ResNo, ResultPtr] : llvm::enumerate(ResultPtrs)) {
2581+
if (ResNo == CallRetResNo) {
2582+
Results.push_back(Call);
2583+
continue;
2584+
}
25652585
MachinePointerInfo PtrInfo;
25662586
if (StoreSDNode *ST = ResultStores[ResNo]) {
25672587
// Replace store with the library call.
2568-
ReplaceAllUsesOfValueWith(SDValue(ST, 0), OutChain);
2588+
ReplaceAllUsesOfValueWith(SDValue(ST, 0), CallChain);
25692589
PtrInfo = ST->getPointerInfo();
25702590
} else {
25712591
PtrInfo = MachinePointerInfo::getFixedStack(
25722592
getMachineFunction(), cast<FrameIndexSDNode>(ResultPtr)->getIndex());
25732593
}
2574-
SDValue LoadResult = getLoad(VT, DL, OutChain, ResultPtr, PtrInfo);
2594+
SDValue LoadResult =
2595+
getLoad(Node->getValueType(ResNo), DL, CallChain, ResultPtr, PtrInfo);
25752596
Results.push_back(LoadResult);
25762597
}
25772598

2599+
if (CallRetResNo && !Node->hasAnyUseOfValue(*CallRetResNo)) {
2600+
// FIXME: Find a way to avoid updating the root. This is needed for x86,
2601+
// which uses a floating-point stack. If (for example) the node to be
2602+
// expanded has two results one floating-point which is returned by the
2603+
// call, and one integer result, returned via an output pointer. If only the
2604+
// integer result is used then the `CopyFromReg` for the FP result may be
2605+
// optimized out. This prevents an FP stack pop from being emitted for it.
2606+
// Setting the root like this ensures there will be a use of the
2607+
// `CopyFromReg` chain, and ensures the FP pop will be emitted.
2608+
SDValue NewRoot =
2609+
getNode(ISD::TokenFactor, DL, MVT::Other, getRoot(), CallChain);
2610+
setRoot(NewRoot);
2611+
// Ensure the new root is reachable from the results.
2612+
Results[0] = getMergeValues({Results[0], NewRoot}, DL);
2613+
}
2614+
25782615
return true;
25792616
}
25802617

llvm/test/CodeGen/PowerPC/f128-arith.ll

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,45 +1365,33 @@ define dso_local fp128 @qpFREXP(ptr %a, ptr %b) {
13651365
; CHECK-LABEL: qpFREXP:
13661366
; CHECK: # %bb.0: # %entry
13671367
; CHECK-NEXT: mflr r0
1368-
; CHECK-NEXT: .cfi_def_cfa_offset 64
1368+
; CHECK-NEXT: stdu r1, -32(r1)
1369+
; CHECK-NEXT: std r0, 48(r1)
1370+
; CHECK-NEXT: .cfi_def_cfa_offset 32
13691371
; CHECK-NEXT: .cfi_offset lr, 16
1370-
; CHECK-NEXT: .cfi_offset r30, -16
1371-
; CHECK-NEXT: std r30, -16(r1) # 8-byte Folded Spill
1372-
; CHECK-NEXT: stdu r1, -64(r1)
1373-
; CHECK-NEXT: std r0, 80(r1)
1374-
; CHECK-NEXT: addi r5, r1, 44
1375-
; CHECK-NEXT: mr r30, r4
13761372
; CHECK-NEXT: lxv v2, 0(r3)
1373+
; CHECK-NEXT: mr r5, r4
13771374
; CHECK-NEXT: bl frexpf128
13781375
; CHECK-NEXT: nop
1379-
; CHECK-NEXT: lwz r3, 44(r1)
1380-
; CHECK-NEXT: stw r3, 0(r30)
1381-
; CHECK-NEXT: addi r1, r1, 64
1376+
; CHECK-NEXT: addi r1, r1, 32
13821377
; CHECK-NEXT: ld r0, 16(r1)
1383-
; CHECK-NEXT: ld r30, -16(r1) # 8-byte Folded Reload
13841378
; CHECK-NEXT: mtlr r0
13851379
; CHECK-NEXT: blr
13861380
;
13871381
; CHECK-P8-LABEL: qpFREXP:
13881382
; CHECK-P8: # %bb.0: # %entry
13891383
; CHECK-P8-NEXT: mflr r0
1390-
; CHECK-P8-NEXT: .cfi_def_cfa_offset 64
1384+
; CHECK-P8-NEXT: stdu r1, -32(r1)
1385+
; CHECK-P8-NEXT: std r0, 48(r1)
1386+
; CHECK-P8-NEXT: .cfi_def_cfa_offset 32
13911387
; CHECK-P8-NEXT: .cfi_offset lr, 16
1392-
; CHECK-P8-NEXT: .cfi_offset r30, -16
1393-
; CHECK-P8-NEXT: std r30, -16(r1) # 8-byte Folded Spill
1394-
; CHECK-P8-NEXT: stdu r1, -64(r1)
1395-
; CHECK-P8-NEXT: std r0, 80(r1)
1396-
; CHECK-P8-NEXT: addi r5, r1, 44
1397-
; CHECK-P8-NEXT: mr r30, r4
13981388
; CHECK-P8-NEXT: lxvd2x vs0, 0, r3
1389+
; CHECK-P8-NEXT: mr r5, r4
13991390
; CHECK-P8-NEXT: xxswapd v2, vs0
14001391
; CHECK-P8-NEXT: bl frexpf128
14011392
; CHECK-P8-NEXT: nop
1402-
; CHECK-P8-NEXT: lwz r3, 44(r1)
1403-
; CHECK-P8-NEXT: stw r3, 0(r30)
1404-
; CHECK-P8-NEXT: addi r1, r1, 64
1393+
; CHECK-P8-NEXT: addi r1, r1, 32
14051394
; CHECK-P8-NEXT: ld r0, 16(r1)
1406-
; CHECK-P8-NEXT: ld r30, -16(r1) # 8-byte Folded Reload
14071395
; CHECK-P8-NEXT: mtlr r0
14081396
; CHECK-P8-NEXT: blr
14091397
entry:

0 commit comments

Comments
 (0)