@@ -8154,53 +8154,54 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
81548154 if (Subtarget->hasCustomCallingConv())
81558155 Subtarget->getRegisterInfo()->UpdateCustomCalleeSavedRegs(MF);
81568156
8157- // Create a 16 Byte TPIDR2 object. The dynamic buffer
8158- // will be expanded and stored in the static object later using a pseudonode.
8159- if (Attrs.hasZAState()) {
8160- TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
8161- TPIDR2.FrameIndex = MFI.CreateStackObject(16, Align(16), false);
8162- SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
8163- DAG.getConstant(1, DL, MVT::i32));
8164-
8165- SDValue Buffer;
8166- if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8167- Buffer = DAG.getNode(AArch64ISD::ALLOCATE_ZA_BUFFER, DL,
8168- DAG.getVTList(MVT::i64, MVT::Other), {Chain, SVL});
8169- } else {
8170- SDValue Size = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL);
8171- Buffer = DAG.getNode(ISD::DYNAMIC_STACKALLOC, DL,
8172- DAG.getVTList(MVT::i64, MVT::Other),
8173- {Chain, Size, DAG.getConstant(1, DL, MVT::i64)});
8174- MFI.CreateVariableSizedObject(Align(16), nullptr);
8175- }
8176- Chain = DAG.getNode(
8177- AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
8178- {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
8179- } else if (Attrs.hasAgnosticZAInterface()) {
8180- // Call __arm_sme_state_size().
8181- SDValue BufferSize =
8182- DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
8183- DAG.getVTList(MVT::i64, MVT::Other), Chain);
8184- Chain = BufferSize.getValue(1);
8185-
8186- SDValue Buffer;
8187- if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8188- Buffer =
8189- DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
8190- DAG.getVTList(MVT::i64, MVT::Other), {Chain, BufferSize});
8191- } else {
8192- // Allocate space dynamically.
8193- Buffer = DAG.getNode(
8194- ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
8195- {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
8196- MFI.CreateVariableSizedObject(Align(16), nullptr);
8157+ if (!Subtarget->useNewSMEABILowering() || Attrs.hasAgnosticZAInterface()) {
8158+ // Old SME ABI lowering (deprecated):
8159+ // Create a 16 Byte TPIDR2 object. The dynamic buffer
8160+ // will be expanded and stored in the static object later using a
8161+ // pseudonode.
8162+ if (Attrs.hasZAState()) {
8163+ TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
8164+ TPIDR2.FrameIndex = MFI.CreateStackObject(16, Align(16), false);
8165+ SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
8166+ DAG.getConstant(1, DL, MVT::i32));
8167+ SDValue Buffer;
8168+ if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8169+ Buffer = DAG.getNode(AArch64ISD::ALLOCATE_ZA_BUFFER, DL,
8170+ DAG.getVTList(MVT::i64, MVT::Other), {Chain, SVL});
8171+ } else {
8172+ SDValue Size = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL);
8173+ Buffer = DAG.getNode(ISD::DYNAMIC_STACKALLOC, DL,
8174+ DAG.getVTList(MVT::i64, MVT::Other),
8175+ {Chain, Size, DAG.getConstant(1, DL, MVT::i64)});
8176+ MFI.CreateVariableSizedObject(Align(16), nullptr);
8177+ }
8178+ Chain = DAG.getNode(
8179+ AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
8180+ {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
8181+ } else if (Attrs.hasAgnosticZAInterface()) {
8182+ // Call __arm_sme_state_size().
8183+ SDValue BufferSize =
8184+ DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
8185+ DAG.getVTList(MVT::i64, MVT::Other), Chain);
8186+ Chain = BufferSize.getValue(1);
8187+ SDValue Buffer;
8188+ if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8189+ Buffer = DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
8190+ DAG.getVTList(MVT::i64, MVT::Other),
8191+ {Chain, BufferSize});
8192+ } else {
8193+ // Allocate space dynamically.
8194+ Buffer = DAG.getNode(
8195+ ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
8196+ {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
8197+ MFI.CreateVariableSizedObject(Align(16), nullptr);
8198+ }
8199+ // Copy the value to a virtual register, and save that in FuncInfo.
8200+ Register BufferPtr =
8201+ MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
8202+ FuncInfo->setSMESaveBufferAddr(BufferPtr);
8203+ Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
81978204 }
8198-
8199- // Copy the value to a virtual register, and save that in FuncInfo.
8200- Register BufferPtr =
8201- MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
8202- FuncInfo->setSMESaveBufferAddr(BufferPtr);
8203- Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
82048205 }
82058206
82068207 if (CallConv == CallingConv::PreserveNone) {
@@ -8217,6 +8218,15 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
82178218 }
82188219 }
82198220
8221+ if (Subtarget->useNewSMEABILowering()) {
8222+ // Clear new ZT0 state. TODO: Move this to the SME ABI pass.
8223+ if (Attrs.isNewZT0())
8224+ Chain = DAG.getNode(
8225+ ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
8226+ DAG.getConstant(Intrinsic::aarch64_sme_zero_zt, DL, MVT::i32),
8227+ DAG.getTargetConstant(0, DL, MVT::i32));
8228+ }
8229+
82208230 return Chain;
82218231}
82228232
@@ -8781,14 +8791,12 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
87818791 MachineFunction &MF = DAG.getMachineFunction();
87828792 AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
87838793 FuncInfo->setSMESaveBufferUsed();
8784-
87858794 TargetLowering::ArgListTy Args;
87868795 TargetLowering::ArgListEntry Entry;
87878796 Entry.Ty = PointerType::getUnqual(*DAG.getContext());
87888797 Entry.Node =
87898798 DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64);
87908799 Args.push_back(Entry);
8791-
87928800 SDValue Callee =
87938801 DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore",
87948802 TLI.getPointerTy(DAG.getDataLayout()));
@@ -8906,6 +8914,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
89068914 *DAG.getContext());
89078915 RetCCInfo.AnalyzeCallResult(Ins, RetCC);
89088916
8917+ // Determine whether we need any streaming mode changes.
8918+ SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
8919+
89098920 // Check callee args/returns for SVE registers and set calling convention
89108921 // accordingly.
89118922 if (CallConv == CallingConv::C || CallConv == CallingConv::Fast) {
@@ -8919,14 +8930,26 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
89198930 CallConv = CallingConv::AArch64_SVE_VectorCall;
89208931 }
89218932
8933+ bool UseNewSMEABILowering = Subtarget->useNewSMEABILowering();
8934+ bool IsAgnosticZAFunction = CallAttrs.caller().hasAgnosticZAInterface();
8935+ auto ZAMarkerNode = [&]() -> std::optional<unsigned> {
8936+ // TODO: Handle agnostic ZA functions.
8937+ if (!UseNewSMEABILowering || IsAgnosticZAFunction)
8938+ return std::nullopt;
8939+ if (!CallAttrs.caller().hasZAState() && !CallAttrs.caller().hasZT0State())
8940+ return std::nullopt;
8941+ return CallAttrs.requiresLazySave() ? AArch64ISD::REQUIRES_ZA_SAVE
8942+ : AArch64ISD::INOUT_ZA_USE;
8943+ }();
8944+
89228945 if (IsTailCall) {
89238946 // Check if it's really possible to do a tail call.
89248947 IsTailCall = isEligibleForTailCallOptimization(CLI);
89258948
89268949 // A sibling call is one where we're under the usual C ABI and not planning
89278950 // to change that but can still do a tail call:
8928- if (!TailCallOpt && IsTailCall && CallConv != CallingConv::Tail &&
8929- CallConv != CallingConv::SwiftTail)
8951+ if (!ZAMarkerNode.has_value() && !TailCallOpt && IsTailCall &&
8952+ CallConv != CallingConv::Tail && CallConv != CallingConv:: SwiftTail)
89308953 IsSibCall = true;
89318954
89328955 if (IsTailCall)
@@ -8978,9 +9001,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
89789001 assert(FPDiff % 16 == 0 && "unaligned stack on tail call");
89799002 }
89809003
8981- // Determine whether we need any streaming mode changes.
8982- SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
8983-
89849004 auto DescribeCallsite =
89859005 [&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
89869006 R << "call from '" << ore::NV("Caller", MF.getName()) << "' to '";
@@ -8994,7 +9014,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
89949014 return R;
89959015 };
89969016
8997- bool RequiresLazySave = CallAttrs.requiresLazySave();
9017+ bool RequiresLazySave = !UseNewSMEABILowering && CallAttrs.requiresLazySave();
89989018 bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
89999019 if (RequiresLazySave) {
90009020 const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
@@ -9076,10 +9096,21 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
90769096 AArch64ISD::SMSTOP, DL, DAG.getVTList(MVT::Other, MVT::Glue), Chain,
90779097 DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32));
90789098
9079- // Adjust the stack pointer for the new arguments...
9099+ // Adjust the stack pointer for the new arguments... and mark ZA uses.
90809100 // These operations are automatically eliminated by the prolog/epilog pass
9081- if (!IsSibCall)
9101+ assert((!IsSibCall || !ZAMarkerNode.has_value()) &&
9102+ "ZA markers require CALLSEQ_START");
9103+ if (!IsSibCall) {
90829104 Chain = DAG.getCALLSEQ_START(Chain, IsTailCall ? 0 : NumBytes, 0, DL);
9105+ if (ZAMarkerNode) {
9106+ // Note: We need the CALLSEQ_START to glue the ZAMarkerNode to, simply
9107+ // using a chain can result in incorrect scheduling. The markers referer
9108+ // to the position just before the CALLSEQ_START (though occur after as
9109+ // CALLSEQ_START lacks in-glue).
9110+ Chain = DAG.getNode(*ZAMarkerNode, DL, DAG.getVTList(MVT::Other),
9111+ {Chain, Chain.getValue(1)});
9112+ }
9113+ }
90839114
90849115 SDValue StackPtr = DAG.getCopyFromReg(Chain, DL, AArch64::SP,
90859116 getPointerTy(DAG.getDataLayout()));
@@ -9551,7 +9582,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
95519582 }
95529583 }
95539584
9554- if (CallAttrs.requiresEnablingZAAfterCall())
9585+ if (RequiresLazySave || CallAttrs.requiresEnablingZAAfterCall())
95559586 // Unconditionally resume ZA.
95569587 Result = DAG.getNode(
95579588 AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Result,
@@ -9572,7 +9603,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
95729603 SDValue TPIDR2_EL0 = DAG.getNode(
95739604 ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
95749605 DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
9575-
95769606 // Copy the address of the TPIDR2 block into X0 before 'calling' the
95779607 // RESTORE_ZA pseudo.
95789608 SDValue Glue;
@@ -9584,7 +9614,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
95849614 DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
95859615 {Result, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64),
95869616 RestoreRoutine, RegMask, Result.getValue(1)});
9587-
95889617 // Finally reset the TPIDR2_EL0 register to 0.
95899618 Result = DAG.getNode(
95909619 ISD::INTRINSIC_VOID, DL, MVT::Other, Result,
0 commit comments