@@ -8284,53 +8284,54 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
82848284 if (Subtarget->hasCustomCallingConv())
82858285 Subtarget->getRegisterInfo()->UpdateCustomCalleeSavedRegs(MF);
82868286
8287- // Create a 16 Byte TPIDR2 object. The dynamic buffer
8288- // will be expanded and stored in the static object later using a pseudonode.
8289- if (Attrs.hasZAState()) {
8290- TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
8291- TPIDR2.FrameIndex = MFI.CreateStackObject(16, Align(16), false);
8292- SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
8293- DAG.getConstant(1, DL, MVT::i32));
8294-
8295- SDValue Buffer;
8296- if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8297- Buffer = DAG.getNode(AArch64ISD::ALLOCATE_ZA_BUFFER, DL,
8298- DAG.getVTList(MVT::i64, MVT::Other), {Chain, SVL});
8299- } else {
8300- SDValue Size = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL);
8301- Buffer = DAG.getNode(ISD::DYNAMIC_STACKALLOC, DL,
8302- DAG.getVTList(MVT::i64, MVT::Other),
8303- {Chain, Size, DAG.getConstant(1, DL, MVT::i64)});
8304- MFI.CreateVariableSizedObject(Align(16), nullptr);
8305- }
8306- Chain = DAG.getNode(
8307- AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
8308- {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
8309- } else if (Attrs.hasAgnosticZAInterface()) {
8310- // Call __arm_sme_state_size().
8311- SDValue BufferSize =
8312- DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
8313- DAG.getVTList(MVT::i64, MVT::Other), Chain);
8314- Chain = BufferSize.getValue(1);
8315-
8316- SDValue Buffer;
8317- if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8318- Buffer =
8319- DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
8320- DAG.getVTList(MVT::i64, MVT::Other), {Chain, BufferSize});
8321- } else {
8322- // Allocate space dynamically.
8323- Buffer = DAG.getNode(
8324- ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
8325- {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
8326- MFI.CreateVariableSizedObject(Align(16), nullptr);
8287+ if (!Subtarget->useNewSMEABILowering() || Attrs.hasAgnosticZAInterface()) {
8288+ // Old SME ABI lowering (deprecated):
8289+ // Create a 16 Byte TPIDR2 object. The dynamic buffer
8290+ // will be expanded and stored in the static object later using a
8291+ // pseudonode.
8292+ if (Attrs.hasZAState()) {
8293+ TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
8294+ TPIDR2.FrameIndex = MFI.CreateStackObject(16, Align(16), false);
8295+ SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
8296+ DAG.getConstant(1, DL, MVT::i32));
8297+ SDValue Buffer;
8298+ if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8299+ Buffer = DAG.getNode(AArch64ISD::ALLOCATE_ZA_BUFFER, DL,
8300+ DAG.getVTList(MVT::i64, MVT::Other), {Chain, SVL});
8301+ } else {
8302+ SDValue Size = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL);
8303+ Buffer = DAG.getNode(ISD::DYNAMIC_STACKALLOC, DL,
8304+ DAG.getVTList(MVT::i64, MVT::Other),
8305+ {Chain, Size, DAG.getConstant(1, DL, MVT::i64)});
8306+ MFI.CreateVariableSizedObject(Align(16), nullptr);
8307+ }
8308+ Chain = DAG.getNode(
8309+ AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
8310+ {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
8311+ } else if (Attrs.hasAgnosticZAInterface()) {
8312+ // Call __arm_sme_state_size().
8313+ SDValue BufferSize =
8314+ DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
8315+ DAG.getVTList(MVT::i64, MVT::Other), Chain);
8316+ Chain = BufferSize.getValue(1);
8317+ SDValue Buffer;
8318+ if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8319+ Buffer = DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
8320+ DAG.getVTList(MVT::i64, MVT::Other),
8321+ {Chain, BufferSize});
8322+ } else {
8323+ // Allocate space dynamically.
8324+ Buffer = DAG.getNode(
8325+ ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
8326+ {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
8327+ MFI.CreateVariableSizedObject(Align(16), nullptr);
8328+ }
8329+ // Copy the value to a virtual register, and save that in FuncInfo.
8330+ Register BufferPtr =
8331+ MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
8332+ FuncInfo->setSMESaveBufferAddr(BufferPtr);
8333+ Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
83278334 }
8328-
8329- // Copy the value to a virtual register, and save that in FuncInfo.
8330- Register BufferPtr =
8331- MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
8332- FuncInfo->setSMESaveBufferAddr(BufferPtr);
8333- Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
83348335 }
83358336
83368337 if (CallConv == CallingConv::PreserveNone) {
@@ -8347,6 +8348,15 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
83478348 }
83488349 }
83498350
8351+ if (Subtarget->useNewSMEABILowering()) {
8352+ // Clear new ZT0 state. TODO: Move this to the SME ABI pass.
8353+ if (Attrs.isNewZT0())
8354+ Chain = DAG.getNode(
8355+ ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
8356+ DAG.getConstant(Intrinsic::aarch64_sme_zero_zt, DL, MVT::i32),
8357+ DAG.getTargetConstant(0, DL, MVT::i32));
8358+ }
8359+
83508360 return Chain;
83518361}
83528362
@@ -8918,7 +8928,6 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
89188928 MachineFunction &MF = DAG.getMachineFunction();
89198929 AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
89208930 FuncInfo->setSMESaveBufferUsed();
8921-
89228931 TargetLowering::ArgListTy Args;
89238932 Args.emplace_back(
89248933 DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64),
@@ -9046,6 +9055,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
90469055 if (MF.getTarget().Options.EmitCallGraphSection && CB && CB->isIndirectCall())
90479056 CSInfo = MachineFunction::CallSiteInfo(*CB);
90489057
9058+ // Determine whether we need any streaming mode changes.
9059+ SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), *this, CLI);
9060+
90499061 // Check callee args/returns for SVE registers and set calling convention
90509062 // accordingly.
90519063 if (CallConv == CallingConv::C || CallConv == CallingConv::Fast) {
@@ -9059,14 +9071,26 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
90599071 CallConv = CallingConv::AArch64_SVE_VectorCall;
90609072 }
90619073
9074+ bool UseNewSMEABILowering = Subtarget->useNewSMEABILowering();
9075+ bool IsAgnosticZAFunction = CallAttrs.caller().hasAgnosticZAInterface();
9076+ auto ZAMarkerNode = [&]() -> std::optional<unsigned> {
9077+ // TODO: Handle agnostic ZA functions.
9078+ if (!UseNewSMEABILowering || IsAgnosticZAFunction)
9079+ return std::nullopt;
9080+ if (!CallAttrs.caller().hasZAState() && !CallAttrs.caller().hasZT0State())
9081+ return std::nullopt;
9082+ return CallAttrs.requiresLazySave() ? AArch64ISD::REQUIRES_ZA_SAVE
9083+ : AArch64ISD::INOUT_ZA_USE;
9084+ }();
9085+
90629086 if (IsTailCall) {
90639087 // Check if it's really possible to do a tail call.
90649088 IsTailCall = isEligibleForTailCallOptimization(CLI);
90659089
90669090 // A sibling call is one where we're under the usual C ABI and not planning
90679091 // to change that but can still do a tail call:
9068- if (!TailCallOpt && IsTailCall && CallConv != CallingConv::Tail &&
9069- CallConv != CallingConv::SwiftTail)
9092+ if (!ZAMarkerNode.has_value() && !TailCallOpt && IsTailCall &&
9093+ CallConv != CallingConv::Tail && CallConv != CallingConv:: SwiftTail)
90709094 IsSibCall = true;
90719095
90729096 if (IsTailCall)
@@ -9118,9 +9142,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91189142 assert(FPDiff % 16 == 0 && "unaligned stack on tail call");
91199143 }
91209144
9121- // Determine whether we need any streaming mode changes.
9122- SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), *this, CLI);
9123-
91249145 auto DescribeCallsite =
91259146 [&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
91269147 R << "call from '" << ore::NV("Caller", MF.getName()) << "' to '";
@@ -9134,7 +9155,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91349155 return R;
91359156 };
91369157
9137- bool RequiresLazySave = CallAttrs.requiresLazySave();
9158+ bool RequiresLazySave = !UseNewSMEABILowering && CallAttrs.requiresLazySave();
91389159 bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
91399160 if (RequiresLazySave) {
91409161 const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
@@ -9209,10 +9230,21 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
92099230 AArch64ISD::SMSTOP, DL, DAG.getVTList(MVT::Other, MVT::Glue), Chain,
92109231 DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32));
92119232
9212- // Adjust the stack pointer for the new arguments...
9233+ // Adjust the stack pointer for the new arguments... and mark ZA uses.
92139234 // These operations are automatically eliminated by the prolog/epilog pass
9214- if (!IsSibCall)
9235+ assert((!IsSibCall || !ZAMarkerNode.has_value()) &&
9236+ "ZA markers require CALLSEQ_START");
9237+ if (!IsSibCall) {
92159238 Chain = DAG.getCALLSEQ_START(Chain, IsTailCall ? 0 : NumBytes, 0, DL);
9239+ if (ZAMarkerNode) {
9240+ // Note: We need the CALLSEQ_START to glue the ZAMarkerNode to, simply
9241+ // using a chain can result in incorrect scheduling. The markers referer
9242+ // to the position just before the CALLSEQ_START (though occur after as
9243+ // CALLSEQ_START lacks in-glue).
9244+ Chain = DAG.getNode(*ZAMarkerNode, DL, DAG.getVTList(MVT::Other),
9245+ {Chain, Chain.getValue(1)});
9246+ }
9247+ }
92169248
92179249 SDValue StackPtr = DAG.getCopyFromReg(Chain, DL, AArch64::SP,
92189250 getPointerTy(DAG.getDataLayout()));
@@ -9683,7 +9715,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96839715 }
96849716 }
96859717
9686- if (CallAttrs.requiresEnablingZAAfterCall())
9718+ if (RequiresLazySave || CallAttrs.requiresEnablingZAAfterCall())
96879719 // Unconditionally resume ZA.
96889720 Result = DAG.getNode(
96899721 AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Result,
@@ -9705,7 +9737,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
97059737 SDValue TPIDR2_EL0 = DAG.getNode(
97069738 ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
97079739 DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
9708-
97099740 // Copy the address of the TPIDR2 block into X0 before 'calling' the
97109741 // RESTORE_ZA pseudo.
97119742 SDValue Glue;
@@ -9717,7 +9748,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
97179748 DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
97189749 {Result, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64),
97199750 RestoreRoutine, RegMask, Result.getValue(1)});
9720-
97219751 // Finally reset the TPIDR2_EL0 register to 0.
97229752 Result = DAG.getNode(
97239753 ISD::INTRINSIC_VOID, DL, MVT::Other, Result,
0 commit comments