@@ -2643,6 +2643,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
26432643 break;
26442644 MAKE_CASE(AArch64ISD::ALLOCATE_ZA_BUFFER)
26452645 MAKE_CASE(AArch64ISD::INIT_TPIDR2OBJ)
2646+ MAKE_CASE(AArch64ISD::GET_SME_SAVE_SIZE)
2647+ MAKE_CASE(AArch64ISD::ALLOC_SME_SAVE_BUFFER)
26462648 MAKE_CASE(AArch64ISD::COALESCER_BARRIER)
26472649 MAKE_CASE(AArch64ISD::VG_SAVE)
26482650 MAKE_CASE(AArch64ISD::VG_RESTORE)
@@ -3230,6 +3232,39 @@ AArch64TargetLowering::EmitAllocateZABuffer(MachineInstr &MI,
32303232 return BB;
32313233}
32323234
3235+ // TODO: Find a way to merge this with EmitAllocateZABuffer.
3236+ MachineBasicBlock *
3237+ AArch64TargetLowering::EmitAllocateSMESaveBuffer(MachineInstr &MI,
3238+ MachineBasicBlock *BB) const {
3239+ MachineFunction *MF = BB->getParent();
3240+ MachineFrameInfo &MFI = MF->getFrameInfo();
3241+ AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
3242+ assert(!MF->getSubtarget<AArch64Subtarget>().isTargetWindows() &&
3243+ "Lazy ZA save is not yet supported on Windows");
3244+
3245+ const TargetInstrInfo *TII = Subtarget->getInstrInfo();
3246+ if (FuncInfo->getSMESaveBufferUsed()) {
3247+ // Allocate a lazy-save buffer object of the size given, normally SVL * SVL
3248+ auto Size = MI.getOperand(1).getReg();
3249+ auto Dest = MI.getOperand(0).getReg();
3250+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::SUBXrx64), Dest)
3251+ .addReg(AArch64::SP)
3252+ .addReg(Size)
3253+ .addImm(AArch64_AM::getArithExtendImm(AArch64_AM::UXTX, 0));
3254+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
3255+ AArch64::SP)
3256+ .addReg(Dest);
3257+
3258+ // We have just allocated a variable sized object, tell this to PEI.
3259+ MFI.CreateVariableSizedObject(Align(16), nullptr);
3260+ } else
3261+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::IMPLICIT_DEF),
3262+ MI.getOperand(0).getReg());
3263+
3264+ BB->remove_instr(&MI);
3265+ return BB;
3266+ }
3267+
32333268MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
32343269 MachineInstr &MI, MachineBasicBlock *BB) const {
32353270
@@ -3264,6 +3299,31 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
32643299 return EmitInitTPIDR2Object(MI, BB);
32653300 case AArch64::AllocateZABuffer:
32663301 return EmitAllocateZABuffer(MI, BB);
3302+ case AArch64::AllocateSMESaveBuffer:
3303+ return EmitAllocateSMESaveBuffer(MI, BB);
3304+ case AArch64::GetSMESaveSize: {
3305+ // If the buffer is used, emit a call to __arm_sme_state_size()
3306+ MachineFunction *MF = BB->getParent();
3307+ AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
3308+ const TargetInstrInfo *TII = Subtarget->getInstrInfo();
3309+ if (FuncInfo->getSMESaveBufferUsed()) {
3310+ const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
3311+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
3312+ .addExternalSymbol("__arm_sme_state_size")
3313+ .addReg(AArch64::X0, RegState::ImplicitDefine)
3314+ .addRegMask(TRI->getCallPreservedMask(
3315+ *MF, CallingConv::
3316+ AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
3317+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
3318+ MI.getOperand(0).getReg())
3319+ .addReg(AArch64::X0);
3320+ } else
3321+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
3322+ MI.getOperand(0).getReg())
3323+ .addReg(AArch64::XZR);
3324+ BB->remove_instr(&MI);
3325+ return BB;
3326+ }
32673327 case AArch64::F128CSEL:
32683328 return EmitF128CSEL(MI, BB);
32693329 case TargetOpcode::STATEPOINT:
@@ -7663,6 +7723,7 @@ CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC,
76637723 case CallingConv::AArch64_VectorCall:
76647724 case CallingConv::AArch64_SVE_VectorCall:
76657725 case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0:
7726+ case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1:
76667727 case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2:
76677728 return CC_AArch64_AAPCS;
76687729 case CallingConv::ARM64EC_Thunk_X64:
@@ -8122,6 +8183,31 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
81228183 Chain = DAG.getNode(
81238184 AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
81248185 {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
8186+ } else if (SMEAttrs(MF.getFunction()).hasAgnosticZAInterface()) {
8187+ // Call __arm_sme_state_size().
8188+ SDValue BufferSize =
8189+ DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
8190+ DAG.getVTList(MVT::i64, MVT::Other), Chain);
8191+ Chain = BufferSize.getValue(1);
8192+
8193+ SDValue Buffer;
8194+ if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8195+ Buffer =
8196+ DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
8197+ DAG.getVTList(MVT::i64, MVT::Other), {Chain, BufferSize});
8198+ } else {
8199+ // Allocate space dynamically.
8200+ Buffer = DAG.getNode(
8201+ ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
8202+ {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
8203+ MFI.CreateVariableSizedObject(Align(16), nullptr);
8204+ }
8205+
8206+ // Copy the value to a virtual register, and save that in FuncInfo.
8207+ Register BufferPtr =
8208+ MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
8209+ FuncInfo->setSMESaveBufferAddr(BufferPtr);
8210+ Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
81258211 }
81268212
81278213 if (CallConv == CallingConv::PreserveNone) {
@@ -8410,6 +8496,7 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
84108496 auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
84118497 if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
84128498 CallerAttrs.requiresLazySave(CalleeAttrs) ||
8499+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
84138500 CallerAttrs.hasStreamingBody())
84148501 return false;
84158502
@@ -8734,6 +8821,30 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
87348821 return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
87358822}
87368823
8824+ // Emit a call to __arm_sme_save or __arm_sme_restore.
8825+ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
8826+ SelectionDAG &DAG,
8827+ AArch64FunctionInfo *Info, SDLoc DL,
8828+ SDValue Chain, bool IsSave) {
8829+ TargetLowering::ArgListTy Args;
8830+ TargetLowering::ArgListEntry Entry;
8831+ Entry.Ty = PointerType::getUnqual(*DAG.getContext());
8832+ Entry.Node =
8833+ DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64);
8834+ Args.push_back(Entry);
8835+
8836+ SDValue Callee =
8837+ DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore",
8838+ TLI.getPointerTy(DAG.getDataLayout()));
8839+ auto *RetTy = Type::getVoidTy(*DAG.getContext());
8840+ TargetLowering::CallLoweringInfo CLI(DAG);
8841+ CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
8842+ CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1, RetTy,
8843+ Callee, std::move(Args));
8844+
8845+ return TLI.LowerCallTo(CLI).second;
8846+ }
8847+
87378848static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
87388849 const SMEAttrs &CalleeAttrs) {
87398850 if (!CallerAttrs.hasStreamingCompatibleInterface() ||
@@ -8894,6 +9005,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
88949005 };
88959006
88969007 bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
9008+ bool RequiresSaveAllZA =
9009+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
9010+ SDValue ZAStateBuffer;
88979011 if (RequiresLazySave) {
88989012 const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
88999013 MachinePointerInfo MPI =
@@ -8920,6 +9034,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
89209034 &MF.getFunction());
89219035 return DescribeCallsite(R) << " sets up a lazy save for ZA";
89229036 });
9037+ } else if (RequiresSaveAllZA) {
9038+ assert(!CalleeAttrs.hasSharedZAInterface() &&
9039+ "Cannot share state that may not exist");
9040+ Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
9041+ /*IsSave=*/true);
89239042 }
89249043
89259044 SDValue PStateSM;
@@ -9467,9 +9586,14 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
94679586 DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
94689587 DAG.getConstant(0, DL, MVT::i64));
94699588 TPIDR2.Uses++;
9589+ } else if (RequiresSaveAllZA) {
9590+ Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
9591+ /*IsSave=*/false);
9592+ FuncInfo->setSMESaveBufferUsed();
94709593 }
94719594
9472- if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0) {
9595+ if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0 ||
9596+ RequiresSaveAllZA) {
94739597 for (unsigned I = 0; I < InVals.size(); ++I) {
94749598 // The smstart/smstop is chained as part of the call, but when the
94759599 // resulting chain is discarded (which happens when the call is not part
@@ -28084,7 +28208,8 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
2808428208 auto CalleeAttrs = SMEAttrs(*Base);
2808528209 if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
2808628210 CallerAttrs.requiresLazySave(CalleeAttrs) ||
28087- CallerAttrs.requiresPreservingZT0(CalleeAttrs))
28211+ CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
28212+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
2808828213 return true;
2808928214 }
2809028215 return false;
0 commit comments