Skip to content

Commit c021e16

Browse files
authored
[AArch64][SME] Handle SME state around TLS-descriptor calls (#155608)
This patch ensures we switch out of streaming mode before TLS-descriptor calls. ZA state will also be preserved when using the new SME ABI lowering (`-aarch64-new-sme-abi`). Fixes #152165
1 parent 40a9e34 commit c021e16

File tree

5 files changed

+205
-10
lines changed

5 files changed

+205
-10
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9602,8 +9602,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96029602
// using a chain can result in incorrect scheduling. The markers refer to
96039603
// the position just before the CALLSEQ_START (though occur after as
96049604
// CALLSEQ_START lacks in-glue).
9605-
Chain = DAG.getNode(*ZAMarkerNode, DL, DAG.getVTList(MVT::Other),
9606-
{Chain, Chain.getValue(1)});
9605+
Chain =
9606+
DAG.getNode(*ZAMarkerNode, DL, DAG.getVTList(MVT::Other, MVT::Glue),
9607+
{Chain, Chain.getValue(1)});
96079608
}
96089609
}
96099610

@@ -10608,16 +10609,41 @@ SDValue AArch64TargetLowering::LowerELFTLSDescCallSeq(SDValue SymAddr,
1060810609
const SDLoc &DL,
1060910610
SelectionDAG &DAG) const {
1061010611
EVT PtrVT = getPointerTy(DAG.getDataLayout());
10612+
auto &MF = DAG.getMachineFunction();
10613+
auto *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
1061110614

10615+
SDValue Glue;
1061210616
SDValue Chain = DAG.getEntryNode();
1061310617
SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);
1061410618

10619+
SMECallAttrs TLSCallAttrs(FuncInfo->getSMEFnAttrs(), {}, SMEAttrs::Normal);
10620+
bool RequiresSMChange = TLSCallAttrs.requiresSMChange();
10621+
10622+
auto ChainAndGlue = [](SDValue Chain) -> std::pair<SDValue, SDValue> {
10623+
return {Chain, Chain.getValue(1)};
10624+
};
10625+
10626+
if (RequiresSMChange)
10627+
std::tie(Chain, Glue) =
10628+
ChainAndGlue(changeStreamingMode(DAG, DL, /*Enable=*/false, Chain, Glue,
10629+
getSMToggleCondition(TLSCallAttrs)));
10630+
1061510631
unsigned Opcode =
1061610632
DAG.getMachineFunction().getInfo<AArch64FunctionInfo>()->hasELFSignedGOT()
1061710633
? AArch64ISD::TLSDESC_AUTH_CALLSEQ
1061810634
: AArch64ISD::TLSDESC_CALLSEQ;
10619-
Chain = DAG.getNode(Opcode, DL, NodeTys, {Chain, SymAddr});
10620-
SDValue Glue = Chain.getValue(1);
10635+
SDValue Ops[] = {Chain, SymAddr, Glue};
10636+
std::tie(Chain, Glue) = ChainAndGlue(DAG.getNode(
10637+
Opcode, DL, NodeTys, Glue ? ArrayRef(Ops) : ArrayRef(Ops).drop_back()));
10638+
10639+
if (TLSCallAttrs.requiresLazySave())
10640+
std::tie(Chain, Glue) = ChainAndGlue(DAG.getNode(
10641+
AArch64ISD::REQUIRES_ZA_SAVE, DL, NodeTys, {Chain, Chain.getValue(1)}));
10642+
10643+
if (RequiresSMChange)
10644+
std::tie(Chain, Glue) =
10645+
ChainAndGlue(changeStreamingMode(DAG, DL, /*Enable=*/true, Chain, Glue,
10646+
getSMToggleCondition(TLSCallAttrs)));
1062110647

1062210648
return DAG.getCopyFromReg(Chain, DL, AArch64::X0, PtrVT, Glue);
1062310649
}

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,11 +1042,11 @@ def AArch64uitof: SDNode<"AArch64ISD::UITOF", SDT_AArch64ITOF>;
10421042
// offset of a variable into X0, using the TLSDesc model.
10431043
def AArch64tlsdesc_callseq : SDNode<"AArch64ISD::TLSDESC_CALLSEQ",
10441044
SDT_AArch64TLSDescCallSeq,
1045-
[SDNPOutGlue, SDNPHasChain, SDNPVariadic]>;
1045+
[SDNPOutGlue, SDNPOptInGlue, SDNPHasChain, SDNPVariadic]>;
10461046

