@@ -8641,6 +8641,16 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI,
86418641 }
86428642}
86438643
8644+ static SMECallAttrs
8645+ getSMECallAttrs(const Function &Caller,
8646+ const TargetLowering::CallLoweringInfo &CLI) {
8647+ if (CLI.CB)
8648+ return SMECallAttrs(*CLI.CB);
8649+ if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
8650+ return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(ES->getSymbol()));
8651+ return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(SMEAttrs::Normal));
8652+ }
8653+
86448654bool AArch64TargetLowering::isEligibleForTailCallOptimization(
86458655 const CallLoweringInfo &CLI) const {
86468656 CallingConv::ID CalleeCC = CLI.CallConv;
@@ -8659,12 +8669,10 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
86598669
86608670 // SME Streaming functions are not eligible for TCO as they may require
86618671 // the streaming mode or ZA to be restored after returning from the call.
8662- SMEAttrs CallerAttrs(MF.getFunction());
8663- auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
8664- if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
8665- CallerAttrs.requiresLazySave(CalleeAttrs) ||
8666- CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
8667- CallerAttrs.hasStreamingBody())
8672+ SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, CLI);
8673+ if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
8674+ CallAttrs.requiresPreservingAllZAState() ||
8675+ CallAttrs.caller().hasStreamingBody())
86688676 return false;
86698677
86708678 // Functions using the C or Fast calling convention that have an SVE signature
@@ -8956,14 +8964,14 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
89568964 return TLI.LowerCallTo(CLI).second;
89578965}
89588966
8959- static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
8960- const SMEAttrs &CalleeAttrs ) {
8961- if (!CallerAttrs .hasStreamingCompatibleInterface() ||
8962- CallerAttrs .hasStreamingBody())
8967+ static AArch64SME::ToggleCondition
8968+ getSMToggleCondition( const SMECallAttrs &CallAttrs ) {
8969+ if (!CallAttrs.caller() .hasStreamingCompatibleInterface() ||
8970+ CallAttrs.caller() .hasStreamingBody())
89638971 return AArch64SME::Always;
8964- if (CalleeAttrs .hasNonStreamingInterface())
8972+ if (CallAttrs.callee() .hasNonStreamingInterface())
89658973 return AArch64SME::IfCallerIsStreaming;
8966- if (CalleeAttrs .hasStreamingInterface())
8974+ if (CallAttrs.callee() .hasStreamingInterface())
89678975 return AArch64SME::IfCallerIsNonStreaming;
89688976
89698977 llvm_unreachable("Unsupported attributes");
@@ -9096,11 +9104,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
90969104 }
90979105
90989106 // Determine whether we need any streaming mode changes.
9099- SMEAttrs CalleeAttrs, CallerAttrs(MF.getFunction());
9100- if (CLI.CB)
9101- CalleeAttrs = SMEAttrs(*CLI.CB);
9102- else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
9103- CalleeAttrs = SMEAttrs(ES->getSymbol());
9107+ SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
91049108
91059109 auto DescribeCallsite =
91069110 [&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
@@ -9115,9 +9119,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91159119 return R;
91169120 };
91179121
9118- bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
9119- bool RequiresSaveAllZA =
9120- CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
9122+ bool RequiresLazySave = CallAttrs.requiresLazySave();
9123+ bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
91219124 if (RequiresLazySave) {
91229125 const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
91239126 MachinePointerInfo MPI =
@@ -9145,18 +9148,18 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91459148 return DescribeCallsite(R) << " sets up a lazy save for ZA";
91469149 });
91479150 } else if (RequiresSaveAllZA) {
9148- assert(!CalleeAttrs .hasSharedZAInterface() &&
9151+ assert(!CallAttrs.callee() .hasSharedZAInterface() &&
91499152 "Cannot share state that may not exist");
91509153 Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
91519154 /*IsSave=*/true);
91529155 }
91539156
91549157 SDValue PStateSM;
9155- bool RequiresSMChange = CallerAttrs .requiresSMChange(CalleeAttrs );
9158+ bool RequiresSMChange = CallAttrs .requiresSMChange();
91569159 if (RequiresSMChange) {
9157- if (CallerAttrs .hasStreamingInterfaceOrBody())
9160+ if (CallAttrs.caller() .hasStreamingInterfaceOrBody())
91589161 PStateSM = DAG.getConstant(1, DL, MVT::i64);
9159- else if (CallerAttrs .hasNonStreamingInterface())
9162+ else if (CallAttrs.caller() .hasNonStreamingInterface())
91609163 PStateSM = DAG.getConstant(0, DL, MVT::i64);
91619164 else
91629165 PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
@@ -9173,7 +9176,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91739176
91749177 SDValue ZTFrameIdx;
91759178 MachineFrameInfo &MFI = MF.getFrameInfo();
9176- bool ShouldPreserveZT0 = CallerAttrs .requiresPreservingZT0(CalleeAttrs );
9179+ bool ShouldPreserveZT0 = CallAttrs .requiresPreservingZT0();
91779180
91789181 // If the caller has ZT0 state which will not be preserved by the callee,
91799182 // spill ZT0 before the call.
@@ -9189,7 +9192,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91899192
91909193 // If caller shares ZT0 but the callee is not shared ZA, we need to stop
91919194 // PSTATE.ZA before the call if there is no lazy-save active.
9192- bool DisableZA = CallerAttrs .requiresDisablingZABeforeCall(CalleeAttrs );
9195+ bool DisableZA = CallAttrs .requiresDisablingZABeforeCall();
91939196 assert((!DisableZA || !RequiresLazySave) &&
91949197 "Lazy-save should have PSTATE.SM=1 on entry to the function");
91959198
@@ -9472,8 +9475,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
94729475 }
94739476
94749477 SDValue NewChain = changeStreamingMode(
9475- DAG, DL, CalleeAttrs .hasStreamingInterface(), Chain, InGlue,
9476- getSMCondition(CallerAttrs, CalleeAttrs ), PStateSM);
9478+ DAG, DL, CallAttrs.callee() .hasStreamingInterface(), Chain, InGlue,
9479+ getSMToggleCondition(CallAttrs ), PStateSM);
94779480 Chain = NewChain.getValue(0);
94789481 InGlue = NewChain.getValue(1);
94799482 }
@@ -9659,8 +9662,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96599662 if (RequiresSMChange) {
96609663 assert(PStateSM && "Expected a PStateSM to be set");
96619664 Result = changeStreamingMode(
9662- DAG, DL, !CalleeAttrs .hasStreamingInterface(), Result, InGlue,
9663- getSMCondition(CallerAttrs, CalleeAttrs ), PStateSM);
9665+ DAG, DL, !CallAttrs.callee() .hasStreamingInterface(), Result, InGlue,
9666+ getSMToggleCondition(CallAttrs ), PStateSM);
96649667
96659668 if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
96669669 InGlue = Result.getValue(1);
@@ -9670,7 +9673,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96709673 }
96719674 }
96729675
9673- if (CallerAttrs .requiresEnablingZAAfterCall(CalleeAttrs ))
9676+ if (CallAttrs .requiresEnablingZAAfterCall())
96749677 // Unconditionally resume ZA.
96759678 Result = DAG.getNode(
96769679 AArch64ISD::SMSTART, DL, MVT::Other, Result,
@@ -28559,12 +28562,10 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
2855928562
2856028563 // Checks to allow the use of SME instructions
2856128564 if (auto *Base = dyn_cast<CallBase>(&Inst)) {
28562- auto CallerAttrs = SMEAttrs(*Inst.getFunction());
28563- auto CalleeAttrs = SMEAttrs(*Base);
28564- if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
28565- CallerAttrs.requiresLazySave(CalleeAttrs) ||
28566- CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
28567- CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
28565+ auto CallAttrs = SMECallAttrs(*Base);
28566+ if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
28567+ CallAttrs.requiresPreservingZT0() ||
28568+ CallAttrs.requiresPreservingAllZAState())
2856828569 return true;
2856928570 }
2857028571 return false;
0 commit comments