@@ -8636,16 +8636,6 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI,
86368636 }
86378637}
86388638
8639- static SMECallAttrs
8640- getSMECallAttrs(const Function &Function,
8641- const TargetLowering::CallLoweringInfo &CLI) {
8642- if (CLI.CB)
8643- return SMECallAttrs(*CLI.CB);
8644- if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
8645- return SMECallAttrs(SMEAttrs(Function), SMEAttrs(ES->getSymbol()));
8646- return SMECallAttrs(SMEAttrs(Function), SMEAttrs(SMEAttrs::Normal));
8647- }
8648-
86498639bool AArch64TargetLowering::isEligibleForTailCallOptimization(
86508640 const CallLoweringInfo &CLI) const {
86518641 CallingConv::ID CalleeCC = CLI.CallConv;
@@ -8664,10 +8654,12 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
86648654
86658655 // SME Streaming functions are not eligible for TCO as they may require
86668656 // the streaming mode or ZA to be restored after returning from the call.
8667- SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, CLI);
8668- if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
8669- CallAttrs.requiresPreservingAllZAState() ||
8670- CallAttrs.caller().hasStreamingBody())
8657+ SMEAttrs CallerAttrs(MF.getFunction());
8658+ auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
8659+ if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
8660+ CallerAttrs.requiresLazySave(CalleeAttrs) ||
8661+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
8662+ CallerAttrs.hasStreamingBody())
86718663 return false;
86728664
86738665 // Functions using the C or Fast calling convention that have an SVE signature
@@ -8959,13 +8951,14 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
89598951 return TLI.LowerCallTo(CLI).second;
89608952}
89618953
8962- static unsigned getSMCondition(const SMECallAttrs &CallAttrs) {
8963- if (!CallAttrs.caller().hasStreamingCompatibleInterface() ||
8964- CallAttrs.caller().hasStreamingBody())
8954+ static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
8955+ const SMEAttrs &CalleeAttrs) {
8956+ if (!CallerAttrs.hasStreamingCompatibleInterface() ||
8957+ CallerAttrs.hasStreamingBody())
89658958 return AArch64SME::Always;
8966- if (CallAttrs.callee() .hasNonStreamingInterface())
8959+ if (CalleeAttrs .hasNonStreamingInterface())
89678960 return AArch64SME::IfCallerIsStreaming;
8968- if (CallAttrs.callee() .hasStreamingInterface())
8961+ if (CalleeAttrs .hasStreamingInterface())
89698962 return AArch64SME::IfCallerIsNonStreaming;
89708963
89718964 llvm_unreachable("Unsupported attributes");
@@ -9098,7 +9091,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
90989091 }
90999092
91009093 // Determine whether we need any streaming mode changes.
9101- SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
9094+ SMEAttrs CalleeAttrs, CallerAttrs(MF.getFunction());
9095+ if (CLI.CB)
9096+ CalleeAttrs = SMEAttrs(*CLI.CB);
9097+ else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
9098+ CalleeAttrs = SMEAttrs(ES->getSymbol());
91029099
91039100 auto DescribeCallsite =
91049101 [&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
@@ -9113,8 +9110,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91139110 return R;
91149111 };
91159112
9116- bool RequiresLazySave = CallAttrs.requiresLazySave();
9117- bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
9113+ bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
9114+ bool RequiresSaveAllZA =
9115+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
91189116 if (RequiresLazySave) {
91199117 const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
91209118 MachinePointerInfo MPI =
@@ -9142,18 +9140,18 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91429140 return DescribeCallsite(R) << " sets up a lazy save for ZA";
91439141 });
91449142 } else if (RequiresSaveAllZA) {
9145- assert(!CallAttrs.callee() .hasSharedZAInterface() &&
9143+ assert(!CalleeAttrs .hasSharedZAInterface() &&
91469144 "Cannot share state that may not exist");
91479145 Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
91489146 /*IsSave=*/true);
91499147 }
91509148
91519149 SDValue PStateSM;
9152- bool RequiresSMChange = CallAttrs .requiresSMChange();
9150+ bool RequiresSMChange = CallerAttrs .requiresSMChange(CalleeAttrs );
91539151 if (RequiresSMChange) {
9154- if (CallAttrs.caller() .hasStreamingInterfaceOrBody())
9152+ if (CallerAttrs .hasStreamingInterfaceOrBody())
91559153 PStateSM = DAG.getConstant(1, DL, MVT::i64);
9156- else if (CallAttrs.caller() .hasNonStreamingInterface())
9154+ else if (CallerAttrs .hasNonStreamingInterface())
91579155 PStateSM = DAG.getConstant(0, DL, MVT::i64);
91589156 else
91599157 PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
@@ -9170,7 +9168,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91709168
91719169 SDValue ZTFrameIdx;
91729170 MachineFrameInfo &MFI = MF.getFrameInfo();
9173- bool ShouldPreserveZT0 = CallAttrs .requiresPreservingZT0();
9171+ bool ShouldPreserveZT0 = CallerAttrs .requiresPreservingZT0(CalleeAttrs );
91749172
91759173 // If the caller has ZT0 state which will not be preserved by the callee,
91769174 // spill ZT0 before the call.
@@ -9186,7 +9184,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91869184
91879185 // If caller shares ZT0 but the callee is not shared ZA, we need to stop
91889186 // PSTATE.ZA before the call if there is no lazy-save active.
9189- bool DisableZA = CallAttrs .requiresDisablingZABeforeCall();
9187+ bool DisableZA = CallerAttrs .requiresDisablingZABeforeCall(CalleeAttrs );
91909188 assert((!DisableZA || !RequiresLazySave) &&
91919189 "Lazy-save should have PSTATE.SM=1 on entry to the function");
91929190
@@ -9468,9 +9466,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
94689466 InGlue = Chain.getValue(1);
94699467 }
94709468
9471- SDValue NewChain =
9472- changeStreamingMode( DAG, DL, CallAttrs.callee(). hasStreamingInterface(),
9473- Chain, InGlue, getSMCondition(CallAttrs ), PStateSM);
9469+ SDValue NewChain = changeStreamingMode(
9470+ DAG, DL, CalleeAttrs. hasStreamingInterface(), Chain, InGlue ,
9471+ getSMCondition(CallerAttrs, CalleeAttrs ), PStateSM);
94749472 Chain = NewChain.getValue(0);
94759473 InGlue = NewChain.getValue(1);
94769474 }
@@ -9649,8 +9647,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96499647 if (RequiresSMChange) {
96509648 assert(PStateSM && "Expected a PStateSM to be set");
96519649 Result = changeStreamingMode(
9652- DAG, DL, !CallAttrs.callee() .hasStreamingInterface(), Result, InGlue,
9653- getSMCondition(CallAttrs ), PStateSM);
9650+ DAG, DL, !CalleeAttrs .hasStreamingInterface(), Result, InGlue,
9651+ getSMCondition(CallerAttrs, CalleeAttrs ), PStateSM);
96549652
96559653 if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
96569654 InGlue = Result.getValue(1);
@@ -9660,7 +9658,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96609658 }
96619659 }
96629660
9663- if (CallAttrs .requiresEnablingZAAfterCall())
9661+ if (CallerAttrs .requiresEnablingZAAfterCall(CalleeAttrs ))
96649662 // Unconditionally resume ZA.
96659663 Result = DAG.getNode(
96669664 AArch64ISD::SMSTART, DL, MVT::Other, Result,
@@ -28520,10 +28518,12 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
2852028518
2852128519 // Checks to allow the use of SME instructions
2852228520 if (auto *Base = dyn_cast<CallBase>(&Inst)) {
28523- auto CallAttrs = SMECallAttrs(*Base);
28524- if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
28525- CallAttrs.requiresPreservingZT0() ||
28526- CallAttrs.requiresPreservingAllZAState())
28521+ auto CallerAttrs = SMEAttrs(*Inst.getFunction());
28522+ auto CalleeAttrs = SMEAttrs(*Base);
28523+ if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
28524+ CallerAttrs.requiresLazySave(CalleeAttrs) ||
28525+ CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
28526+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
2852728527 return true;
2852828528 }
2852928529 return false;
0 commit comments