Skip to content

Commit 0b9495b

Browse files
committed
[AArch64][SME] Split SMECallAttrs out of SMEAttrs (NFC)
SMECallAttrs is a new helper class that holds all the SMEAttrs for a call. The interfaces to query actions needed for the call (e.g. change streaming mode) have been moved to the SMECallAttrs class. The main motivation for this change is to make the split between caller, callee, and callsite attributes more apparent. Places that previously implicitly checked callsite attributes have been updated to make these checks explicit. Similarly, places known to only check callee or callsite attributes have also been updated to make this clear.
1 parent 8c7a2ce commit 0b9495b

File tree

5 files changed

+206
-161
lines changed

5 files changed

+206
-161
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8653,6 +8653,16 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI,
86538653
}
86548654
}
86558655

8656+
static SMECallAttrs
8657+
getSMECallAttrs(const Function &Function,
8658+
const TargetLowering::CallLoweringInfo &CLI) {
8659+
if (CLI.CB)
8660+
return SMECallAttrs(*CLI.CB);
8661+
if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
8662+
return SMECallAttrs(SMEAttrs(Function), SMEAttrs(ES->getSymbol()));
8663+
return SMECallAttrs(SMEAttrs(Function), SMEAttrs(SMEAttrs::Normal));
8664+
}
8665+
86568666
bool AArch64TargetLowering::isEligibleForTailCallOptimization(
86578667
const CallLoweringInfo &CLI) const {
86588668
CallingConv::ID CalleeCC = CLI.CallConv;
@@ -8671,12 +8681,10 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
86718681

86728682
// SME Streaming functions are not eligible for TCO as they may require
86738683
// the streaming mode or ZA to be restored after returning from the call.
8674-
SMEAttrs CallerAttrs(MF.getFunction());
8675-
auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
8676-
if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
8677-
CallerAttrs.requiresLazySave(CalleeAttrs) ||
8678-
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
8679-
CallerAttrs.hasStreamingBody())
8684+
SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, CLI);
8685+
if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
8686+
CallAttrs.requiresPreservingAllZAState() ||
8687+
CallAttrs.caller().hasStreamingBody())
86808688
return false;
86818689

86828690
// Functions using the C or Fast calling convention that have an SVE signature
@@ -8968,14 +8976,13 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
89688976
return TLI.LowerCallTo(CLI).second;
89698977
}
89708978

8971-
static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
8972-
const SMEAttrs &CalleeAttrs) {
8973-
if (!CallerAttrs.hasStreamingCompatibleInterface() ||
8974-
CallerAttrs.hasStreamingBody())
8979+
static unsigned getSMCondition(const SMECallAttrs &CallAttrs) {
8980+
if (!CallAttrs.caller().hasStreamingCompatibleInterface() ||
8981+
CallAttrs.caller().hasStreamingBody())
89758982
return AArch64SME::Always;
8976-
if (CalleeAttrs.hasNonStreamingInterface())
8983+
if (CallAttrs.calleeOrCallsite().hasNonStreamingInterface())
89778984
return AArch64SME::IfCallerIsStreaming;
8978-
if (CalleeAttrs.hasStreamingInterface())
8985+
if (CallAttrs.calleeOrCallsite().hasStreamingInterface())
89798986
return AArch64SME::IfCallerIsNonStreaming;
89808987

89818988
llvm_unreachable("Unsupported attributes");
@@ -9108,11 +9115,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91089115
}
91099116

91109117
// Determine whether we need any streaming mode changes.
9111-
SMEAttrs CalleeAttrs, CallerAttrs(MF.getFunction());
9112-
if (CLI.CB)
9113-
CalleeAttrs = SMEAttrs(*CLI.CB);
9114-
else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
9115-
CalleeAttrs = SMEAttrs(ES->getSymbol());
9118+
SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
91169119

91179120
auto DescribeCallsite =
91189121
[&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
@@ -9127,9 +9130,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91279130
return R;
91289131
};
91299132

9130-
bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
9131-
bool RequiresSaveAllZA =
9132-
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
9133+
bool RequiresLazySave = CallAttrs.requiresLazySave();
9134+
bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
91339135
if (RequiresLazySave) {
91349136
const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
91359137
MachinePointerInfo MPI =
@@ -9157,18 +9159,18 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91579159
return DescribeCallsite(R) << " sets up a lazy save for ZA";
91589160
});
91599161
} else if (RequiresSaveAllZA) {
9160-
assert(!CalleeAttrs.hasSharedZAInterface() &&
9162+
assert(!CallAttrs.calleeOrCallsite().hasSharedZAInterface() &&
91619163
"Cannot share state that may not exist");
91629164
Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
91639165
/*IsSave=*/true);
91649166
}
91659167

