@@ -2631,6 +2631,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
26312631 break;
26322632 MAKE_CASE(AArch64ISD::ALLOCATE_ZA_BUFFER)
26332633 MAKE_CASE(AArch64ISD::INIT_TPIDR2OBJ)
2634+ MAKE_CASE(AArch64ISD::GET_SME_SAVE_SIZE)
2635+ MAKE_CASE(AArch64ISD::ALLOC_SME_SAVE_BUFFER)
26342636 MAKE_CASE(AArch64ISD::COALESCER_BARRIER)
26352637 MAKE_CASE(AArch64ISD::VG_SAVE)
26362638 MAKE_CASE(AArch64ISD::VG_RESTORE)
@@ -3218,6 +3220,39 @@ AArch64TargetLowering::EmitAllocateZABuffer(MachineInstr &MI,
32183220 return BB;
32193221}
32203222
3223+ // TODO: Find a way to merge this with EmitAllocateZABuffer.
3224+ MachineBasicBlock *
3225+ AArch64TargetLowering::EmitAllocateSMESaveBuffer(MachineInstr &MI,
3226+ MachineBasicBlock *BB) const {
3227+ MachineFunction *MF = BB->getParent();
3228+ MachineFrameInfo &MFI = MF->getFrameInfo();
3229+ AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
3230+ assert(!MF->getSubtarget<AArch64Subtarget>().isTargetWindows() &&
3231+ "Lazy ZA save is not yet supported on Windows");
3232+
3233+ const TargetInstrInfo *TII = Subtarget->getInstrInfo();
3234+ if (FuncInfo->getSMESaveBufferUsed()) {
3235+ // Allocate a lazy-save buffer object of the size given, normally SVL * SVL
3236+ auto Size = MI.getOperand(1).getReg();
3237+ auto Dest = MI.getOperand(0).getReg();
3238+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::SUBXrx64), Dest)
3239+ .addReg(AArch64::SP)
3240+ .addReg(Size)
3241+ .addImm(AArch64_AM::getArithExtendImm(AArch64_AM::UXTX, 0));
3242+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
3243+ AArch64::SP)
3244+ .addReg(Dest);
3245+
3246+ // We have just allocated a variable sized object, tell this to PEI.
3247+ MFI.CreateVariableSizedObject(Align(16), nullptr);
3248+ } else
3249+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::IMPLICIT_DEF),
3250+ MI.getOperand(0).getReg());
3251+
3252+ BB->remove_instr(&MI);
3253+ return BB;
3254+ }
3255+
32213256MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
32223257 MachineInstr &MI, MachineBasicBlock *BB) const {
32233258
@@ -3252,6 +3287,31 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
32523287 return EmitInitTPIDR2Object(MI, BB);
32533288 case AArch64::AllocateZABuffer:
32543289 return EmitAllocateZABuffer(MI, BB);
3290+ case AArch64::AllocateSMESaveBuffer:
3291+ return EmitAllocateSMESaveBuffer(MI, BB);
3292+ case AArch64::GetSMESaveSize: {
3293+ // If the buffer is used, emit a call to __arm_sme_state_size()
3294+ MachineFunction *MF = BB->getParent();
3295+ AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
3296+ const TargetInstrInfo *TII = Subtarget->getInstrInfo();
3297+ if (FuncInfo->getSMESaveBufferUsed()) {
3298+ const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
3299+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
3300+ .addExternalSymbol("__arm_sme_state_size")
3301+ .addReg(AArch64::X0, RegState::ImplicitDefine)
3302+ .addRegMask(TRI->getCallPreservedMask(
3303+ *MF, CallingConv::
3304+ AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
3305+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
3306+ MI.getOperand(0).getReg())
3307+ .addReg(AArch64::X0);
3308+ } else
3309+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
3310+ MI.getOperand(0).getReg())
3311+ .addReg(AArch64::XZR);
3312+ BB->remove_instr(&MI);
3313+ return BB;
3314+ }
32553315 case AArch64::F128CSEL:
32563316 return EmitF128CSEL(MI, BB);
32573317 case TargetOpcode::STATEPOINT:
@@ -7651,6 +7711,7 @@ CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC,
76517711 case CallingConv::AArch64_VectorCall:
76527712 case CallingConv::AArch64_SVE_VectorCall:
76537713 case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0:
7714+ case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1:
76547715 case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2:
76557716 return CC_AArch64_AAPCS;
76567717 case CallingConv::ARM64EC_Thunk_X64:
@@ -8110,6 +8171,31 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
81108171 Chain = DAG.getNode(
81118172 AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
81128173 {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
8174+ } else if (SMEAttrs(MF.getFunction()).hasAgnosticZAInterface()) {
8175+ // Call __arm_sme_state_size().
8176+ SDValue BufferSize =
8177+ DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
8178+ DAG.getVTList(MVT::i64, MVT::Other), Chain);
8179+ Chain = BufferSize.getValue(1);
8180+
8181+ SDValue Buffer;
8182+ if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8183+ Buffer =
8184+ DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
8185+ DAG.getVTList(MVT::i64, MVT::Other), {Chain, BufferSize});
8186+ } else {
8187+ // Allocate space dynamically.
8188+ Buffer = DAG.getNode(
8189+ ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
8190+ {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
8191+ MFI.CreateVariableSizedObject(Align(16), nullptr);
8192+ }
8193+
8194+ // Copy the value to a virtual register, and save that in FuncInfo.
8195+ Register BufferPtr =
8196+ MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
8197+ FuncInfo->setSMESaveBufferAddr(BufferPtr);
8198+ Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
81138199 }
81148200
81158201 if (CallConv == CallingConv::PreserveNone) {
@@ -8398,6 +8484,7 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
83988484 auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
83998485 if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
84008486 CallerAttrs.requiresLazySave(CalleeAttrs) ||
8487+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
84018488 CallerAttrs.hasStreamingBody())
84028489 return false;
84038490
@@ -8722,6 +8809,30 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
87228809 return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
87238810}
87248811
8812+ // Emit a call to __arm_sme_save or __arm_sme_restore.
8813+ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
8814+ SelectionDAG &DAG,
8815+ AArch64FunctionInfo *Info, SDLoc DL,
8816+ SDValue Chain, bool IsSave) {
8817+ TargetLowering::ArgListTy Args;
8818+ TargetLowering::ArgListEntry Entry;
8819+ Entry.Ty = PointerType::getUnqual(*DAG.getContext());
8820+ Entry.Node =
8821+ DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64);
8822+ Args.push_back(Entry);
8823+
8824+ SDValue Callee =
8825+ DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore",
8826+ TLI.getPointerTy(DAG.getDataLayout()));
8827+ auto *RetTy = Type::getVoidTy(*DAG.getContext());
8828+ TargetLowering::CallLoweringInfo CLI(DAG);
8829+ CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
8830+ CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1, RetTy,
8831+ Callee, std::move(Args));
8832+
8833+ return TLI.LowerCallTo(CLI).second;
8834+ }
8835+
87258836static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
87268837 const SMEAttrs &CalleeAttrs) {
87278838 if (!CallerAttrs.hasStreamingCompatibleInterface() ||
@@ -8882,6 +8993,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
88828993 };
88838994
88848995 bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
8996+ bool RequiresSaveAllZA =
8997+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
8998+ SDValue ZAStateBuffer;
88858999 if (RequiresLazySave) {
88869000 const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
88879001 MachinePointerInfo MPI =
@@ -8908,6 +9022,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
89089022 &MF.getFunction());
89099023 return DescribeCallsite(R) << " sets up a lazy save for ZA";
89109024 });
9025+ } else if (RequiresSaveAllZA) {
9026+ assert(!CalleeAttrs.hasSharedZAInterface() &&
9027+ "Cannot share state that may not exist");
9028+ Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
9029+ /*IsSave=*/true);
89119030 }
89129031
89139032 SDValue PStateSM;
@@ -9455,9 +9574,14 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
94559574 DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
94569575 DAG.getConstant(0, DL, MVT::i64));
94579576 TPIDR2.Uses++;
9577+ } else if (RequiresSaveAllZA) {
9578+ Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
9579+ /*IsSave=*/false);
9580+ FuncInfo->setSMESaveBufferUsed();
94589581 }
94599582
9460- if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0) {
9583+ if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0 ||
9584+ RequiresSaveAllZA) {
94619585 for (unsigned I = 0; I < InVals.size(); ++I) {
94629586 // The smstart/smstop is chained as part of the call, but when the
94639587 // resulting chain is discarded (which happens when the call is not part
@@ -28063,7 +28187,8 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
2806328187 auto CalleeAttrs = SMEAttrs(*Base);
2806428188 if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
2806528189 CallerAttrs.requiresLazySave(CalleeAttrs) ||
28066- CallerAttrs.requiresPreservingZT0(CalleeAttrs))
28190+ CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
28191+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
2806728192 return true;
2806828193 }
2806928194 return false;
0 commit comments