Skip to content

Commit 915aef3

Browse files
committed
[SDAG] Support expanding FSINCOS to vector library calls
This shares most of its code with the scalar sincos expansion. It allows expanding vector FSINCOS nodes to a library call from the specified `-vector-library`. The upside of this is it will mean the vectorizer only needs to handle the sincos intrinsic, which has no memory effects, and this can handle lowering the intrinsic to a call that takes output pointers.
1 parent 98c8d64 commit 915aef3

File tree

5 files changed

+216
-70
lines changed

5 files changed

+216
-70
lines changed

llvm/include/llvm/CodeGen/SelectionDAG.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1599,6 +1599,9 @@ class SelectionDAG {
15991599
SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
16001600
SDValue Op2);
16011601

1602+
/// Expand the specified \c ISD::FSINCOS node as the Legalize pass would.
1603+
bool expandFSINCOS(SDNode *Node, SmallVectorImpl<SDValue> &Results);
1604+
16021605
/// Expand the specified \c ISD::VAARG node as the Legalize pass would.
16031606
SDValue expandVAArg(SDNode *Node);
16041607

llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp

Lines changed: 1 addition & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -2348,75 +2348,6 @@ static bool useSinCos(SDNode *Node) {
23482348
return false;
23492349
}
23502350

2351-
/// Issue libcalls to sincos to compute sin / cos pairs.
2352-
void SelectionDAGLegalize::ExpandSinCosLibCall(
2353-
SDNode *Node, SmallVectorImpl<SDValue> &Results) {
2354-
EVT VT = Node->getValueType(0);
2355-
Type *Ty = VT.getTypeForEVT(*DAG.getContext());
2356-
RTLIB::Libcall LC = RTLIB::getFSINCOS(VT);
2357-
2358-
// Find users of the node that store the results (and share input chains). The
2359-
// destination pointers can be used instead of creating stack allocations.
2360-
SDValue StoresInChain{};
2361-
std::array<StoreSDNode *, 2> ResultStores = {nullptr};
2362-
for (SDNode *User : Node->uses()) {
2363-
if (!ISD::isNormalStore(User))
2364-
continue;
2365-
auto *ST = cast<StoreSDNode>(User);
2366-
if (!ST->isSimple() || ST->getAddressSpace() != 0 ||
2367-
ST->getAlign() < DAG.getDataLayout().getABITypeAlign(Ty) ||
2368-
(StoresInChain && ST->getChain() != StoresInChain) ||
2369-
Node->isPredecessorOf(ST->getChain().getNode()))
2370-
continue;
2371-
ResultStores[ST->getValue().getResNo()] = ST;
2372-
StoresInChain = ST->getChain();
2373-
}
2374-
2375-
TargetLowering::ArgListTy Args;
2376-
TargetLowering::ArgListEntry Entry{};
2377-
2378-
// Pass the argument.
2379-
Entry.Node = Node->getOperand(0);
2380-
Entry.Ty = Ty;
2381-
Args.push_back(Entry);
2382-
2383-
// Pass the output pointers for sin and cos.
2384-
SmallVector<SDValue, 2> ResultPtrs{};
2385-
for (StoreSDNode *ST : ResultStores) {
2386-
SDValue ResultPtr = ST ? ST->getBasePtr() : DAG.CreateStackTemporary(VT);
2387-
Entry.Node = ResultPtr;
2388-
Entry.Ty = PointerType::getUnqual(Ty->getContext());
2389-
Args.push_back(Entry);
2390-
ResultPtrs.push_back(ResultPtr);
2391-
}
2392-
2393-
SDLoc DL(Node);
2394-
SDValue InChain = StoresInChain ? StoresInChain : DAG.getEntryNode();
2395-
SDValue Callee = DAG.getExternalSymbol(TLI.getLibcallName(LC),
2396-
TLI.getPointerTy(DAG.getDataLayout()));
2397-
TargetLowering::CallLoweringInfo CLI(DAG);
2398-
CLI.setDebugLoc(DL).setChain(InChain).setLibCallee(
2399-
TLI.getLibcallCallingConv(LC), Type::getVoidTy(*DAG.getContext()), Callee,
2400-
std::move(Args));
2401-
2402-
auto [Call, OutChain] = TLI.LowerCallTo(CLI);
2403-
2404-
for (auto [ResNo, ResultPtr] : llvm::enumerate(ResultPtrs)) {
2405-
MachinePointerInfo PtrInfo;
2406-
if (StoreSDNode *ST = ResultStores[ResNo]) {
2407-
// Replace store with the library call.
2408-
DAG.ReplaceAllUsesOfValueWith(SDValue(ST, 0), OutChain);
2409-
PtrInfo = ST->getPointerInfo();
2410-
} else {
2411-
PtrInfo = MachinePointerInfo::getFixedStack(
2412-
DAG.getMachineFunction(),
2413-
cast<FrameIndexSDNode>(ResultPtr)->getIndex());
2414-
}
2415-
SDValue LoadResult = DAG.getLoad(VT, DL, OutChain, ResultPtr, PtrInfo);
2416-
Results.push_back(LoadResult);
2417-
}
2418-
}
2419-
24202351
SDValue SelectionDAGLegalize::expandLdexp(SDNode *Node) const {
24212352
SDLoc dl(Node);
24222353
EVT VT = Node->getValueType(0);
@@ -4633,7 +4564,7 @@ void SelectionDAGLegalize::ConvertNodeToLibcall(SDNode *Node) {
46334564
break;
46344565
case ISD::FSINCOS:
46354566
// Expand into sincos libcall.
4636-
ExpandSinCosLibCall(Node, Results);
4567+
(void)DAG.expandFSINCOS(Node, Results);
46374568
break;
46384569
case ISD::FLOG:
46394570
case ISD::STRICT_FLOG:

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,6 +1191,11 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
11911191
RTLIB::REM_PPCF128, Results))
11921192
return;
11931193

