@@ -8653,6 +8653,16 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI,
86538653 }
86548654}
86558655
8656+ static SMECallAttrs
8657+ getSMECallAttrs(const Function &Function,
8658+ const TargetLowering::CallLoweringInfo &CLI) {
8659+ if (CLI.CB)
8660+ return SMECallAttrs(*CLI.CB);
8661+ if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
8662+ return SMECallAttrs(SMEAttrs(Function), SMEAttrs(ES->getSymbol()));
8663+ return SMECallAttrs(SMEAttrs(Function), SMEAttrs(SMEAttrs::Normal));
8664+ }
8665+
86568666bool AArch64TargetLowering::isEligibleForTailCallOptimization(
86578667 const CallLoweringInfo &CLI) const {
86588668 CallingConv::ID CalleeCC = CLI.CallConv;
@@ -8671,12 +8681,10 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
86718681
86728682 // SME Streaming functions are not eligible for TCO as they may require
86738683 // the streaming mode or ZA to be restored after returning from the call.
8674- SMEAttrs CallerAttrs(MF.getFunction());
8675- auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
8676- if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
8677- CallerAttrs.requiresLazySave(CalleeAttrs) ||
8678- CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
8679- CallerAttrs.hasStreamingBody())
8684+ SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, CLI);
8685+ if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
8686+ CallAttrs.requiresPreservingAllZAState() ||
8687+ CallAttrs.caller().hasStreamingBody())
86808688 return false;
86818689
86828690 // Functions using the C or Fast calling convention that have an SVE signature
@@ -8968,14 +8976,13 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
89688976 return TLI.LowerCallTo(CLI).second;
89698977}
89708978
8971- static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
8972- const SMEAttrs &CalleeAttrs) {
8973- if (!CallerAttrs.hasStreamingCompatibleInterface() ||
8974- CallerAttrs.hasStreamingBody())
8979+ static unsigned getSMCondition(const SMECallAttrs &CallAttrs) {
8980+ if (!CallAttrs.caller().hasStreamingCompatibleInterface() ||
8981+ CallAttrs.caller().hasStreamingBody())
89758982 return AArch64SME::Always;
8976- if (CalleeAttrs .hasNonStreamingInterface())
8983+ if (CallAttrs.calleeOrCallsite() .hasNonStreamingInterface())
89778984 return AArch64SME::IfCallerIsStreaming;
8978- if (CalleeAttrs .hasStreamingInterface())
8985+ if (CallAttrs.calleeOrCallsite() .hasStreamingInterface())
89798986 return AArch64SME::IfCallerIsNonStreaming;
89808987
89818988 llvm_unreachable("Unsupported attributes");
@@ -9108,11 +9115,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91089115 }
91099116
91109117 // Determine whether we need any streaming mode changes.
9111- SMEAttrs CalleeAttrs, CallerAttrs(MF.getFunction());
9112- if (CLI.CB)
9113- CalleeAttrs = SMEAttrs(*CLI.CB);
9114- else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
9115- CalleeAttrs = SMEAttrs(ES->getSymbol());
9118+ SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
91169119
91179120 auto DescribeCallsite =
91189121 [&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
@@ -9127,9 +9130,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91279130 return R;
91289131 };
91299132
9130- bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
9131- bool RequiresSaveAllZA =
9132- CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
9133+ bool RequiresLazySave = CallAttrs.requiresLazySave();
9134+ bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
91339135 if (RequiresLazySave) {
91349136 const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
91359137 MachinePointerInfo MPI =
@@ -9157,18 +9159,18 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91579159 return DescribeCallsite(R) << " sets up a lazy save for ZA";
91589160 });
91599161 } else if (RequiresSaveAllZA) {
9160- assert(!CalleeAttrs .hasSharedZAInterface() &&
9162+ assert(!CallAttrs.calleeOrCallsite() .hasSharedZAInterface() &&
91619163 "Cannot share state that may not exist");
91629164 Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
91639165 /*IsSave=*/true);
91649166 }
91659167
91669168 SDValue PStateSM;
9167- bool RequiresSMChange = CallerAttrs .requiresSMChange(CalleeAttrs );
9169+ bool RequiresSMChange = CallAttrs .requiresSMChange();
91689170 if (RequiresSMChange) {
9169- if (CallerAttrs .hasStreamingInterfaceOrBody())
9171+ if (CallAttrs.caller() .hasStreamingInterfaceOrBody())
91709172 PStateSM = DAG.getConstant(1, DL, MVT::i64);
9171- else if (CallerAttrs .hasNonStreamingInterface())
9173+ else if (CallAttrs.caller() .hasNonStreamingInterface())
91729174 PStateSM = DAG.getConstant(0, DL, MVT::i64);
91739175 else
91749176 PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
@@ -9185,7 +9187,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91859187
91869188 SDValue ZTFrameIdx;
91879189 MachineFrameInfo &MFI = MF.getFrameInfo();
9188- bool ShouldPreserveZT0 = CallerAttrs .requiresPreservingZT0(CalleeAttrs );
9190+ bool ShouldPreserveZT0 = CallAttrs .requiresPreservingZT0();
91899191
91909192 // If the caller has ZT0 state which will not be preserved by the callee,
91919193 // spill ZT0 before the call.
@@ -9201,7 +9203,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
92019203
92029204 // If caller shares ZT0 but the callee is not shared ZA, we need to stop
92039205 // PSTATE.ZA before the call if there is no lazy-save active.
9204- bool DisableZA = CallerAttrs .requiresDisablingZABeforeCall(CalleeAttrs );
9206+ bool DisableZA = CallAttrs .requiresDisablingZABeforeCall();
92059207 assert((!DisableZA || !RequiresLazySave) &&
92069208 "Lazy-save should have PSTATE.SM=1 on entry to the function");
92079209
@@ -9484,8 +9486,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
94849486 }
94859487
94869488 SDValue NewChain = changeStreamingMode(
9487- DAG, DL, CalleeAttrs. hasStreamingInterface(), Chain, InGlue ,
9488- getSMCondition(CallerAttrs, CalleeAttrs ), PStateSM);
9489+ DAG, DL, CallAttrs.calleeOrCallsite(). hasStreamingInterface(), Chain,
9490+ InGlue, getSMCondition(CallAttrs ), PStateSM);
94899491 Chain = NewChain.getValue(0);
94909492 InGlue = NewChain.getValue(1);
94919493 }
@@ -9664,8 +9666,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96649666 if (RequiresSMChange) {
96659667 assert(PStateSM && "Expected a PStateSM to be set");
96669668 Result = changeStreamingMode(
9667- DAG, DL, !CalleeAttrs. hasStreamingInterface(), Result, InGlue ,
9668- getSMCondition(CallerAttrs, CalleeAttrs ), PStateSM);
9669+ DAG, DL, !CallAttrs.calleeOrCallsite(). hasStreamingInterface(), Result,
9670+ InGlue, getSMCondition(CallAttrs ), PStateSM);
96699671
96709672 if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
96719673 InGlue = Result.getValue(1);
@@ -9675,7 +9677,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96759677 }
96769678 }
96779679
9678- if (CallerAttrs .requiresEnablingZAAfterCall(CalleeAttrs ))
9680+ if (CallAttrs .requiresEnablingZAAfterCall())
96799681 // Unconditionally resume ZA.
96809682 Result = DAG.getNode(
96819683 AArch64ISD::SMSTART, DL, MVT::Other, Result,
@@ -28552,12 +28554,10 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
2855228554
2855328555 // Checks to allow the use of SME instructions
2855428556 if (auto *Base = dyn_cast<CallBase>(&Inst)) {
28555- auto CallerAttrs = SMEAttrs(*Inst.getFunction());
28556- auto CalleeAttrs = SMEAttrs(*Base);
28557- if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
28558- CallerAttrs.requiresLazySave(CalleeAttrs) ||
28559- CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
28560- CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
28557+ auto CallAttrs = SMECallAttrs(*Base);
28558+ if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
28559+ CallAttrs.requiresPreservingZT0() ||
28560+ CallAttrs.requiresPreservingAllZAState())
2856128561 return true;
2856228562 }
2856328563 return false;
0 commit comments