91669168
SDValue PStateSM;
9167-
bool RequiresSMChange = CallerAttrs.requiresSMChange(CalleeAttrs);
9169+
bool RequiresSMChange = CallAttrs.requiresSMChange();
91689170
if (RequiresSMChange) {
9169-
if (CallerAttrs.hasStreamingInterfaceOrBody())
9171+
if (CallAttrs.caller().hasStreamingInterfaceOrBody())
91709172
PStateSM = DAG.getConstant(1, DL, MVT::i64);
9171-
else if (CallerAttrs.hasNonStreamingInterface())
9173+
else if (CallAttrs.caller().hasNonStreamingInterface())
91729174
PStateSM = DAG.getConstant(0, DL, MVT::i64);
91739175
else
91749176
PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
@@ -9185,7 +9187,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91859187

91869188
SDValue ZTFrameIdx;
91879189
MachineFrameInfo &MFI = MF.getFrameInfo();
9188-
bool ShouldPreserveZT0 = CallerAttrs.requiresPreservingZT0(CalleeAttrs);
9190+
bool ShouldPreserveZT0 = CallAttrs.requiresPreservingZT0();
91899191

91909192
// If the caller has ZT0 state which will not be preserved by the callee,
91919193
// spill ZT0 before the call.
@@ -9201,7 +9203,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
92019203

92029204
// If caller shares ZT0 but the callee is not shared ZA, we need to stop
92039205
// PSTATE.ZA before the call if there is no lazy-save active.
9204-
bool DisableZA = CallerAttrs.requiresDisablingZABeforeCall(CalleeAttrs);
9206+
bool DisableZA = CallAttrs.requiresDisablingZABeforeCall();
92059207
assert((!DisableZA || !RequiresLazySave) &&
92069208
"Lazy-save should have PSTATE.SM=1 on entry to the function");
92079209

@@ -9484,8 +9486,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
94849486
}
94859487

94869488
SDValue NewChain = changeStreamingMode(
9487-
DAG, DL, CalleeAttrs.hasStreamingInterface(), Chain, InGlue,
9488-
getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
9489+
DAG, DL, CallAttrs.calleeOrCallsite().hasStreamingInterface(), Chain,
9490+
InGlue, getSMCondition(CallAttrs), PStateSM);
94899491
Chain = NewChain.getValue(0);
94909492
InGlue = NewChain.getValue(1);
94919493
}
@@ -9664,8 +9666,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96649666
if (RequiresSMChange) {
96659667
assert(PStateSM && "Expected a PStateSM to be set");
96669668
Result = changeStreamingMode(
9667-
DAG, DL, !CalleeAttrs.hasStreamingInterface(), Result, InGlue,
9668-
getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
9669+
DAG, DL, !CallAttrs.calleeOrCallsite().hasStreamingInterface(), Result,
9670+
InGlue, getSMCondition(CallAttrs), PStateSM);
96699671

96709672
if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
96719673
InGlue = Result.getValue(1);
@@ -9675,7 +9677,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96759677
}
96769678
}
96779679

