@@ -2481,13 +2481,12 @@ SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
2481
2481
return Subvectors[0];
2482
2482
}
2483
2483
2484
- bool SelectionDAG::expandFSINCOS(SDNode *Node,
2485
- SmallVectorImpl<SDValue> &Results) {
2484
+ bool SelectionDAG::expandMultipleResultFPLibCall(
2485
+ RTLIB::Libcall LC, SDNode *Node, SmallVectorImpl<SDValue> &Results,
2486
+ std::optional<unsigned> CallRetResNo) {
2487
+ LLVMContext &Ctx = *getContext();
2486
2488
EVT VT = Node->getValueType(0);
2487
- LLVMContext *Ctx = getContext();
2488
- Type *Ty = VT.getTypeForEVT(*Ctx);
2489
- RTLIB::Libcall LC =
2490
- RTLIB::getFSINCOS(VT.isVector() ? VT.getVectorElementType() : VT);
2489
+ unsigned NumResults = Node->getNumValues();
2491
2490
2492
2491
const char *LCName = TLI->getLibcallName(LC);
2493
2492
if (!LC || !LCName)
@@ -2503,78 +2502,116 @@ bool SelectionDAG::expandFSINCOS(SDNode *Node,
2503
2502
return nullptr;
2504
2503
};
2505
2504
2505
+ // For vector types, we must find a vector mapping for the libcall.
2506
2506
VecDesc const *VD = nullptr;
2507
2507
if (VT.isVector() && !(VD = getVecDesc()))
2508
2508
return false;
2509
2509
2510
2510
// Find users of the node that store the results (and share input chains). The
2511
2511
// destination pointers can be used instead of creating stack allocations.
2512
2512
SDValue StoresInChain{};
2513
- std::array <StoreSDNode *, 2> ResultStores = {nullptr} ;
2513
+ SmallVector <StoreSDNode *, 2> ResultStores(NumResults) ;
2514
2514
for (SDNode *User : Node->uses()) {
2515
2515
if (!ISD::isNormalStore(User))
2516
2516
continue;
2517
2517
auto *ST = cast<StoreSDNode>(User);
2518
- if (!ST->isSimple() || ST->getAddressSpace() != 0 ||
2519
- ST->getAlign() < getDataLayout().getABITypeAlign(Ty->getScalarType()) ||
2518
+ SDValue StoreValue = ST->getValue();
2519
+ unsigned ResNo = StoreValue.getResNo();
2520
+ Type *StoreType = StoreValue.getValueType().getTypeForEVT(Ctx);
2521
+ if (CallRetResNo == ResNo || !ST->isSimple() ||
2522
+ ST->getAddressSpace() != 0 ||
2523
+ ST->getAlign() <
2524
+ getDataLayout().getABITypeAlign(StoreType->getScalarType()) ||
2520
2525
(StoresInChain && ST->getChain() != StoresInChain) ||
2521
2526
Node->isPredecessorOf(ST->getChain().getNode()))
2522
2527
continue;
2523
- ResultStores[ST->getValue().getResNo() ] = ST;
2528
+ ResultStores[ResNo ] = ST;
2524
2529
StoresInChain = ST->getChain();
2525
2530
}
2526
2531
2527
2532
TargetLowering::ArgListTy Args;
2528
- TargetLowering::ArgListEntry Entry{};
2533
+ auto AddArgListEntry = [&](SDValue Node, Type *Ty) {
2534
+ TargetLowering::ArgListEntry Entry{};
2535
+ Entry.Ty = Ty;
2536
+ Entry.Node = Node;
2537
+ Args.push_back(Entry);
2538
+ };
2529
2539
2530
- // Pass the argument.
2531
- Entry.Node = Node->getOperand(0);
2532
- Entry.Ty = Ty;
2533
- Args.push_back(Entry);
2540
+ // Pass the arguments.
2541
+ for (const SDValue &Op : Node->op_values()) {
2542
+ EVT ArgVT = Op.getValueType();
2543
+ Type *ArgTy = ArgVT.getTypeForEVT(Ctx);
2544
+ AddArgListEntry(Op, ArgTy);
2545
+ }
2534
2546
2535
- // Pass the output pointers for sin and cos.
2536
- SmallVector<SDValue, 2> ResultPtrs{};
2537
- for (StoreSDNode *ST : ResultStores) {
2538
- SDValue ResultPtr = ST ? ST->getBasePtr() : CreateStackTemporary(VT);
2539
- Entry.Node = ResultPtr;
2540
- Entry.Ty = PointerType::getUnqual(Ty->getContext());
2541
- Args.push_back(Entry);
2542
- ResultPtrs.push_back(ResultPtr);
2547
+ // Pass the output pointers.
2548
+ SmallVector<SDValue, 2> ResultPtrs(NumResults);
2549
+ Type *PointerTy = PointerType::getUnqual(Ctx);
2550
+ for (auto [ResNo, ST] : llvm::enumerate(ResultStores)) {
2551
+ if (ResNo == CallRetResNo)
2552
+ continue;
2553
+ EVT ResVT = Node->getValueType(ResNo);
2554
+ SDValue ResultPtr = ST ? ST->getBasePtr() : CreateStackTemporary(ResVT);
2555
+ ResultPtrs[ResNo] = ResultPtr;
2556
+ AddArgListEntry(ResultPtr, PointerTy);
2543
2557
}
2544
2558
2545
2559
SDLoc DL(Node);
2546
2560
2561
+ // Pass the vector mask (if required).
2547
2562
if (VD && VD->isMasked()) {
2548
- EVT MaskVT = TLI->getSetCCResultType(getDataLayout(), *Ctx, VT);
2549
- Entry.Node = getBoolConstant(true, DL, MaskVT, VT);
2550
- Entry.Ty = MaskVT.getTypeForEVT(*Ctx);
2551
- Args.push_back(Entry);
2563
+ EVT MaskVT = TLI->getSetCCResultType(getDataLayout(), Ctx, VT);
2564
+ SDValue Mask = getBoolConstant(true, DL, MaskVT, VT);
2565
+ AddArgListEntry(Mask, MaskVT.getTypeForEVT(Ctx));
2552
2566
}
2553
2567
2568
+ Type *RetType = CallRetResNo.has_value()
2569
+ ? Node->getValueType(*CallRetResNo).getTypeForEVT(Ctx)
2570
+ : Type::getVoidTy(Ctx);
2554
2571
SDValue InChain = StoresInChain ? StoresInChain : getEntryNode();
2555
2572
SDValue Callee = getExternalSymbol(VD ? VD->getVectorFnName().data() : LCName,
2556
2573
TLI->getPointerTy(getDataLayout()));
2557
2574
TargetLowering::CallLoweringInfo CLI(*this);
2558
2575
CLI.setDebugLoc(DL).setChain(InChain).setLibCallee(
2559
- TLI->getLibcallCallingConv(LC), Type::getVoidTy(*Ctx), Callee,
2560
- std::move(Args));
2576
+ TLI->getLibcallCallingConv(LC), RetType, Callee, std::move(Args));
2561
2577
2562
- auto [Call, OutChain ] = TLI->LowerCallTo(CLI);
2578
+ auto [Call, CallChain ] = TLI->LowerCallTo(CLI);
2563
2579
2564
2580
for (auto [ResNo, ResultPtr] : llvm::enumerate(ResultPtrs)) {
2581
+ if (ResNo == CallRetResNo) {
2582
+ Results.push_back(Call);
2583
+ continue;
2584
+ }
2565
2585
MachinePointerInfo PtrInfo;
2566
2586
if (StoreSDNode *ST = ResultStores[ResNo]) {
2567
2587
// Replace store with the library call.
2568
- ReplaceAllUsesOfValueWith(SDValue(ST, 0), OutChain );
2588
+ ReplaceAllUsesOfValueWith(SDValue(ST, 0), CallChain );
2569
2589
PtrInfo = ST->getPointerInfo();
2570
2590
} else {
2571
2591
PtrInfo = MachinePointerInfo::getFixedStack(
2572
2592
getMachineFunction(), cast<FrameIndexSDNode>(ResultPtr)->getIndex());
2573
2593
}
2574
- SDValue LoadResult = getLoad(VT, DL, OutChain, ResultPtr, PtrInfo);
2594
+ SDValue LoadResult =
2595
+ getLoad(Node->getValueType(ResNo), DL, CallChain, ResultPtr, PtrInfo);
2575
2596
Results.push_back(LoadResult);
2576
2597
}
2577
2598
2599
+ if (CallRetResNo && !Node->hasAnyUseOfValue(*CallRetResNo)) {
2600
+ // FIXME: Find a way to avoid updating the root. This is needed for x86,
2601
+ // which uses a floating-point stack. If (for example) the node to be
2602
+ // expanded has two results one floating-point which is returned by the
2603
+ // call, and one integer result, returned via an output pointer. If only the
2604
+ // integer result is used then the `CopyFromReg` for the FP result may be
2605
+ // optimized out. This prevents an FP stack pop from being emitted for it.
2606
+ // Setting the root like this ensures there will be a use of the
2607
+ // `CopyFromReg` chain, and ensures the FP pop will be emitted.
2608
+ SDValue NewRoot =
2609
+ getNode(ISD::TokenFactor, DL, MVT::Other, getRoot(), CallChain);
2610
+ setRoot(NewRoot);
2611
+ // Ensure the new root is reachable from the results.
2612
+ Results[0] = getMergeValues({Results[0], NewRoot}, DL);
2613
+ }
2614
+
2578
2615
return true;
2579
2616
}
2580
2617
0 commit comments