@@ -2643,6 +2643,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
26432643 break;
26442644 MAKE_CASE(AArch64ISD::ALLOCATE_ZA_BUFFER)
26452645 MAKE_CASE(AArch64ISD::INIT_TPIDR2OBJ)
2646+ MAKE_CASE(AArch64ISD::GET_SME_SAVE_SIZE)
2647+ MAKE_CASE(AArch64ISD::ALLOC_SME_SAVE_BUFFER)
26462648 MAKE_CASE(AArch64ISD::COALESCER_BARRIER)
26472649 MAKE_CASE(AArch64ISD::VG_SAVE)
26482650 MAKE_CASE(AArch64ISD::VG_RESTORE)
@@ -3230,6 +3232,64 @@ AArch64TargetLowering::EmitAllocateZABuffer(MachineInstr &MI,
32303232 return BB;
32313233}
32323234
3235+ // TODO: Find a way to merge this with EmitAllocateZABuffer.
3236+ MachineBasicBlock *
3237+ AArch64TargetLowering::EmitAllocateSMESaveBuffer(MachineInstr &MI,
3238+ MachineBasicBlock *BB) const {
3239+ MachineFunction *MF = BB->getParent();
3240+ MachineFrameInfo &MFI = MF->getFrameInfo();
3241+ AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
3242+ assert(!MF->getSubtarget<AArch64Subtarget>().isTargetWindows() &&
3243+ "Lazy ZA save is not yet supported on Windows");
3244+
3245+ const TargetInstrInfo *TII = Subtarget->getInstrInfo();
3246+ if (FuncInfo->isSMESaveBufferUsed()) {
3247+ // Allocate a buffer object of the size given by MI.getOperand(1).
3248+ auto Size = MI.getOperand(1).getReg();
3249+ auto Dest = MI.getOperand(0).getReg();
3250+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::SUBXrx64), AArch64::SP)
3251+ .addReg(AArch64::SP)
3252+ .addReg(Size)
3253+ .addImm(AArch64_AM::getArithExtendImm(AArch64_AM::UXTX, 0));
3254+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), Dest)
3255+ .addReg(AArch64::SP);
3256+
3257+ // We have just allocated a variable sized object, tell this to PEI.
3258+ MFI.CreateVariableSizedObject(Align(16), nullptr);
3259+ } else
3260+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::IMPLICIT_DEF),
3261+ MI.getOperand(0).getReg());
3262+
3263+ BB->remove_instr(&MI);
3264+ return BB;
3265+ }
3266+
3267+ MachineBasicBlock *
3268+ AArch64TargetLowering::EmitGetSMESaveSize(MachineInstr &MI,
3269+ MachineBasicBlock *BB) const {
3270+ // If the buffer is used, emit a call to __arm_sme_state_size()
3271+ MachineFunction *MF = BB->getParent();
3272+ AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
3273+ const TargetInstrInfo *TII = Subtarget->getInstrInfo();
3274+ if (FuncInfo->isSMESaveBufferUsed()) {
3275+ const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
3276+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
3277+ .addExternalSymbol("__arm_sme_state_size")
3278+ .addReg(AArch64::X0, RegState::ImplicitDefine)
3279+ .addRegMask(TRI->getCallPreservedMask(
3280+ *MF, CallingConv::
3281+ AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
3282+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
3283+ MI.getOperand(0).getReg())
3284+ .addReg(AArch64::X0);
3285+ } else
3286+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
3287+ MI.getOperand(0).getReg())
3288+ .addReg(AArch64::XZR);
3289+ BB->remove_instr(&MI);
3290+ return BB;
3291+ }
3292+
32333293MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
32343294 MachineInstr &MI, MachineBasicBlock *BB) const {
32353295
@@ -3264,6 +3324,10 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
32643324 return EmitInitTPIDR2Object(MI, BB);
32653325 case AArch64::AllocateZABuffer:
32663326 return EmitAllocateZABuffer(MI, BB);
3327+ case AArch64::AllocateSMESaveBuffer:
3328+ return EmitAllocateSMESaveBuffer(MI, BB);
3329+ case AArch64::GetSMESaveSize:
3330+ return EmitGetSMESaveSize(MI, BB);
32673331 case AArch64::F128CSEL:
32683332 return EmitF128CSEL(MI, BB);
32693333 case TargetOpcode::STATEPOINT:
@@ -7663,6 +7727,7 @@ CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC,
76637727 case CallingConv::AArch64_VectorCall:
76647728 case CallingConv::AArch64_SVE_VectorCall:
76657729 case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0:
7730+ case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1:
76667731 case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2:
76677732 return CC_AArch64_AAPCS;
76687733 case CallingConv::ARM64EC_Thunk_X64:
@@ -8122,6 +8187,31 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
81228187 Chain = DAG.getNode(
81238188 AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
81248189 {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
8190+ } else if (SMEAttrs(MF.getFunction()).hasAgnosticZAInterface()) {
8191+ // Call __arm_sme_state_size().
8192+ SDValue BufferSize =
8193+ DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
8194+ DAG.getVTList(MVT::i64, MVT::Other), Chain);
8195+ Chain = BufferSize.getValue(1);
8196+
8197+ SDValue Buffer;
8198+ if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8199+ Buffer =
8200+ DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
8201+ DAG.getVTList(MVT::i64, MVT::Other), {Chain, BufferSize});
8202+ } else {
8203+ // Allocate space dynamically.
8204+ Buffer = DAG.getNode(
8205+ ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
8206+ {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
8207+ MFI.CreateVariableSizedObject(Align(16), nullptr);
8208+ }
8209+
8210+ // Copy the value to a virtual register, and save that in FuncInfo.
8211+ Register BufferPtr =
8212+ MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
8213+ FuncInfo->setSMESaveBufferAddr(BufferPtr);
8214+ Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
81258215 }
81268216
81278217 if (CallConv == CallingConv::PreserveNone) {
@@ -8410,6 +8500,7 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
84108500 auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
84118501 if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
84128502 CallerAttrs.requiresLazySave(CalleeAttrs) ||
8503+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
84138504 CallerAttrs.hasStreamingBody())
84148505 return false;
84158506
@@ -8734,6 +8825,33 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
87348825 return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
87358826}
87368827
8828+ // Emit a call to __arm_sme_save or __arm_sme_restore.
8829+ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
8830+ SelectionDAG &DAG,
8831+ AArch64FunctionInfo *Info, SDLoc DL,
8832+ SDValue Chain, bool IsSave) {
8833+ MachineFunction &MF = DAG.getMachineFunction();
8834+ AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
8835+ FuncInfo->setSMESaveBufferUsed();
8836+
8837+ TargetLowering::ArgListTy Args;
8838+ TargetLowering::ArgListEntry Entry;
8839+ Entry.Ty = PointerType::getUnqual(*DAG.getContext());
8840+ Entry.Node =
8841+ DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64);
8842+ Args.push_back(Entry);
8843+
8844+ SDValue Callee =
8845+ DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore",
8846+ TLI.getPointerTy(DAG.getDataLayout()));
8847+ auto *RetTy = Type::getVoidTy(*DAG.getContext());
8848+ TargetLowering::CallLoweringInfo CLI(DAG);
8849+ CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
8850+ CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1, RetTy,
8851+ Callee, std::move(Args));
8852+ return TLI.LowerCallTo(CLI).second;
8853+ }
8854+
87378855static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
87388856 const SMEAttrs &CalleeAttrs) {
87398857 if (!CallerAttrs.hasStreamingCompatibleInterface() ||
@@ -8894,6 +9012,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
88949012 };
88959013
88969014 bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
9015+ bool RequiresSaveAllZA =
9016+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
88979017 if (RequiresLazySave) {
88989018 const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
88999019 MachinePointerInfo MPI =
@@ -8920,6 +9040,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
89209040 &MF.getFunction());
89219041 return DescribeCallsite(R) << " sets up a lazy save for ZA";
89229042 });
9043+ } else if (RequiresSaveAllZA) {
9044+ assert(!CalleeAttrs.hasSharedZAInterface() &&
9045+ "Cannot share state that may not exist");
9046+ Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
9047+ /*IsSave=*/true);
89239048 }
89249049
89259050 SDValue PStateSM;
@@ -9467,9 +9592,13 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
94679592 DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
94689593 DAG.getConstant(0, DL, MVT::i64));
94699594 TPIDR2.Uses++;
9595+ } else if (RequiresSaveAllZA) {
9596+ Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
9597+ /*IsSave=*/false);
94709598 }
94719599
9472- if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0) {
9600+ if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0 ||
9601+ RequiresSaveAllZA) {
94739602 for (unsigned I = 0; I < InVals.size(); ++I) {
94749603 // The smstart/smstop is chained as part of the call, but when the
94759604 // resulting chain is discarded (which happens when the call is not part
@@ -28084,7 +28213,8 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
2808428213 auto CalleeAttrs = SMEAttrs(*Base);
2808528214 if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
2808628215 CallerAttrs.requiresLazySave(CalleeAttrs) ||
28087- CallerAttrs.requiresPreservingZT0(CalleeAttrs))
28216+ CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
28217+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
2808828218 return true;
2808928219 }
2809028220 return false;
0 commit comments