9678-
if (CallerAttrs.requiresEnablingZAAfterCall(CalleeAttrs))
9680+
if (CallAttrs.requiresEnablingZAAfterCall())
96799681
// Unconditionally resume ZA.
96809682
Result = DAG.getNode(
96819683
AArch64ISD::SMSTART, DL, MVT::Other, Result,
@@ -28552,12 +28554,10 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
2855228554

2855328555
// Checks to allow the use of SME instructions
2855428556
if (auto *Base = dyn_cast<CallBase>(&Inst)) {
28555-
auto CallerAttrs = SMEAttrs(*Inst.getFunction());
28556-
auto CalleeAttrs = SMEAttrs(*Base);
28557-
if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
28558-
CallerAttrs.requiresLazySave(CalleeAttrs) ||
28559-
CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
28560-
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
28557+
auto CallAttrs = SMECallAttrs(*Base);
28558+
if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
28559+
CallAttrs.requiresPreservingZT0() ||
28560+
CallAttrs.requiresPreservingAllZAState())
2856128561
return true;
2856228562
}
2856328563
return false;

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -268,22 +268,21 @@ const FeatureBitset AArch64TTIImpl::InlineInverseFeatures = {
268268

269269
bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
270270
const Function *Callee) const {
271-
SMEAttrs CallerAttrs(*Caller), CalleeAttrs(*Callee);
271+
SMECallAttrs CallAttrs(*Caller, *Callee);
272272

273273
// When inlining, we should consider the body of the function, not the
274274
// interface.
275-
if (CalleeAttrs.hasStreamingBody()) {
276-
CalleeAttrs.set(SMEAttrs::SM_Compatible, false);
277-
CalleeAttrs.set(SMEAttrs::SM_Enabled, true);
275+
if (CallAttrs.callee().hasStreamingBody()) {
276+
CallAttrs.callee().set(SMEAttrs::SM_Compatible, false);
277+
CallAttrs.callee().set(SMEAttrs::SM_Enabled, true);
278278
}
279279

280-
if (CalleeAttrs.isNewZA() || CalleeAttrs.isNewZT0())
280+
if (CallAttrs.callee().isNewZA() || CallAttrs.callee().isNewZT0())
281281
return false;
282282

283-
if (CallerAttrs.requiresLazySave(CalleeAttrs) ||
284-
CallerAttrs.requiresSMChange(CalleeAttrs) ||
285-
CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
286-
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs)) {
283+
if (CallAttrs.requiresLazySave() || CallAttrs.requiresSMChange() ||
284+
CallAttrs.requiresPreservingZT0() ||
285+
CallAttrs.requiresPreservingAllZAState()) {
287286
if (hasPossibleIncompatibleOps(Callee))
288287
return false;
289288
}
@@ -349,12 +348,14 @@ AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call,
349348
// streaming-mode change, and the call to G from F would also require a
350349
// streaming-mode change, then there is benefit to do the streaming-mode
351350
// change only once and avoid inlining of G into F.
351+
352352
SMEAttrs FAttrs(*F);
353-
SMEAttrs CalleeAttrs(Call);
354-
if (FAttrs.requiresSMChange(CalleeAttrs)) {
353+
SMECallAttrs CallAttrs(Call);
354+
355+
if (SMECallAttrs(FAttrs, CallAttrs.calleeOrCallsite()).requiresSMChange()) {
355356
if (F == Call.getCaller()) // (1)
356357
return CallPenaltyChangeSM * DefaultCallPenalty;
357-
if (FAttrs.requiresSMChange(SMEAttrs(*Call.getCaller()))) // (2)
358+
if (SMECallAttrs(FAttrs, CallAttrs.caller()).requiresSMChange()) // (2)
358359
return InlineCallPenaltyChangeSM * DefaultCallPenalty;
359360
}
360361

llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,14 @@ void SMEAttrs::set(unsigned M, bool Enable) {
2727
"ZA_New and SME_ABI_Routine are mutually exclusive");
2828

2929
assert(
30-
(!sharesZA() ||
31-
(isNewZA() ^ isInZA() ^ isInOutZA() ^ isOutZA() ^ isPreservesZA())) &&
30+
(isNewZA() + isInZA() + isOutZA() + isInOutZA() + isPreservesZA()) <= 1 &&
3231
"Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', "
3332
"'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive");
3433

3534
// ZT0 Attrs
3635
assert(
37-
(!sharesZT0() || (isNewZT0() ^ isInZT0() ^ isInOutZT0() ^ isOutZT0() ^
38-
isPreservesZT0())) &&
36+
(isNewZT0() + isInZT0() + isOutZT0() + isInOutZT0() + isPreservesZT0()) <=
37+
1 &&
3938
"Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
4039
"'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive");
4140

@@ -44,27 +43,6 @@ void SMEAttrs::set(unsigned M, bool Enable) {
4443
"interface");
4544
}
4645

47-
SMEAttrs::SMEAttrs(const CallBase &CB) {
48-
*this = SMEAttrs(CB.getAttributes());
49-
if (auto *F = CB.getCalledFunction()) {
50-
set(SMEAttrs(*F).Bitmask | SMEAttrs(F->getName()).Bitmask);
51-
}
52-
}
53-
54-
SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
55-
if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state")
56-
Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine);
57-
if (FuncName == "__arm_tpidr2_restore")
58-
Bitmask |= SMEAttrs::SM_Compatible | encodeZAState(StateValue::In) |
59-
SMEAttrs::SME_ABI_Routine;
60-
if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
61-
FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
62-
Bitmask |= SMEAttrs::SM_Compatible;
63-
if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" ||
64-
FuncName == "__arm_sme_state_size")
65-
Bitmask |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
66-
}
67-
6846
SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
6947
Bitmask = 0;
7048
if (Attrs.hasFnAttr("aarch64_pstate_sm_enabled"))
@@ -99,17 +77,39 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
9977
Bitmask |= encodeZT0State(StateValue::New);
10078
}
10179