10471047
def AArch64tlsdesc_auth_callseq : SDNode<"AArch64ISD::TLSDESC_AUTH_CALLSEQ",
10481048
SDT_AArch64TLSDescCallSeq,
1049-
[SDNPOutGlue, SDNPHasChain, SDNPVariadic]>;
1049+
[SDNPOutGlue, SDNPOptInGlue, SDNPHasChain, SDNPVariadic]>;
10501050

10511051
def AArch64WrapperLarge : SDNode<"AArch64ISD::WrapperLarge",
10521052
SDT_AArch64WrapperLarge>;

llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,12 @@ def CommitZASavePseudo
113113

114114
def AArch64_inout_za_use
115115
: SDNode<"AArch64ISD::INOUT_ZA_USE", SDTypeProfile<0, 0,[]>,
116-
[SDNPHasChain, SDNPInGlue]>;
116+
[SDNPHasChain, SDNPInGlue, SDNPOutGlue]>;
117117
def : Pat<(AArch64_inout_za_use), (InOutZAUsePseudo)>;
118118

119119
def AArch64_requires_za_save
120120
: SDNode<"AArch64ISD::REQUIRES_ZA_SAVE", SDTypeProfile<0, 0,[]>,
121-
[SDNPHasChain, SDNPInGlue]>;
121+
[SDNPHasChain, SDNPInGlue, SDNPOutGlue]>;
122122
def : Pat<(AArch64_requires_za_save), (RequiresZASavePseudo)>;
123123

124124
def AArch64_sme_state_alloc

llvm/lib/Target/AArch64/MachineSMEABIPass.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,17 @@ static void setPhysLiveRegs(LiveRegUnits &LiveUnits, LiveRegs PhysLiveRegs) {
381381
LiveUnits.addReg(AArch64::W0_HI);
382382
}
383383

