@@ -8086,13 +8086,76 @@ static SDValue getZT0FrameIndex(MachineFrameInfo &MFI,
8086
8086
DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
8087
8087
}
8088
8088
8089
+ // Emit a call to __arm_sme_save or __arm_sme_restore.
8090
+ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
8091
+ SelectionDAG &DAG,
8092
+ AArch64FunctionInfo *Info, SDLoc DL,
8093
+ SDValue Chain, bool IsSave) {
8094
+ MachineFunction &MF = DAG.getMachineFunction();
8095
+ AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
8096
+ FuncInfo->setSMESaveBufferUsed();
8097
+ TargetLowering::ArgListTy Args;
8098
+ Args.emplace_back(
8099
+ DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64),
8100
+ PointerType::getUnqual(*DAG.getContext()));
8101
+
8102
+ RTLIB::Libcall LC =
8103
+ IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE;
8104
+ SDValue Callee = DAG.getExternalSymbol(TLI.getLibcallName(LC),
8105
+ TLI.getPointerTy(DAG.getDataLayout()));
8106
+ auto *RetTy = Type::getVoidTy(*DAG.getContext());
8107
+ TargetLowering::CallLoweringInfo CLI(DAG);
8108
+ CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
8109
+ TLI.getLibcallCallingConv(LC), RetTy, Callee, std::move(Args));
8110
+ return TLI.LowerCallTo(CLI).second;
8111
+ }
8112
+
8113
+ static SDValue emitRestoreZALazySave(SDValue Chain, SDLoc DL,
8114
+ const AArch64TargetLowering &TLI,
8115
+ const AArch64RegisterInfo &TRI,
8116
+ AArch64FunctionInfo &FuncInfo,
8117
+ SelectionDAG &DAG) {
8118
+ // Conditionally restore the lazy save using a pseudo node.
8119
+ RTLIB::Libcall LC = RTLIB::SMEABI_TPIDR2_RESTORE;
8120
+ TPIDR2Object &TPIDR2 = FuncInfo.getTPIDR2Obj();
8121
+ SDValue RegMask = DAG.getRegisterMask(TRI.getCallPreservedMask(
8122
+ DAG.getMachineFunction(), TLI.getLibcallCallingConv(LC)));
8123
+ SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
8124
+ TLI.getLibcallName(LC), TLI.getPointerTy(DAG.getDataLayout()));
8125
+ SDValue TPIDR2_EL0 = DAG.getNode(
8126
+ ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Chain,
8127
+ DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
8128
+ // Copy the address of the TPIDR2 block into X0 before 'calling' the
8129
+ // RESTORE_ZA pseudo.
8130
+ SDValue Glue;
8131
+ SDValue TPIDR2Block = DAG.getFrameIndex(
8132
+ TPIDR2.FrameIndex,
8133
+ DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
8134
+ Chain = DAG.getCopyToReg(Chain, DL, AArch64::X0, TPIDR2Block, Glue);
8135
+ Chain =
8136
+ DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
8137
+ {Chain, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64),
8138
+ RestoreRoutine, RegMask, Chain.getValue(1)});
8139
+ // Finally reset the TPIDR2_EL0 register to 0.
8140
+ Chain = DAG.getNode(
8141
+ ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
8142
+ DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
8143
+ DAG.getConstant(0, DL, MVT::i64));
8144
+ TPIDR2.Uses++;
8145
+ return Chain;
8146
+ }
8147
+
8089
8148
SDValue AArch64TargetLowering::lowerEHPadEntry(SDValue Chain, SDLoc const &DL,
8090
8149
SelectionDAG &DAG) const {
8091
8150
assert(Chain.getOpcode() == ISD::EntryToken && "Unexpected Chain value");
8092
8151
SDValue Glue = Chain.getValue(1);
8093
8152
8094
8153
MachineFunction &MF = DAG.getMachineFunction();
8095
- SMEAttrs SMEFnAttrs = MF.getInfo<AArch64FunctionInfo>()->getSMEFnAttrs();
8154
+ auto &FuncInfo = *MF.getInfo<AArch64FunctionInfo>();
8155
+ auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
8156
+ const AArch64RegisterInfo &TRI = *Subtarget.getRegisterInfo();
8157
+
8158
+ SMEAttrs SMEFnAttrs = FuncInfo.getSMEFnAttrs();
8096
8159
8097
8160
// The following conditions are true on entry to an exception handler:
8098
8161
// - PSTATE.SM is 0.
@@ -8107,14 +8170,43 @@ SDValue AArch64TargetLowering::lowerEHPadEntry(SDValue Chain, SDLoc const &DL,
8107
8170
// These mode changes are usually optimized away in catch blocks as they
8108
8171
// occur before the __cxa_begin_catch (which is a non-streaming function),
8109
8172
// but are necessary in some cases (such as for cleanups).
8173
+ //
8174
+ // Additionally, if the function has ZA or ZT0 state, we must restore it.
8110
8175
8176
+ // [COND_]SMSTART SM
8111
8177
if (SMEFnAttrs.hasStreamingInterfaceOrBody())
8112
- return changeStreamingMode(DAG, DL, /*Enable=*/true, Chain,
8113
- /*Glue*/ Glue, AArch64SME::Always);
8178
+ Chain = changeStreamingMode(DAG, DL, /*Enable=*/true, Chain,
8179
+ /*Glue*/ Glue, AArch64SME::Always);
8180
+ else if (SMEFnAttrs.hasStreamingCompatibleInterface())
8181
+ Chain = changeStreamingMode(DAG, DL, /*Enable=*/true, Chain, Glue,
8182
+ AArch64SME::IfCallerIsStreaming);
8183
+
8184
+ if (getTM().useNewSMEABILowering())
8185
+ return Chain;
8114
8186
8115
- if (SMEFnAttrs.hasStreamingCompatibleInterface())
8116
- return changeStreamingMode(DAG, DL, /*Enable=*/true, Chain, Glue,
8117
- AArch64SME::IfCallerIsStreaming);
8187
+ if (SMEFnAttrs.hasAgnosticZAInterface()) {
8188
+ // Restore full ZA
8189
+ Chain = emitSMEStateSaveRestore(*this, DAG, &FuncInfo, DL, Chain,
8190
+ /*IsSave=*/false);
8191
+ } else if (SMEFnAttrs.hasZAState() || SMEFnAttrs.hasZT0State()) {
8192
+ // SMSTART ZA
8193
+ Chain = DAG.getNode(
8194
+ AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Chain,
8195
+ DAG.getTargetConstant(int32_t(AArch64SVCR::SVCRZA), DL, MVT::i32));
8196
+
8197
+ // Restore ZT0
8198
+ if (SMEFnAttrs.hasZT0State()) {
8199
+ SDValue ZT0FrameIndex =
8200
+ getZT0FrameIndex(MF.getFrameInfo(), FuncInfo, DAG);
8201
+ Chain =
8202
+ DAG.getNode(AArch64ISD::RESTORE_ZT, DL, DAG.getVTList(MVT::Other),
8203
+ {Chain, DAG.getConstant(0, DL, MVT::i32), ZT0FrameIndex});
8204
+ }
8205
+
8206
+ // Restore ZA
8207
+ if (SMEFnAttrs.hasZAState())
8208
+ Chain = emitRestoreZALazySave(Chain, DL, *this, TRI, FuncInfo, DAG);
8209
+ }
8118
8210
8119
8211
return Chain;
8120
8212
}
@@ -9232,30 +9324,6 @@ SDValue AArch64TargetLowering::changeStreamingMode(
9232
9324
return GetCheckVL(SMChange.getValue(0), SMChange.getValue(1));
9233
9325
}
9234
9326
9235
- // Emit a call to __arm_sme_save or __arm_sme_restore.
9236
- static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
9237
- SelectionDAG &DAG,
9238
- AArch64FunctionInfo *Info, SDLoc DL,
9239
- SDValue Chain, bool IsSave) {
9240
- MachineFunction &MF = DAG.getMachineFunction();
9241
- AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
9242
- FuncInfo->setSMESaveBufferUsed();
9243
- TargetLowering::ArgListTy Args;
9244
- Args.emplace_back(
9245
- DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64),
9246
- PointerType::getUnqual(*DAG.getContext()));
9247
-
9248
- RTLIB::Libcall LC =
9249
- IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE;
9250
- SDValue Callee = DAG.getExternalSymbol(TLI.getLibcallName(LC),
9251
- TLI.getPointerTy(DAG.getDataLayout()));
9252
- auto *RetTy = Type::getVoidTy(*DAG.getContext());
9253
- TargetLowering::CallLoweringInfo CLI(DAG);
9254
- CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
9255
- TLI.getLibcallCallingConv(LC), RetTy, Callee, std::move(Args));
9256
- return TLI.LowerCallTo(CLI).second;
9257
- }
9258
-
9259
9327
static AArch64SME::ToggleCondition
9260
9328
getSMToggleCondition(const SMECallAttrs &CallAttrs) {
9261
9329
if (!CallAttrs.caller().hasStreamingCompatibleInterface() ||
@@ -10015,33 +10083,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
10015
10083
{Result, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
10016
10084
10017
10085
if (RequiresLazySave) {
10018
- // Conditionally restore the lazy save using a pseudo node.
10019
- RTLIB::Libcall LC = RTLIB::SMEABI_TPIDR2_RESTORE;
10020
- TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
10021
- SDValue RegMask = DAG.getRegisterMask(
10022
- TRI->getCallPreservedMask(MF, getLibcallCallingConv(LC)));
10023
- SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
10024
- getLibcallName(LC), getPointerTy(DAG.getDataLayout()));
10025
- SDValue TPIDR2_EL0 = DAG.getNode(
10026
- ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
10027
- DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
10028
- // Copy the address of the TPIDR2 block into X0 before 'calling' the
10029
- // RESTORE_ZA pseudo.
10030
- SDValue Glue;
10031
- SDValue TPIDR2Block = DAG.getFrameIndex(
10032
- TPIDR2.FrameIndex,
10033
- DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
10034
- Result = DAG.getCopyToReg(Result, DL, AArch64::X0, TPIDR2Block, Glue);
10035
- Result =
10036
- DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
10037
- {Result, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64),
10038
- RestoreRoutine, RegMask, Result.getValue(1)});
10039
- // Finally reset the TPIDR2_EL0 register to 0.
10040
- Result = DAG.getNode(
10041
- ISD::INTRINSIC_VOID, DL, MVT::Other, Result,
10042
- DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
10043
- DAG.getConstant(0, DL, MVT::i64));
10044
- TPIDR2.Uses++;
10086
+ Result = emitRestoreZALazySave(Result, DL, *this, *TRI, *FuncInfo, DAG);
10045
10087
} else if (RequiresSaveAllZA) {
10046
10088
Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Result,
10047
10089
/*IsSave=*/false);
0 commit comments