@@ -8094,13 +8094,76 @@ static SDValue getZT0FrameIndex(MachineFrameInfo &MFI,
80948094 DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
80958095}
80968096
8097+ // Emit a call to __arm_sme_save or __arm_sme_restore.
8098+ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
8099+ SelectionDAG &DAG,
8100+ AArch64FunctionInfo *Info, SDLoc DL,
8101+ SDValue Chain, bool IsSave) {
8102+ MachineFunction &MF = DAG.getMachineFunction();
8103+ AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
8104+ FuncInfo->setSMESaveBufferUsed();
8105+ TargetLowering::ArgListTy Args;
8106+ Args.emplace_back(
8107+ DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64),
8108+ PointerType::getUnqual(*DAG.getContext()));
8109+
8110+ RTLIB::Libcall LC =
8111+ IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE;
8112+ SDValue Callee = DAG.getExternalSymbol(TLI.getLibcallName(LC),
8113+ TLI.getPointerTy(DAG.getDataLayout()));
8114+ auto *RetTy = Type::getVoidTy(*DAG.getContext());
8115+ TargetLowering::CallLoweringInfo CLI(DAG);
8116+ CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
8117+ TLI.getLibcallCallingConv(LC), RetTy, Callee, std::move(Args));
8118+ return TLI.LowerCallTo(CLI).second;
8119+ }
8120+
8121+ static SDValue emitRestoreZALazySave(SDValue Chain, SDLoc DL,
8122+ const AArch64TargetLowering &TLI,
8123+ const AArch64RegisterInfo &TRI,
8124+ AArch64FunctionInfo &FuncInfo,
8125+ SelectionDAG &DAG) {
8126+ // Conditionally restore the lazy save using a pseudo node.
8127+ RTLIB::Libcall LC = RTLIB::SMEABI_TPIDR2_RESTORE;
8128+ TPIDR2Object &TPIDR2 = FuncInfo.getTPIDR2Obj();
8129+ SDValue RegMask = DAG.getRegisterMask(TRI.getCallPreservedMask(
8130+ DAG.getMachineFunction(), TLI.getLibcallCallingConv(LC)));
8131+ SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
8132+ TLI.getLibcallName(LC), TLI.getPointerTy(DAG.getDataLayout()));
8133+ SDValue TPIDR2_EL0 = DAG.getNode(
8134+ ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Chain,
8135+ DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
8136+ // Copy the address of the TPIDR2 block into X0 before 'calling' the
8137+ // RESTORE_ZA pseudo.
8138+ SDValue Glue;
8139+ SDValue TPIDR2Block = DAG.getFrameIndex(
8140+ TPIDR2.FrameIndex,
8141+ DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
8142+ Chain = DAG.getCopyToReg(Chain, DL, AArch64::X0, TPIDR2Block, Glue);
8143+ Chain =
8144+ DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
8145+ {Chain, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64),
8146+ RestoreRoutine, RegMask, Chain.getValue(1)});
8147+ // Finally reset the TPIDR2_EL0 register to 0.
8148+ Chain = DAG.getNode(
8149+ ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
8150+ DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
8151+ DAG.getConstant(0, DL, MVT::i64));
8152+ TPIDR2.Uses++;
8153+ return Chain;
8154+ }
8155+
80978156SDValue AArch64TargetLowering::lowerEHPadEntry(SDValue Chain, SDLoc const &DL,
80988157 SelectionDAG &DAG) const {
80998158 assert(Chain.getOpcode() == ISD::EntryToken && "Unexpected Chain value");
81008159 SDValue Glue = Chain.getValue(1);
81018160
81028161 MachineFunction &MF = DAG.getMachineFunction();
8103- SMEAttrs SMEFnAttrs = MF.getInfo<AArch64FunctionInfo>()->getSMEFnAttrs();
8162+ auto &FuncInfo = *MF.getInfo<AArch64FunctionInfo>();
8163+ auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
8164+ const AArch64RegisterInfo &TRI = *Subtarget.getRegisterInfo();
8165+
8166+ SMEAttrs SMEFnAttrs = FuncInfo.getSMEFnAttrs();
81048167
81058168 // The following conditions are true on entry to an exception handler:
81068169 // - PSTATE.SM is 0.
@@ -8115,14 +8178,43 @@ SDValue AArch64TargetLowering::lowerEHPadEntry(SDValue Chain, SDLoc const &DL,
81158178 // These mode changes are usually optimized away in catch blocks as they
81168179 // occur before the __cxa_begin_catch (which is a non-streaming function),
81178180 // but are necessary in some cases (such as for cleanups).
8181+ //
8182+ // Additionally, if the function has ZA or ZT0 state, we must restore it.
81188183
8184+ // [COND_]SMSTART SM
81198185 if (SMEFnAttrs.hasStreamingInterfaceOrBody())
8120- return changeStreamingMode(DAG, DL, /*Enable=*/true, Chain,
8121- /*Glue*/ Glue, AArch64SME::Always);
8186+ Chain = changeStreamingMode(DAG, DL, /*Enable=*/true, Chain,
8187+ /*Glue*/ Glue, AArch64SME::Always);
8188+ else if (SMEFnAttrs.hasStreamingCompatibleInterface())
8189+ Chain = changeStreamingMode(DAG, DL, /*Enable=*/true, Chain, Glue,
8190+ AArch64SME::IfCallerIsStreaming);
8191+
8192+ if (getTM().useNewSMEABILowering())
8193+ return Chain;
81228194
8123- if (SMEFnAttrs.hasStreamingCompatibleInterface())
8124- return changeStreamingMode(DAG, DL, /*Enable=*/true, Chain, Glue,
8125- AArch64SME::IfCallerIsStreaming);
8195+ if (SMEFnAttrs.hasAgnosticZAInterface()) {
8196+ // Restore full ZA
8197+ Chain = emitSMEStateSaveRestore(*this, DAG, &FuncInfo, DL, Chain,
8198+ /*IsSave=*/false);
8199+ } else if (SMEFnAttrs.hasZAState() || SMEFnAttrs.hasZT0State()) {
8200+ // SMSTART ZA
8201+ Chain = DAG.getNode(
8202+ AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Chain,
8203+ DAG.getTargetConstant(int32_t(AArch64SVCR::SVCRZA), DL, MVT::i32));
8204+
8205+ // Restore ZT0
8206+ if (SMEFnAttrs.hasZT0State()) {
8207+ SDValue ZT0FrameIndex =
8208+ getZT0FrameIndex(MF.getFrameInfo(), FuncInfo, DAG);
8209+ Chain =
8210+ DAG.getNode(AArch64ISD::RESTORE_ZT, DL, DAG.getVTList(MVT::Other),
8211+ {Chain, DAG.getConstant(0, DL, MVT::i32), ZT0FrameIndex});
8212+ }
8213+
8214+ // Restore ZA
8215+ if (SMEFnAttrs.hasZAState())
8216+ Chain = emitRestoreZALazySave(Chain, DL, *this, TRI, FuncInfo, DAG);
8217+ }
81268218
81278219 return Chain;
81288220}
@@ -9240,30 +9332,6 @@ SDValue AArch64TargetLowering::changeStreamingMode(
92409332 return GetCheckVL(SMChange.getValue(0), SMChange.getValue(1));
92419333}
92429334
9243- // Emit a call to __arm_sme_save or __arm_sme_restore.
9244- static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
9245- SelectionDAG &DAG,
9246- AArch64FunctionInfo *Info, SDLoc DL,
9247- SDValue Chain, bool IsSave) {
9248- MachineFunction &MF = DAG.getMachineFunction();
9249- AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
9250- FuncInfo->setSMESaveBufferUsed();
9251- TargetLowering::ArgListTy Args;
9252- Args.emplace_back(
9253- DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64),
9254- PointerType::getUnqual(*DAG.getContext()));
9255-
9256- RTLIB::Libcall LC =
9257- IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE;
9258- SDValue Callee = DAG.getExternalSymbol(TLI.getLibcallName(LC),
9259- TLI.getPointerTy(DAG.getDataLayout()));
9260- auto *RetTy = Type::getVoidTy(*DAG.getContext());
9261- TargetLowering::CallLoweringInfo CLI(DAG);
9262- CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
9263- TLI.getLibcallCallingConv(LC), RetTy, Callee, std::move(Args));
9264- return TLI.LowerCallTo(CLI).second;
9265- }
9266-
92679335static AArch64SME::ToggleCondition
92689336getSMToggleCondition(const SMECallAttrs &CallAttrs) {
92699337 if (!CallAttrs.caller().hasStreamingCompatibleInterface() ||
@@ -10023,33 +10091,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
1002310091 {Result, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
1002410092
1002510093 if (RequiresLazySave) {
10026- // Conditionally restore the lazy save using a pseudo node.
10027- RTLIB::Libcall LC = RTLIB::SMEABI_TPIDR2_RESTORE;
10028- TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
10029- SDValue RegMask = DAG.getRegisterMask(
10030- TRI->getCallPreservedMask(MF, getLibcallCallingConv(LC)));
10031- SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
10032- getLibcallName(LC), getPointerTy(DAG.getDataLayout()));
10033- SDValue TPIDR2_EL0 = DAG.getNode(
10034- ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
10035- DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
10036- // Copy the address of the TPIDR2 block into X0 before 'calling' the
10037- // RESTORE_ZA pseudo.
10038- SDValue Glue;
10039- SDValue TPIDR2Block = DAG.getFrameIndex(
10040- TPIDR2.FrameIndex,
10041- DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
10042- Result = DAG.getCopyToReg(Result, DL, AArch64::X0, TPIDR2Block, Glue);
10043- Result =
10044- DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
10045- {Result, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64),
10046- RestoreRoutine, RegMask, Result.getValue(1)});
10047- // Finally reset the TPIDR2_EL0 register to 0.
10048- Result = DAG.getNode(
10049- ISD::INTRINSIC_VOID, DL, MVT::Other, Result,
10050- DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
10051- DAG.getConstant(0, DL, MVT::i64));
10052- TPIDR2.Uses++;
10094+ Result = emitRestoreZALazySave(Result, DL, *this, *TRI, *FuncInfo, DAG);
1005310095 } else if (RequiresSaveAllZA) {
1005410096 Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Result,
1005510097 /*IsSave=*/false);
0 commit comments