384+
[[maybe_unused]] bool isCallStartOpcode(unsigned Opc) {
385+
switch (Opc) {
386+
case AArch64::TLSDESC_CALLSEQ:
387+
case AArch64::TLSDESC_AUTH_CALLSEQ:
388+
case AArch64::ADJCALLSTACKDOWN:
389+
return true;
390+
default:
391+
return false;
392+
}
393+
}
394+
384395
FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
385396
assert((SMEFnAttrs.hasAgnosticZAInterface() || SMEFnAttrs.hasZT0State() ||
386397
SMEFnAttrs.hasZAState()) &&
@@ -424,8 +435,7 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
424435
// Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
425436
auto [NeededState, InsertPt] = getZAStateBeforeInst(
426437
*TRI, MI, /*ZAOffAtReturn=*/SMEFnAttrs.hasPrivateZAInterface());
427-
assert((InsertPt == MBBI ||
428-
InsertPt->getOpcode() == AArch64::ADJCALLSTACKDOWN) &&
438+
assert((InsertPt == MBBI || isCallStartOpcode(InsertPt->getOpcode())) &&
429439
"Unexpected state change insertion point!");
430440
// TODO: Do something to avoid state changes where NZCV is live.
431441
if (MBBI == FirstTerminatorInsertPt)
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc -mtriple=aarch64-unknown-linux-gnu -mattr=+sme -aarch64-new-sme-abi -relocation-model=pic < %s | FileCheck %s
3+
4+
@x = external thread_local local_unnamed_addr global i32, align 4
5+
6+
define i32 @load_tls_streaming_compat() nounwind "aarch64_pstate_sm_compatible" {
7+
; CHECK-LABEL: load_tls_streaming_compat:
8+
; CHECK: // %bb.0: // %entry
9+
; CHECK-NEXT: stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
10+
; CHECK-NEXT: stp d13, d12, [sp, #16] // 16-byte Folded Spill
11+
; CHECK-NEXT: stp d11, d10, [sp, #32] // 16-byte Folded Spill
12+
; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill
13+
; CHECK-NEXT: str x30, [sp, #64] // 8-byte Folded Spill
14+
; CHECK-NEXT: mrs x8, SVCR
15+
; CHECK-NEXT: tbz w8, #0, .LBB0_2
16+
; CHECK-NEXT: // %bb.1: // %entry
17+
; CHECK-NEXT: smstop sm
18+
; CHECK-NEXT: .LBB0_2: // %entry
19+
; CHECK-NEXT: adrp x0, :tlsdesc:x
20+
; CHECK-NEXT: ldr x1, [x0, :tlsdesc_lo12:x]
21+
; CHECK-NEXT: add x0, x0, :tlsdesc_lo12:x
22+
; CHECK-NEXT: .tlsdesccall x
23+
; CHECK-NEXT: blr x1
24+
; CHECK-NEXT: tbz w8, #0, .LBB0_4
25+
; CHECK-NEXT: // %bb.3: // %entry
26+
; CHECK-NEXT: smstart sm
27+
; CHECK-NEXT: .LBB0_4: // %entry
28+
; CHECK-NEXT: mrs x8, TPIDR_EL0
29+
; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
30+
; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload
31+
; CHECK-NEXT: ldr w0, [x8, x0]
32+
; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload
33+
; CHECK-NEXT: ldr x30, [sp, #64] // 8-byte Folded Reload
34+
; CHECK-NEXT: ldp d15, d14, [sp], #80 // 16-byte Folded Reload
35+
; CHECK-NEXT: ret
36+
entry:
37+
%0 = tail call align 4 ptr @llvm.threadlocal.address.p0(ptr align 4 @x)
38+
%1 = load i32, ptr %0, align 4
39+
ret i32 %1
40+
}
41+
42+
define i32 @load_tls_streaming() nounwind "aarch64_pstate_sm_enabled" {
43+
; CHECK-LABEL: load_tls_streaming:
44+
; CHECK: // %bb.0: // %entry
45+
; CHECK-NEXT: stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
46+
; CHECK-NEXT: stp d13, d12, [sp, #16] // 16-byte Folded Spill
47+
; CHECK-NEXT: stp d11, d10, [sp, #32] // 16-byte Folded Spill
48+
; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill
49+
; CHECK-NEXT: str x30, [sp, #64] // 8-byte Folded Spill
50+
; CHECK-NEXT: smstop sm
51+
; CHECK-NEXT: adrp x0, :tlsdesc:x
52+
; CHECK-NEXT: ldr x1, [x0, :tlsdesc_lo12:x]
53+
; CHECK-NEXT: add x0, x0, :tlsdesc_lo12:x
54+
; CHECK-NEXT: .tlsdesccall x
55+
; CHECK-NEXT: blr x1
56+
; CHECK-NEXT: smstart sm
57+
; CHECK-NEXT: mrs x8, TPIDR_EL0
58+
; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
59+
; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload
60+
; CHECK-NEXT: ldr w0, [x8, x0]
61+
; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload
62+
; CHECK-NEXT: ldr x30, [sp, #64] // 8-byte Folded Reload
63+
; CHECK-NEXT: ldp d15, d14, [sp], #80 // 16-byte Folded Reload
64+
; CHECK-NEXT: ret
65+
entry:
66+
%0 = tail call align 4 ptr @llvm.threadlocal.address.p0(ptr align 4 @x)
67+
%1 = load i32, ptr %0, align 4
68+
ret i32 %1
69+
}
70+
71+
define i32 @load_tls_shared_za() nounwind "aarch64_inout_za" {
72+
; CHECK-LABEL: load_tls_shared_za:
73+
; CHECK: // %bb.0: // %entry
74+
; CHECK-NEXT: stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
75+
; CHECK-NEXT: mov x29, sp
76+
; CHECK-NEXT: sub sp, sp, #16
77+
; CHECK-NEXT: rdsvl x8, #1
78+
; CHECK-NEXT: mov x9, sp
79+
; CHECK-NEXT: msub x9, x8, x8, x9
80+
; CHECK-NEXT: mov sp, x9
81+
; CHECK-NEXT: sub x10, x29, #16
82+
; CHECK-NEXT: stp x9, x8, [x29, #-16]
83+
; CHECK-NEXT: msr TPIDR2_EL0, x10
84+
; CHECK-NEXT: adrp x0, :tlsdesc:x
85+
; CHECK-NEXT: ldr x1, [x0, :tlsdesc_lo12:x]
86+
; CHECK-NEXT: add x0, x0, :tlsdesc_lo12:x
87+
; CHECK-NEXT: .tlsdesccall x
88+
; CHECK-NEXT: blr x1
89+
; CHECK-NEXT: mrs x8, TPIDR_EL0
90+
; CHECK-NEXT: ldr w0, [x8, x0]
91+
; CHECK-NEXT: mov w8, w0
92+
; CHECK-NEXT: smstart za
93+
; CHECK-NEXT: mrs x9, TPIDR2_EL0
94+
; CHECK-NEXT: sub x0, x29, #16
95+
; CHECK-NEXT: cbnz x9, .LBB2_2
96+
; CHECK-NEXT: // %bb.1: // %entry
97+
; CHECK-NEXT: bl __arm_tpidr2_restore
98+
; CHECK-NEXT: .LBB2_2: // %entry
99+
; CHECK-NEXT: mov w0, w8
100+
; CHECK-NEXT: msr TPIDR2_EL0, xzr
101+
; CHECK-NEXT: mov sp, x29
102+
; CHECK-NEXT: ldp x29, x30, [sp], #16 // 16-byte Folded Reload
103+
; CHECK-NEXT: ret
104+
entry:
105+
%0 = tail call align 4 ptr @llvm.threadlocal.address.p0(ptr align 4 @x)
106+
%1 = load i32, ptr %0, align 4
107+
ret i32 %1
108+
}
109+
110+
define i32 @load_tls_streaming_shared_za() nounwind "aarch64_inout_za" "aarch64_pstate_sm_enabled" {
111+
; CHECK-LABEL: load_tls_streaming_shared_za:
112+
; CHECK: // %bb.0: // %entry
113+
; CHECK-NEXT: stp d15, d14, [sp, #-96]! // 16-byte Folded Spill
114+
; CHECK-NEXT: stp d13, d12, [sp, #16] // 16-byte Folded Spill
115+
; CHECK-NEXT: stp d11, d10, [sp, #32] // 16-byte Folded Spill
116+
; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill
117+
; CHECK-NEXT: stp x29, x30, [sp, #64] // 16-byte Folded Spill
118+
; CHECK-NEXT: add x29, sp, #64
119+
; CHECK-NEXT: str x19, [sp, #80] // 8-byte Folded Spill
120+
; CHECK-NEXT: sub sp, sp, #16
121+
; CHECK-NEXT: rdsvl x8, #1
122+
; CHECK-NEXT: mov x9, sp
123+
; CHECK-NEXT: msub x9, x8, x8, x9
124+
; CHECK-NEXT: mov sp, x9
125+
; CHECK-NEXT: stp x9, x8, [x29, #-80]
126+
; CHECK-NEXT: smstop sm
127+
; CHECK-NEXT: sub x8, x29, #80
128+
; CHECK-NEXT: msr TPIDR2_EL0, x8
129+
; CHECK-NEXT: adrp x0, :tlsdesc:x
130+
; CHECK-NEXT: ldr x1, [x0, :tlsdesc_lo12:x]
131+
; CHECK-NEXT: add x0, x0, :tlsdesc_lo12:x
132+
; CHECK-NEXT: .tlsdesccall x
133+
; CHECK-NEXT: blr x1
134+
; CHECK-NEXT: smstart sm
135+
; CHECK-NEXT: mrs x8, TPIDR_EL0
136+
; CHECK-NEXT: ldr w0, [x8, x0]
137+
; CHECK-NEXT: mov w8, w0
138+
; CHECK-NEXT: smstart za
139+
; CHECK-NEXT: mrs x9, TPIDR2_EL0
140+
; CHECK-NEXT: sub x0, x29, #80
141+
; CHECK-NEXT: cbnz x9, .LBB3_2
142+
; CHECK-NEXT: // %bb.1: // %entry
143+
; CHECK-NEXT: bl __arm_tpidr2_restore
144+
; CHECK-NEXT: .LBB3_2: // %entry
145+
; CHECK-NEXT: mov w0, w8
146+
; CHECK-NEXT: msr TPIDR2_EL0, xzr
147+
; CHECK-NEXT: sub sp, x29, #64
148+
; CHECK-NEXT: ldp x29, x30, [sp, #64] // 16-byte Folded Reload
149+
; CHECK-NEXT: ldr x19, [sp, #80] // 8-byte Folded Reload
150+
; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
151+
; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload
152+
; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload
153+
; CHECK-NEXT: ldp d15, d14, [sp], #96 // 16-byte Folded Reload
154+
; CHECK-NEXT: ret
155+
entry:
156+
%0 = tail call align 4 ptr @llvm.threadlocal.address.p0(ptr align 4 @x)
157+
%1 = load i32, ptr %0, align 4
158+
ret i32 %1
159+
}

0 commit comments

Comments
 (0)