@@ -8277,53 +8277,54 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
82778277 if (Subtarget->hasCustomCallingConv())
82788278 Subtarget->getRegisterInfo()->UpdateCustomCalleeSavedRegs(MF);
82798279
8280- // Create a 16 Byte TPIDR2 object. The dynamic buffer
8281- // will be expanded and stored in the static object later using a pseudonode.
8282- if (Attrs.hasZAState()) {
8283- TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
8284- TPIDR2.FrameIndex = MFI.CreateStackObject(16, Align(16), false);
8285- SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
8286- DAG.getConstant(1, DL, MVT::i32));
8287-
8288- SDValue Buffer;
8289- if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8290- Buffer = DAG.getNode(AArch64ISD::ALLOCATE_ZA_BUFFER, DL,
8291- DAG.getVTList(MVT::i64, MVT::Other), {Chain, SVL});
8292- } else {
8293- SDValue Size = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL);
8294- Buffer = DAG.getNode(ISD::DYNAMIC_STACKALLOC, DL,
8295- DAG.getVTList(MVT::i64, MVT::Other),
8296- {Chain, Size, DAG.getConstant(1, DL, MVT::i64)});
8297- MFI.CreateVariableSizedObject(Align(16), nullptr);
8298- }
8299- Chain = DAG.getNode(
8300- AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
8301- {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
8302- } else if (Attrs.hasAgnosticZAInterface()) {
8303- // Call __arm_sme_state_size().
8304- SDValue BufferSize =
8305- DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
8306- DAG.getVTList(MVT::i64, MVT::Other), Chain);
8307- Chain = BufferSize.getValue(1);
8308-
8309- SDValue Buffer;
8310- if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8311- Buffer =
8312- DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
8313- DAG.getVTList(MVT::i64, MVT::Other), {Chain, BufferSize});
8314- } else {
8315- // Allocate space dynamically.
8316- Buffer = DAG.getNode(
8317- ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
8318- {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
8319- MFI.CreateVariableSizedObject(Align(16), nullptr);
8280+ if (!Subtarget->useNewSMEABILowering() || Attrs.hasAgnosticZAInterface()) {
8281+ // Old SME ABI lowering (deprecated):
8282+ // Create a 16 Byte TPIDR2 object. The dynamic buffer
8283+ // will be expanded and stored in the static object later using a
8284+ // pseudonode.
8285+ if (Attrs.hasZAState()) {
8286+ TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
8287+ TPIDR2.FrameIndex = MFI.CreateStackObject(16, Align(16), false);
8288+ SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
8289+ DAG.getConstant(1, DL, MVT::i32));
8290+ SDValue Buffer;
8291+ if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8292+ Buffer = DAG.getNode(AArch64ISD::ALLOCATE_ZA_BUFFER, DL,
8293+ DAG.getVTList(MVT::i64, MVT::Other), {Chain, SVL});
8294+ } else {
8295+ SDValue Size = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL);
8296+ Buffer = DAG.getNode(ISD::DYNAMIC_STACKALLOC, DL,
8297+ DAG.getVTList(MVT::i64, MVT::Other),
8298+ {Chain, Size, DAG.getConstant(1, DL, MVT::i64)});
8299+ MFI.CreateVariableSizedObject(Align(16), nullptr);
8300+ }
8301+ Chain = DAG.getNode(
8302+ AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
8303+ {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
8304+ } else if (Attrs.hasAgnosticZAInterface()) {
8305+ // Call __arm_sme_state_size().
8306+ SDValue BufferSize =
8307+ DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
8308+ DAG.getVTList(MVT::i64, MVT::Other), Chain);
8309+ Chain = BufferSize.getValue(1);
8310+ SDValue Buffer;
8311+ if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8312+ Buffer = DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
8313+ DAG.getVTList(MVT::i64, MVT::Other),
8314+ {Chain, BufferSize});
8315+ } else {
8316+ // Allocate space dynamically.
8317+ Buffer = DAG.getNode(
8318+ ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
8319+ {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
8320+ MFI.CreateVariableSizedObject(Align(16), nullptr);
8321+ }
8322+ // Copy the value to a virtual register, and save that in FuncInfo.
8323+ Register BufferPtr =
8324+ MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
8325+ FuncInfo->setSMESaveBufferAddr(BufferPtr);
8326+ Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
83208327 }
8321-
8322- // Copy the value to a virtual register, and save that in FuncInfo.
8323- Register BufferPtr =
8324- MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
8325- FuncInfo->setSMESaveBufferAddr(BufferPtr);
8326- Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
83278328 }
83288329
83298330 if (CallConv == CallingConv::PreserveNone) {
@@ -8340,6 +8341,15 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
83408341 }
83418342 }
83428343
8344+ if (Subtarget->useNewSMEABILowering()) {
8345+ // Clear new ZT0 state. TODO: Move this to the SME ABI pass.
8346+ if (Attrs.isNewZT0())
8347+ Chain = DAG.getNode(
8348+ ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
8349+ DAG.getConstant(Intrinsic::aarch64_sme_zero_zt, DL, MVT::i32),
8350+ DAG.getTargetConstant(0, DL, MVT::i32));
8351+ }
8352+
83438353 return Chain;
83448354}
83458355
@@ -8911,7 +8921,6 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
89118921 MachineFunction &MF = DAG.getMachineFunction();
89128922 AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
89138923 FuncInfo->setSMESaveBufferUsed();
8914-
89158924 TargetLowering::ArgListTy Args;
89168925 TargetLowering::ArgListEntry Entry;
89178926 Entry.Ty = PointerType::getUnqual(*DAG.getContext());
@@ -9041,6 +9050,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
90419050 if (MF.getTarget().Options.EmitCallGraphSection && CB && CB->isIndirectCall())
90429051 CSInfo = MachineFunction::CallSiteInfo(*CB);
90439052
9053+ // Determine whether we need any streaming mode changes.
9054+ SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), *this, CLI);
9055+
90449056 // Check callee args/returns for SVE registers and set calling convention
90459057 // accordingly.
90469058 if (CallConv == CallingConv::C || CallConv == CallingConv::Fast) {
@@ -9054,14 +9066,26 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
90549066 CallConv = CallingConv::AArch64_SVE_VectorCall;
90559067 }
90569068
9069+ bool UseNewSMEABILowering = Subtarget->useNewSMEABILowering();
9070+ bool IsAgnosticZAFunction = CallAttrs.caller().hasAgnosticZAInterface();
9071+ auto ZAMarkerNode = [&]() -> std::optional<unsigned> {
9072+ // TODO: Handle agnostic ZA functions.
9073+ if (!UseNewSMEABILowering || IsAgnosticZAFunction)
9074+ return std::nullopt;
9075+ if (!CallAttrs.caller().hasZAState() && !CallAttrs.caller().hasZT0State())
9076+ return std::nullopt;
9077+ return CallAttrs.requiresLazySave() ? AArch64ISD::REQUIRES_ZA_SAVE
9078+ : AArch64ISD::INOUT_ZA_USE;
9079+ }();
9080+
90579081 if (IsTailCall) {
90589082 // Check if it's really possible to do a tail call.
90599083 IsTailCall = isEligibleForTailCallOptimization(CLI);
90609084
90619085 // A sibling call is one where we're under the usual C ABI and not planning
90629086 // to change that but can still do a tail call:
9063- if (!TailCallOpt && IsTailCall && CallConv != CallingConv::Tail &&
9064- CallConv != CallingConv::SwiftTail)
9087+ if (!ZAMarkerNode.has_value() && !TailCallOpt && IsTailCall &&
9088+ CallConv != CallingConv::Tail && CallConv != CallingConv:: SwiftTail)
90659089 IsSibCall = true;
90669090
90679091 if (IsTailCall)
@@ -9113,9 +9137,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91139137 assert(FPDiff % 16 == 0 && "unaligned stack on tail call");
91149138 }
91159139
9116- // Determine whether we need any streaming mode changes.
9117- SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), *this, CLI);
9118-
91199140 auto DescribeCallsite =
91209141 [&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
91219142 R << "call from '" << ore::NV("Caller", MF.getName()) << "' to '";
@@ -9129,7 +9150,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91299150 return R;
91309151 };
91319152
9132- bool RequiresLazySave = CallAttrs.requiresLazySave();
9153+ bool RequiresLazySave = !UseNewSMEABILowering && CallAttrs.requiresLazySave();
91339154 bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
91349155 if (RequiresLazySave) {
91359156 const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
@@ -9204,10 +9225,21 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
92049225 AArch64ISD::SMSTOP, DL, DAG.getVTList(MVT::Other, MVT::Glue), Chain,
92059226 DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32));
92069227
9207- // Adjust the stack pointer for the new arguments...
9228+ // Adjust the stack pointer for the new arguments... and mark ZA uses.
92089229 // These operations are automatically eliminated by the prolog/epilog pass
9209- if (!IsSibCall)
9230+ assert((!IsSibCall || !ZAMarkerNode.has_value()) &&
9231+ "ZA markers require CALLSEQ_START");
9232+ if (!IsSibCall) {
92109233 Chain = DAG.getCALLSEQ_START(Chain, IsTailCall ? 0 : NumBytes, 0, DL);
9234+ if (ZAMarkerNode) {
9235+ // Note: We need the CALLSEQ_START to glue the ZAMarkerNode to, simply
9236+ // using a chain can result in incorrect scheduling. The markers referer
9237+ // to the position just before the CALLSEQ_START (though occur after as
9238+ // CALLSEQ_START lacks in-glue).
9239+ Chain = DAG.getNode(*ZAMarkerNode, DL, DAG.getVTList(MVT::Other),
9240+ {Chain, Chain.getValue(1)});
9241+ }
9242+ }
92119243
92129244 SDValue StackPtr = DAG.getCopyFromReg(Chain, DL, AArch64::SP,
92139245 getPointerTy(DAG.getDataLayout()));
@@ -9678,7 +9710,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96789710 }
96799711 }
96809712
9681- if (CallAttrs.requiresEnablingZAAfterCall())
9713+ if (RequiresLazySave || CallAttrs.requiresEnablingZAAfterCall())
96829714 // Unconditionally resume ZA.
96839715 Result = DAG.getNode(
96849716 AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Result,
@@ -9700,7 +9732,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
97009732 SDValue TPIDR2_EL0 = DAG.getNode(
97019733 ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
97029734 DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
9703-
97049735 // Copy the address of the TPIDR2 block into X0 before 'calling' the
97059736 // RESTORE_ZA pseudo.
97069737 SDValue Glue;
@@ -9712,7 +9743,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
97129743 DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
97139744 {Result, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64),
97149745 RestoreRoutine, RegMask, Result.getValue(1)});
9715-
97169746 // Finally reset the TPIDR2_EL0 register to 0.
97179747 Result = DAG.getNode(
97189748 ISD::INTRINSIC_VOID, DL, MVT::Other, Result,
0 commit comments