@@ -3244,16 +3244,15 @@ AArch64TargetLowering::EmitAllocateSMESaveBuffer(MachineInstr &MI,
32443244
32453245 const TargetInstrInfo *TII = Subtarget->getInstrInfo();
32463246 if (FuncInfo->getSMESaveBufferUsed()) {
3247- // Allocate a lazy-save buffer object of the size given, normally SVL * SVL
3247+ // Allocate a buffer object of the size given by MI.getOperand(1).
32483248 auto Size = MI.getOperand(1).getReg();
32493249 auto Dest = MI.getOperand(0).getReg();
3250- BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::SUBXrx64), Dest )
3250+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::SUBXrx64), AArch64::SP )
32513251 .addReg(AArch64::SP)
32523252 .addReg(Size)
32533253 .addImm(AArch64_AM::getArithExtendImm(AArch64_AM::UXTX, 0));
3254- BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
3255- AArch64::SP)
3256- .addReg(Dest);
3254+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), Dest)
3255+ .addReg(AArch64::SP);
32573256
32583257 // We have just allocated a variable sized object, tell this to PEI.
32593258 MFI.CreateVariableSizedObject(Align(16), nullptr);
@@ -3265,6 +3264,32 @@ AArch64TargetLowering::EmitAllocateSMESaveBuffer(MachineInstr &MI,
32653264 return BB;
32663265}
32673266
3267+ MachineBasicBlock *
3268+ AArch64TargetLowering::EmitGetSMESaveSize(MachineInstr &MI,
3269+ MachineBasicBlock *BB) const {
3270+ // If the buffer is used, emit a call to __arm_sme_state_size()
3271+ MachineFunction *MF = BB->getParent();
3272+ AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
3273+ const TargetInstrInfo *TII = Subtarget->getInstrInfo();
3274+ if (FuncInfo->getSMESaveBufferUsed()) {
3275+ const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
3276+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
3277+ .addExternalSymbol("__arm_sme_state_size")
3278+ .addReg(AArch64::X0, RegState::ImplicitDefine)
3279+ .addRegMask(TRI->getCallPreservedMask(
3280+ *MF, CallingConv::
3281+ AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
3282+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
3283+ MI.getOperand(0).getReg())
3284+ .addReg(AArch64::X0);
3285+ } else
3286+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
3287+ MI.getOperand(0).getReg())
3288+ .addReg(AArch64::XZR);
3289+ BB->remove_instr(&MI);
3290+ return BB;
3291+ }
3292+
32683293MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
32693294 MachineInstr &MI, MachineBasicBlock *BB) const {
32703295
@@ -3301,29 +3326,8 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
33013326 return EmitAllocateZABuffer(MI, BB);
33023327 case AArch64::AllocateSMESaveBuffer:
33033328 return EmitAllocateSMESaveBuffer(MI, BB);
3304- case AArch64::GetSMESaveSize: {
3305- // If the buffer is used, emit a call to __arm_sme_state_size()
3306- MachineFunction *MF = BB->getParent();
3307- AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
3308- const TargetInstrInfo *TII = Subtarget->getInstrInfo();
3309- if (FuncInfo->getSMESaveBufferUsed()) {
3310- const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
3311- BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
3312- .addExternalSymbol("__arm_sme_state_size")
3313- .addReg(AArch64::X0, RegState::ImplicitDefine)
3314- .addRegMask(TRI->getCallPreservedMask(
3315- *MF, CallingConv::
3316- AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
3317- BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
3318- MI.getOperand(0).getReg())
3319- .addReg(AArch64::X0);
3320- } else
3321- BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
3322- MI.getOperand(0).getReg())
3323- .addReg(AArch64::XZR);
3324- BB->remove_instr(&MI);
3325- return BB;
3326- }
3329+ case AArch64::GetSMESaveSize:
3330+ return EmitGetSMESaveSize(MI, BB);
33273331 case AArch64::F128CSEL:
33283332 return EmitF128CSEL(MI, BB);
33293333 case TargetOpcode::STATEPOINT:
@@ -8826,6 +8830,10 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
88268830 SelectionDAG &DAG,
88278831 AArch64FunctionInfo *Info, SDLoc DL,
88288832 SDValue Chain, bool IsSave) {
8833+ MachineFunction &MF = DAG.getMachineFunction();
8834+ AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
8835+ FuncInfo->setSMESaveBufferUsed();
8836+
88298837 TargetLowering::ArgListTy Args;
88308838 TargetLowering::ArgListEntry Entry;
88318839 Entry.Ty = PointerType::getUnqual(*DAG.getContext());
@@ -8841,7 +8849,6 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
88418849 CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
88428850 CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1, RetTy,
88438851 Callee, std::move(Args));
8844-
88458852 return TLI.LowerCallTo(CLI).second;
88468853}
88478854
@@ -9007,7 +9014,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
90079014 bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
90089015 bool RequiresSaveAllZA =
90099016 CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
9010- SDValue ZAStateBuffer;
90119017 if (RequiresLazySave) {
90129018 const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
90139019 MachinePointerInfo MPI =
@@ -9589,7 +9595,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
95899595 } else if (RequiresSaveAllZA) {
95909596 Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
95919597 /*IsSave=*/false);
9592- FuncInfo->setSMESaveBufferUsed();
95939598 }
95949599
95959600 if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0 ||
0 commit comments