102-
bool SMEAttrs::requiresSMChange(const SMEAttrs &Callee) const {
103-
if (Callee.hasStreamingCompatibleInterface())
80+
void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName) {
81+
unsigned KnownAttrs = SMEAttrs::Normal;
82+
if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state")
83+
KnownAttrs |= (SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine);
84+
if (FuncName == "__arm_tpidr2_restore")
85+
KnownAttrs |= SMEAttrs::SM_Compatible | encodeZAState(StateValue::In) |
86+
SMEAttrs::SME_ABI_Routine;
87+
if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
88+
FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
89+
KnownAttrs |= SMEAttrs::SM_Compatible;
90+
if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" ||
91+
FuncName == "__arm_sme_state_size")
92+
KnownAttrs |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
93+
set(KnownAttrs, /*Enable=*/true);
94+
}
95+
96+
bool SMECallAttrs::requiresSMChange() const {
97+
if ((Callsite | Callee).hasStreamingCompatibleInterface())
10498
return false;
10599

106100
// Both non-streaming
107-
if (hasNonStreamingInterfaceAndBody() && Callee.hasNonStreamingInterface())
101+
if (Caller.hasNonStreamingInterfaceAndBody() &&
102+
(Callsite | Callee).hasNonStreamingInterface())
108103
return false;
109104

110105
// Both streaming
111-
if (hasStreamingInterfaceOrBody() && Callee.hasStreamingInterface())
106+
if (Caller.hasStreamingInterfaceOrBody() &&
107+
(Callsite | Callee).hasStreamingInterface())
112108
return false;
113109

114110
return true;
115111
}
112+
113+
SMECallAttrs::SMECallAttrs(const CallBase &CB)
114+
: SMECallAttrs(*CB.getFunction(), CB.getCalledFunction(),
115+
CB.getAttributes()) {}

0 commit comments

Comments
 (0)