1194+
break;
1195+
case ISD::FSINCOS:
1196+
if (DAG.expandFSINCOS(Node, Results))
1197+
return;
1198+
11941199
break;
11951200
case ISD::VECTOR_COMPRESS:
11961201
Results.push_back(TLI.expandVECTOR_COMPRESS(Node, DAG));

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "llvm/ADT/Twine.h"
2626
#include "llvm/Analysis/AliasAnalysis.h"
2727
#include "llvm/Analysis/MemoryLocation.h"
28+
#include "llvm/Analysis/TargetLibraryInfo.h"
2829
#include "llvm/Analysis/ValueTracking.h"
2930
#include "llvm/Analysis/VectorUtils.h"
3031
#include "llvm/BinaryFormat/Dwarf.h"
@@ -2483,6 +2484,103 @@ SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
24832484
return Subvectors[0];
24842485
}
24852486

2487+
bool SelectionDAG::expandFSINCOS(SDNode *Node,
2488+
SmallVectorImpl<SDValue> &Results) {
2489+
EVT VT = Node->getValueType(0);
2490+
LLVMContext *Ctx = getContext();
2491+
Type *Ty = VT.getTypeForEVT(*Ctx);
2492+
RTLIB::Libcall LC =
2493+
RTLIB::getFSINCOS(VT.isVector() ? VT.getVectorElementType() : VT);
2494+
2495+
const char *LCName = TLI->getLibcallName(LC);
2496+
if (!LC || !LCName)
2497+
return false;
2498+
2499+
auto getVecDesc = [&]() -> VecDesc const * {
2500+
for (bool Masked : {false, true}) {
2501+
if (VecDesc const *VD = getLibInfo().getVectorMappingInfo(
2502+
LCName, VT.getVectorElementCount(), Masked)) {
2503+
return VD;
2504+
}
2505+
}
2506+
return nullptr;
2507+
};
2508+
2509+
VecDesc const *VD = nullptr;
2510+
if (VT.isVector() && !(VD = getVecDesc()))
2511+
return false;
2512+
2513+
// Find users of the node that store the results (and share input chains). The
2514+
// destination pointers can be used instead of creating stack allocations.
2515+
SDValue StoresInChain{};
2516+
std::array<StoreSDNode *, 2> ResultStores = {nullptr};
2517+
for (SDNode *User : Node->uses()) {
2518+
if (!ISD::isNormalStore(User))
2519+
continue;
2520+
auto *ST = cast<StoreSDNode>(User);
2521+
if (!ST->isSimple() || ST->getAddressSpace() != 0 ||
2522+
ST->getAlign() < getDataLayout().getABITypeAlign(Ty->getScalarType()) ||
2523+
(StoresInChain && ST->getChain() != StoresInChain) ||
2524+
Node->isPredecessorOf(ST->getChain().getNode()))
2525+
continue;
2526+
ResultStores[ST->getValue().getResNo()] = ST;
2527+
StoresInChain = ST->getChain();
2528+
}
2529+
2530+
TargetLowering::ArgListTy Args;
2531+
TargetLowering::ArgListEntry Entry{};
2532+
2533+
// Pass the argument.
2534+
Entry.Node = Node->getOperand(0);
2535+
Entry.Ty = Ty;
2536+
Args.push_back(Entry);
2537+
2538+
// Pass the output pointers for sin and cos.
2539+
SmallVector<SDValue, 2> ResultPtrs{};
2540+
for (StoreSDNode *ST : ResultStores) {
2541+
SDValue ResultPtr = ST ? ST->getBasePtr() : CreateStackTemporary(VT);
2542+
Entry.Node = ResultPtr;
2543+
Entry.Ty = PointerType::getUnqual(Ty->getContext());
2544+
Args.push_back(Entry);
2545+
ResultPtrs.push_back(ResultPtr);
2546+
}
2547+
2548+
SDLoc DL(Node);
2549+
2550+
if (VD && VD->isMasked()) {
2551+
EVT MaskVT = TLI->getSetCCResultType(getDataLayout(), *Ctx, VT);
2552+
Entry.Node = getBoolConstant(true, DL, MaskVT, VT);
2553+
Entry.Ty = MaskVT.getTypeForEVT(*Ctx);
2554+
Args.push_back(Entry);
2555+
}
2556+
2557+
SDValue InChain = StoresInChain ? StoresInChain : getEntryNode();
2558+
SDValue Callee = getExternalSymbol(VD ? VD->getVectorFnName().data() : LCName,
2559+
TLI->getPointerTy(getDataLayout()));
2560+
TargetLowering::CallLoweringInfo CLI(*this);
2561+
CLI.setDebugLoc(DL).setChain(InChain).setLibCallee(
2562+
TLI->getLibcallCallingConv(LC), Type::getVoidTy(*Ctx), Callee,
2563+
std::move(Args));
2564+
2565+
auto [Call, OutChain] = TLI->LowerCallTo(CLI);
2566+
2567+
for (auto [ResNo, ResultPtr] : llvm::enumerate(ResultPtrs)) {
2568+
MachinePointerInfo PtrInfo;
2569+
if (StoreSDNode *ST = ResultStores[ResNo]) {
2570+
// Replace store with the library call.
2571+
ReplaceAllUsesOfValueWith(SDValue(ST, 0), OutChain);
2572+
PtrInfo = ST->getPointerInfo();
2573+
} else {
2574+
PtrInfo = MachinePointerInfo::getFixedStack(
2575+
getMachineFunction(), cast<FrameIndexSDNode>(ResultPtr)->getIndex());
2576+
}
2577+
SDValue LoadResult = getLoad(VT, DL, OutChain, ResultPtr, PtrInfo);
2578+
Results.push_back(LoadResult);
2579+
}
2580+
2581+
return true;
2582+
}
2583+
24862584
SDValue SelectionDAG::expandVAArg(SDNode *Node) {
24872585
SDLoc dl(Node);
24882586
const TargetLowering &TLI = getTargetLoweringInfo();
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc -mtriple=aarch64-gnu-linux -mattr=+neon,+sve -vector-library=sleefgnuabi < %s | FileCheck %s -check-prefix=SLEEF
3+
; RUN: llc -mtriple=aarch64-gnu-linux -mattr=+neon,+sve -vector-library=ArmPL < %s | FileCheck %s -check-prefix=ARMPL
4+
5+
define void @test_sincos_v4f32(<4 x float> %x, ptr noalias %out_sin, ptr noalias %out_cos) {
6+
; SLEEF-LABEL: test_sincos_v4f32:
7+
; SLEEF: // %bb.0:
8+
; SLEEF-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
9+
; SLEEF-NEXT: .cfi_def_cfa_offset 16
10+
; SLEEF-NEXT: .cfi_offset w30, -16
11+
; SLEEF-NEXT: bl _ZGVnN4vl4l4_sincosf
12+
; SLEEF-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
13+
; SLEEF-NEXT: ret
14+
;
15+
; ARMPL-LABEL: test_sincos_v4f32:
16+
; ARMPL: // %bb.0:
17+
; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
18+
; ARMPL-NEXT: .cfi_def_cfa_offset 16
19+
; ARMPL-NEXT: .cfi_offset w30, -16
20+
; ARMPL-NEXT: bl armpl_vsincosq_f32
21+
; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
22+
; ARMPL-NEXT: ret
23+
%result = call { <4 x float>, <4 x float> } @llvm.sincos.v4f32(<4 x float> %x)
24+
%result.0 = extractvalue { <4 x float>, <4 x float> } %result, 0
25+
%result.1 = extractvalue { <4 x float>, <4 x float> } %result, 1
26+
store <4 x float> %result.0, ptr %out_sin, align 4
27+
store <4 x float> %result.1, ptr %out_cos, align 4
28+
ret void
29+
}
30+
31+
define void @test_sincos_v2f64(<2 x double> %x, ptr noalias %out_sin, ptr noalias %out_cos) {
32+
; SLEEF-LABEL: test_sincos_v2f64:
33+
; SLEEF: // %bb.0:
34+
; SLEEF-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
35+
; SLEEF-NEXT: .cfi_def_cfa_offset 16
36+
; SLEEF-NEXT: .cfi_offset w30, -16
37+
; SLEEF-NEXT: bl _ZGVnN2vl8l8_sincos
38+
; SLEEF-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
39+
; SLEEF-NEXT: ret
40+
;
41+
; ARMPL-LABEL: test_sincos_v2f64:
42+
; ARMPL: // %bb.0:
43+
; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
44+
; ARMPL-NEXT: .cfi_def_cfa_offset 16
45+
; ARMPL-NEXT: .cfi_offset w30, -16
46+
; ARMPL-NEXT: bl armpl_vsincosq_f64
47+
; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
48+
; ARMPL-NEXT: ret
49+
%result = call { <2 x double>, <2 x double> } @llvm.sincos.v2f64(<2 x double> %x)
50+
%result.0 = extractvalue { <2 x double>, <2 x double> } %result, 0
51+
%result.1 = extractvalue { <2 x double>, <2 x double> } %result, 1
52+
store <2 x double> %result.0, ptr %out_sin, align 8
53+
store <2 x double> %result.1, ptr %out_cos, align 8
54+
ret void
55+
}
56+
57+
define void @test_sincos_nxv4f32(<vscale x 4 x float> %x, ptr noalias %out_sin, ptr noalias %out_cos) {
58+
; SLEEF-LABEL: test_sincos_nxv4f32:
59+
; SLEEF: // %bb.0:
60+
; SLEEF-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
61+
; SLEEF-NEXT: .cfi_def_cfa_offset 16
62+
; SLEEF-NEXT: .cfi_offset w30, -16
63+
; SLEEF-NEXT: bl _ZGVsNxvl4l4_sincosf
64+
; SLEEF-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
65+
; SLEEF-NEXT: ret
66+
;
67+
; ARMPL-LABEL: test_sincos_nxv4f32:
68+
; ARMPL: // %bb.0:
69+
; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
70+
; ARMPL-NEXT: .cfi_def_cfa_offset 16
71+
; ARMPL-NEXT: .cfi_offset w30, -16
72+
; ARMPL-NEXT: ptrue p0.s
73+
; ARMPL-NEXT: bl armpl_svsincos_f32_x
74+
; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
75+
; ARMPL-NEXT: ret
76+
%result = call { <vscale x 4 x float>, <vscale x 4 x float> } @llvm.sincos.nxv4f32(<vscale x 4 x float> %x)
77+
%result.0 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float> } %result, 0
78+
%result.1 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float> } %result, 1
79+
store <vscale x 4 x float> %result.0, ptr %out_sin, align 4
80+
store <vscale x 4 x float> %result.1, ptr %out_cos, align 4
81+
ret void
82+
}
83+
84+
define void @test_sincos_nxv2f64(<vscale x 2 x double> %x, ptr noalias %out_sin, ptr noalias %out_cos) {
85+
; SLEEF-LABEL: test_sincos_nxv2f64:
86+
; SLEEF: // %bb.0:
87+
; SLEEF-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
88+
; SLEEF-NEXT: .cfi_def_cfa_offset 16
89+
; SLEEF-NEXT: .cfi_offset w30, -16
90+
; SLEEF-NEXT: bl _ZGVsNxvl8l8_sincos
91+
; SLEEF-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
92+
; SLEEF-NEXT: ret
93+
;
94+
; ARMPL-LABEL: test_sincos_nxv2f64:
95+
; ARMPL: // %bb.0:
96+
; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
97+
; ARMPL-NEXT: .cfi_def_cfa_offset 16
98+
; ARMPL-NEXT: .cfi_offset w30, -16
99+
; ARMPL-NEXT: ptrue p0.d
100+
; ARMPL-NEXT: bl armpl_svsincos_f64_x
101+
; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
102+
; ARMPL-NEXT: ret
103+
%result = call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.sincos.nxv2f64(<vscale x 2 x double> %x)
104+
%result.0 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %result, 0
105+
%result.1 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %result, 1
106+
store <vscale x 2 x double> %result.0, ptr %out_sin, align 8
107+
store <vscale x 2 x double> %result.1, ptr %out_cos, align 8
108+
ret void
109+
}

0 commit comments

Comments
 (0)