@@ -8636,6 +8636,16 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI,
86368636 }
86378637}
86388638
8639+ static SMECallAttrs
8640+ getSMECallAttrs(const Function &Function,
8641+ const TargetLowering::CallLoweringInfo &CLI) {
8642+ if (CLI.CB)
8643+ return SMECallAttrs(*CLI.CB);
8644+ if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
8645+ return SMECallAttrs(SMEAttrs(Function), SMEAttrs(ES->getSymbol()));
8646+ return SMECallAttrs(SMEAttrs(Function), SMEAttrs(SMEAttrs::Normal));
8647+ }
8648+
86398649bool AArch64TargetLowering::isEligibleForTailCallOptimization(
86408650 const CallLoweringInfo &CLI) const {
86418651 CallingConv::ID CalleeCC = CLI.CallConv;
@@ -8654,12 +8664,10 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
86548664
86558665 // SME Streaming functions are not eligible for TCO as they may require
86568666 // the streaming mode or ZA to be restored after returning from the call.
8657- SMEAttrs CallerAttrs(MF.getFunction());
8658- auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
8659- if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
8660- CallerAttrs.requiresLazySave(CalleeAttrs) ||
8661- CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
8662- CallerAttrs.hasStreamingBody())
8667+ SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, CLI);
8668+ if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
8669+ CallAttrs.requiresPreservingAllZAState() ||
8670+ CallAttrs.caller().hasStreamingBody())
86638671 return false;
86648672
86658673 // Functions using the C or Fast calling convention that have an SVE signature
@@ -8951,14 +8959,13 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
89518959 return TLI.LowerCallTo(CLI).second;
89528960}
89538961
8954- static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
8955- const SMEAttrs &CalleeAttrs) {
8956- if (!CallerAttrs.hasStreamingCompatibleInterface() ||
8957- CallerAttrs.hasStreamingBody())
8962+ static unsigned getSMCondition(const SMECallAttrs &CallAttrs) {
8963+ if (!CallAttrs.caller().hasStreamingCompatibleInterface() ||
8964+ CallAttrs.caller().hasStreamingBody())
89588965 return AArch64SME::Always;
8959- if (CalleeAttrs .hasNonStreamingInterface())
8966+ if (CallAttrs.callee() .hasNonStreamingInterface())
89608967 return AArch64SME::IfCallerIsStreaming;
8961- if (CalleeAttrs .hasStreamingInterface())
8968+ if (CallAttrs.callee() .hasStreamingInterface())
89628969 return AArch64SME::IfCallerIsNonStreaming;
89638970
89648971 llvm_unreachable("Unsupported attributes");
@@ -9091,11 +9098,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
90919098 }
90929099
90939100 // Determine whether we need any streaming mode changes.
9094- SMEAttrs CalleeAttrs, CallerAttrs(MF.getFunction());
9095- if (CLI.CB)
9096- CalleeAttrs = SMEAttrs(*CLI.CB);
9097- else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
9098- CalleeAttrs = SMEAttrs(ES->getSymbol());
9101+ SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
90999102
91009103 auto DescribeCallsite =
91019104 [&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
@@ -9110,9 +9113,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91109113 return R;
91119114 };
91129115
9113- bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
9114- bool RequiresSaveAllZA =
9115- CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
9116+ bool RequiresLazySave = CallAttrs.requiresLazySave();
9117+ bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
91169118 if (RequiresLazySave) {
91179119 const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
91189120 MachinePointerInfo MPI =
@@ -9140,18 +9142,18 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91409142 return DescribeCallsite(R) << " sets up a lazy save for ZA";
91419143 });
91429144 } else if (RequiresSaveAllZA) {
9143- assert(!CalleeAttrs .hasSharedZAInterface() &&
9145+ assert(!CallAttrs.callee() .hasSharedZAInterface() &&
91449146 "Cannot share state that may not exist");
91459147 Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
91469148 /*IsSave=*/true);
91479149 }
91489150
91499151 SDValue PStateSM;
9150- bool RequiresSMChange = CallerAttrs .requiresSMChange(CalleeAttrs );
9152+ bool RequiresSMChange = CallAttrs .requiresSMChange();
91519153 if (RequiresSMChange) {
9152- if (CallerAttrs .hasStreamingInterfaceOrBody())
9154+ if (CallAttrs.caller() .hasStreamingInterfaceOrBody())
91539155 PStateSM = DAG.getConstant(1, DL, MVT::i64);
9154- else if (CallerAttrs .hasNonStreamingInterface())
9156+ else if (CallAttrs.caller() .hasNonStreamingInterface())
91559157 PStateSM = DAG.getConstant(0, DL, MVT::i64);
91569158 else
91579159 PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
@@ -9168,7 +9170,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91689170
91699171 SDValue ZTFrameIdx;
91709172 MachineFrameInfo &MFI = MF.getFrameInfo();
9171- bool ShouldPreserveZT0 = CallerAttrs .requiresPreservingZT0(CalleeAttrs );
9173+ bool ShouldPreserveZT0 = CallAttrs .requiresPreservingZT0();
91729174
91739175 // If the caller has ZT0 state which will not be preserved by the callee,
91749176 // spill ZT0 before the call.
@@ -9184,7 +9186,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91849186
91859187 // If caller shares ZT0 but the callee is not shared ZA, we need to stop
91869188 // PSTATE.ZA before the call if there is no lazy-save active.
9187- bool DisableZA = CallerAttrs .requiresDisablingZABeforeCall(CalleeAttrs );
9189+ bool DisableZA = CallAttrs .requiresDisablingZABeforeCall();
91889190 assert((!DisableZA || !RequiresLazySave) &&
91899191 "Lazy-save should have PSTATE.SM=1 on entry to the function");
91909192
@@ -9466,9 +9468,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
94669468 InGlue = Chain.getValue(1);
94679469 }
94689470
9469- SDValue NewChain = changeStreamingMode(
9470- DAG, DL, CalleeAttrs. hasStreamingInterface(), Chain, InGlue ,
9471- getSMCondition(CallerAttrs, CalleeAttrs ), PStateSM);
9471+ SDValue NewChain =
9472+ changeStreamingMode( DAG, DL, CallAttrs.callee(). hasStreamingInterface(),
9473+ Chain, InGlue, getSMCondition(CallAttrs ), PStateSM);
94729474 Chain = NewChain.getValue(0);
94739475 InGlue = NewChain.getValue(1);
94749476 }
@@ -9647,8 +9649,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96479649 if (RequiresSMChange) {
96489650 assert(PStateSM && "Expected a PStateSM to be set");
96499651 Result = changeStreamingMode(
9650- DAG, DL, !CalleeAttrs .hasStreamingInterface(), Result, InGlue,
9651- getSMCondition(CallerAttrs, CalleeAttrs ), PStateSM);
9652+ DAG, DL, !CallAttrs.callee() .hasStreamingInterface(), Result, InGlue,
9653+ getSMCondition(CallAttrs ), PStateSM);
96529654
96539655 if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
96549656 InGlue = Result.getValue(1);
@@ -9658,7 +9660,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96589660 }
96599661 }
96609662
9661- if (CallerAttrs .requiresEnablingZAAfterCall(CalleeAttrs ))
9663+ if (CallAttrs .requiresEnablingZAAfterCall())
96629664 // Unconditionally resume ZA.
96639665 Result = DAG.getNode(
96649666 AArch64ISD::SMSTART, DL, MVT::Other, Result,
@@ -28518,12 +28520,10 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
2851828520
2851928521 // Checks to allow the use of SME instructions
2852028522 if (auto *Base = dyn_cast<CallBase>(&Inst)) {
28521- auto CallerAttrs = SMEAttrs(*Inst.getFunction());
28522- auto CalleeAttrs = SMEAttrs(*Base);
28523- if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
28524- CallerAttrs.requiresLazySave(CalleeAttrs) ||
28525- CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
28526- CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
28523+ auto CallAttrs = SMECallAttrs(*Base);
28524+ if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
28525+ CallAttrs.requiresPreservingZT0() ||
28526+ CallAttrs.requiresPreservingAllZAState())
2852728527 return true;
2852828528 }
2852928529 return false;
0 commit comments