@@ -8244,53 +8244,54 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
82448244 if (Subtarget->hasCustomCallingConv())
82458245 Subtarget->getRegisterInfo()->UpdateCustomCalleeSavedRegs(MF);
82468246
8247- // Create a 16 Byte TPIDR2 object. The dynamic buffer
8248- // will be expanded and stored in the static object later using a pseudonode.
8249- if (Attrs.hasZAState()) {
8250- TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
8251- TPIDR2.FrameIndex = MFI.CreateStackObject(16, Align(16), false);
8252- SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
8253- DAG.getConstant(1, DL, MVT::i32));
8254-
8255- SDValue Buffer;
8256- if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8257- Buffer = DAG.getNode(AArch64ISD::ALLOCATE_ZA_BUFFER, DL,
8258- DAG.getVTList(MVT::i64, MVT::Other), {Chain, SVL});
8259- } else {
8260- SDValue Size = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL);
8261- Buffer = DAG.getNode(ISD::DYNAMIC_STACKALLOC, DL,
8262- DAG.getVTList(MVT::i64, MVT::Other),
8263- {Chain, Size, DAG.getConstant(1, DL, MVT::i64)});
8264- MFI.CreateVariableSizedObject(Align(16), nullptr);
8265- }
8266- Chain = DAG.getNode(
8267- AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
8268- {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
8269- } else if (Attrs.hasAgnosticZAInterface()) {
8270- // Call __arm_sme_state_size().
8271- SDValue BufferSize =
8272- DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
8273- DAG.getVTList(MVT::i64, MVT::Other), Chain);
8274- Chain = BufferSize.getValue(1);
8275-
8276- SDValue Buffer;
8277- if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8278- Buffer =
8279- DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
8280- DAG.getVTList(MVT::i64, MVT::Other), {Chain, BufferSize});
8281- } else {
8282- // Allocate space dynamically.
8283- Buffer = DAG.getNode(
8284- ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
8285- {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
8286- MFI.CreateVariableSizedObject(Align(16), nullptr);
8247+ if (!Subtarget->useNewSMEABILowering() || Attrs.hasAgnosticZAInterface()) {
8248+ // Old SME ABI lowering (deprecated):
8249+ // Create a 16 Byte TPIDR2 object. The dynamic buffer
8250+ // will be expanded and stored in the static object later using a
8251+ // pseudonode.
8252+ if (Attrs.hasZAState()) {
8253+ TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
8254+ TPIDR2.FrameIndex = MFI.CreateStackObject(16, Align(16), false);
8255+ SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
8256+ DAG.getConstant(1, DL, MVT::i32));
8257+ SDValue Buffer;
8258+ if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8259+ Buffer = DAG.getNode(AArch64ISD::ALLOCATE_ZA_BUFFER, DL,
8260+ DAG.getVTList(MVT::i64, MVT::Other), {Chain, SVL});
8261+ } else {
8262+ SDValue Size = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL);
8263+ Buffer = DAG.getNode(ISD::DYNAMIC_STACKALLOC, DL,
8264+ DAG.getVTList(MVT::i64, MVT::Other),
8265+ {Chain, Size, DAG.getConstant(1, DL, MVT::i64)});
8266+ MFI.CreateVariableSizedObject(Align(16), nullptr);
8267+ }
8268+ Chain = DAG.getNode(
8269+ AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
8270+ {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
8271+ } else if (Attrs.hasAgnosticZAInterface()) {
8272+ // Call __arm_sme_state_size().
8273+ SDValue BufferSize =
8274+ DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
8275+ DAG.getVTList(MVT::i64, MVT::Other), Chain);
8276+ Chain = BufferSize.getValue(1);
8277+ SDValue Buffer;
8278+ if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8279+ Buffer = DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
8280+ DAG.getVTList(MVT::i64, MVT::Other),
8281+ {Chain, BufferSize});
8282+ } else {
8283+ // Allocate space dynamically.
8284+ Buffer = DAG.getNode(
8285+ ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
8286+ {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
8287+ MFI.CreateVariableSizedObject(Align(16), nullptr);
8288+ }
8289+ // Copy the value to a virtual register, and save that in FuncInfo.
8290+ Register BufferPtr =
8291+ MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
8292+ FuncInfo->setSMESaveBufferAddr(BufferPtr);
8293+ Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
82878294 }
8288-
8289- // Copy the value to a virtual register, and save that in FuncInfo.
8290- Register BufferPtr =
8291- MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
8292- FuncInfo->setSMESaveBufferAddr(BufferPtr);
8293- Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
82948295 }
82958296
82968297 if (CallConv == CallingConv::PreserveNone) {
@@ -8307,6 +8308,15 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
83078308 }
83088309 }
83098310
8311+ if (Subtarget->useNewSMEABILowering()) {
8312+ // Clear new ZT0 state. TODO: Move this to the SME ABI pass.
8313+ if (Attrs.isNewZT0())
8314+ Chain = DAG.getNode(
8315+ ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
8316+ DAG.getConstant(Intrinsic::aarch64_sme_zero_zt, DL, MVT::i32),
8317+ DAG.getTargetConstant(0, DL, MVT::i32));
8318+ }
8319+
83108320 return Chain;
83118321}
83128322
@@ -8871,14 +8881,12 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
88718881 MachineFunction &MF = DAG.getMachineFunction();
88728882 AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
88738883 FuncInfo->setSMESaveBufferUsed();
8874-
88758884 TargetLowering::ArgListTy Args;
88768885 TargetLowering::ArgListEntry Entry;
88778886 Entry.Ty = PointerType::getUnqual(*DAG.getContext());
88788887 Entry.Node =
88798888 DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64);
88808889 Args.push_back(Entry);
8881-
88828890 SDValue Callee =
88838891 DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore",
88848892 TLI.getPointerTy(DAG.getDataLayout()));
@@ -9001,6 +9009,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
90019009 if (MF.getTarget().Options.EmitCallGraphSection && CB && CB->isIndirectCall())
90029010 CSInfo = MachineFunction::CallSiteInfo(*CB);
90039011
9012+ // Determine whether we need any streaming mode changes.
9013+ SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
9014+
90049015 // Check callee args/returns for SVE registers and set calling convention
90059016 // accordingly.
90069017 if (CallConv == CallingConv::C || CallConv == CallingConv::Fast) {
@@ -9014,14 +9025,26 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
90149025 CallConv = CallingConv::AArch64_SVE_VectorCall;
90159026 }
90169027
9028+ bool UseNewSMEABILowering = Subtarget->useNewSMEABILowering();
9029+ bool IsAgnosticZAFunction = CallAttrs.caller().hasAgnosticZAInterface();
9030+ auto ZAMarkerNode = [&]() -> std::optional<unsigned> {
9031+ // TODO: Handle agnostic ZA functions.
9032+ if (!UseNewSMEABILowering || IsAgnosticZAFunction)
9033+ return std::nullopt;
9034+ if (!CallAttrs.caller().hasZAState() && !CallAttrs.caller().hasZT0State())
9035+ return std::nullopt;
9036+ return CallAttrs.requiresLazySave() ? AArch64ISD::REQUIRES_ZA_SAVE
9037+ : AArch64ISD::INOUT_ZA_USE;
9038+ }();
9039+
90179040 if (IsTailCall) {
90189041 // Check if it's really possible to do a tail call.
90199042 IsTailCall = isEligibleForTailCallOptimization(CLI);
90209043
90219044 // A sibling call is one where we're under the usual C ABI and not planning
90229045 // to change that but can still do a tail call:
9023- if (!TailCallOpt && IsTailCall && CallConv != CallingConv::Tail &&
9024- CallConv != CallingConv::SwiftTail)
9046+ if (!ZAMarkerNode.has_value() && !TailCallOpt && IsTailCall &&
9047+ CallConv != CallingConv::Tail && CallConv != CallingConv:: SwiftTail)
90259048 IsSibCall = true;
90269049
90279050 if (IsTailCall)
@@ -9073,9 +9096,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
90739096 assert(FPDiff % 16 == 0 && "unaligned stack on tail call");
90749097 }
90759098
9076- // Determine whether we need any streaming mode changes.
9077- SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
9078-
90799099 auto DescribeCallsite =
90809100 [&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
90819101 R << "call from '" << ore::NV("Caller", MF.getName()) << "' to '";
@@ -9089,7 +9109,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
90899109 return R;
90909110 };
90919111
9092- bool RequiresLazySave = CallAttrs.requiresLazySave();
9112+ bool RequiresLazySave = !UseNewSMEABILowering && CallAttrs.requiresLazySave();
90939113 bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
90949114 if (RequiresLazySave) {
90959115 const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
@@ -9171,10 +9191,21 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91719191 AArch64ISD::SMSTOP, DL, DAG.getVTList(MVT::Other, MVT::Glue), Chain,
91729192 DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32));
91739193
9174- // Adjust the stack pointer for the new arguments...
9194+ // Adjust the stack pointer for the new arguments... and mark ZA uses.
91759195 // These operations are automatically eliminated by the prolog/epilog pass
9176- if (!IsSibCall)
9196+ assert((!IsSibCall || !ZAMarkerNode.has_value()) &&
9197+ "ZA markers require CALLSEQ_START");
9198+ if (!IsSibCall) {
91779199 Chain = DAG.getCALLSEQ_START(Chain, IsTailCall ? 0 : NumBytes, 0, DL);
9200+ if (ZAMarkerNode) {
9201+ // Note: We need the CALLSEQ_START to glue the ZAMarkerNode to, simply
9202+ // using a chain can result in incorrect scheduling. The markers referer
9203+ // to the position just before the CALLSEQ_START (though occur after as
9204+ // CALLSEQ_START lacks in-glue).
9205+ Chain = DAG.getNode(*ZAMarkerNode, DL, DAG.getVTList(MVT::Other),
9206+ {Chain, Chain.getValue(1)});
9207+ }
9208+ }
91789209
91799210 SDValue StackPtr = DAG.getCopyFromReg(Chain, DL, AArch64::SP,
91809211 getPointerTy(DAG.getDataLayout()));
@@ -9646,7 +9677,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96469677 }
96479678 }
96489679
9649- if (CallAttrs.requiresEnablingZAAfterCall())
9680+ if (RequiresLazySave || CallAttrs.requiresEnablingZAAfterCall())
96509681 // Unconditionally resume ZA.
96519682 Result = DAG.getNode(
96529683 AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Result,
@@ -9667,7 +9698,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96679698 SDValue TPIDR2_EL0 = DAG.getNode(
96689699 ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
96699700 DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
9670-
96719701 // Copy the address of the TPIDR2 block into X0 before 'calling' the
96729702 // RESTORE_ZA pseudo.
96739703 SDValue Glue;
@@ -9679,7 +9709,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96799709 DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
96809710 {Result, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64),
96819711 RestoreRoutine, RegMask, Result.getValue(1)});
9682-
96839712 // Finally reset the TPIDR2_EL0 register to 0.
96849713 Result = DAG.getNode(
96859714 ISD::INTRINSIC_VOID, DL, MVT::Other, Result,
0 commit comments