Skip to content

Commit 2ee4c4b

Browse files
committed
[AArch64][SME][SDAG] Add basic support for exception handling
This patch adds basic support for exception handling to SelectionDAG for ZT0, ZA, and agnostic ZA state. This works based on the following assumptions: - To throw an exception requires calling into the runtime * The which will be a private ZA call (that commits the lazy save) - Therefore, as noted in https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#exceptions we will always enter the EH block with PSTATE.ZA=0 and TPIDR2_EL0=null, so we can emit a restore of ZA/ZT0. Note: This patch does not handle all cases yet. Currently, there is no support for committing agnostic ZA state before `invoke`s, regardless of whether the callee is also agnostic (to ensure ZA state is saved on all normal returns).
1 parent 9690a71 commit 2ee4c4b

File tree

2 files changed

+568
-68
lines changed

2 files changed

+568
-68
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 99 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -8094,13 +8094,76 @@ static SDValue getZT0FrameIndex(MachineFrameInfo &MFI,
80948094
DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
80958095
}
80968096

8097+
// Emit a call to __arm_sme_save or __arm_sme_restore.
8098+
static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
8099+
SelectionDAG &DAG,
8100+
AArch64FunctionInfo *Info, SDLoc DL,
8101+
SDValue Chain, bool IsSave) {
8102+
MachineFunction &MF = DAG.getMachineFunction();
8103+
AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
8104+
FuncInfo->setSMESaveBufferUsed();
8105+
TargetLowering::ArgListTy Args;
8106+
Args.emplace_back(
8107+
DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64),
8108+
PointerType::getUnqual(*DAG.getContext()));
8109+
8110+
RTLIB::Libcall LC =
8111+
IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE;
8112+
SDValue Callee = DAG.getExternalSymbol(TLI.getLibcallName(LC),
8113+
TLI.getPointerTy(DAG.getDataLayout()));
8114+
auto *RetTy = Type::getVoidTy(*DAG.getContext());
8115+
TargetLowering::CallLoweringInfo CLI(DAG);
8116+
CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
8117+
TLI.getLibcallCallingConv(LC), RetTy, Callee, std::move(Args));
8118+
return TLI.LowerCallTo(CLI).second;
8119+
}
8120+
8121+
static SDValue emitRestoreZALazySave(SDValue Chain, SDLoc DL,
8122+
const AArch64TargetLowering &TLI,
8123+
const AArch64RegisterInfo &TRI,
8124+
AArch64FunctionInfo &FuncInfo,
8125+
SelectionDAG &DAG) {
8126+
// Conditionally restore the lazy save using a pseudo node.
8127+
RTLIB::Libcall LC = RTLIB::SMEABI_TPIDR2_RESTORE;
8128+
TPIDR2Object &TPIDR2 = FuncInfo.getTPIDR2Obj();
8129+
SDValue RegMask = DAG.getRegisterMask(TRI.getCallPreservedMask(
8130+
DAG.getMachineFunction(), TLI.getLibcallCallingConv(LC)));
8131+
SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
8132+
TLI.getLibcallName(LC), TLI.getPointerTy(DAG.getDataLayout()));
8133+
SDValue TPIDR2_EL0 = DAG.getNode(
8134+
ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Chain,
8135+
DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
8136+
// Copy the address of the TPIDR2 block into X0 before 'calling' the
8137+
// RESTORE_ZA pseudo.
8138+
SDValue Glue;
8139+
SDValue TPIDR2Block = DAG.getFrameIndex(
8140+
TPIDR2.FrameIndex,
8141+
DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
8142+
Chain = DAG.getCopyToReg(Chain, DL, AArch64::X0, TPIDR2Block, Glue);
8143+
Chain =
8144+
DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
8145+
{Chain, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64),
8146+
RestoreRoutine, RegMask, Chain.getValue(1)});
8147+
// Finally reset the TPIDR2_EL0 register to 0.
8148+
Chain = DAG.getNode(
8149+
ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
8150+
DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
8151+
DAG.getConstant(0, DL, MVT::i64));
8152+
TPIDR2.Uses++;
8153+
return Chain;
8154+
}
8155+
80978156
SDValue AArch64TargetLowering::lowerEHPadEntry(SDValue Chain, SDLoc const &DL,
80988157
SelectionDAG &DAG) const {
80998158
assert(Chain.getOpcode() == ISD::EntryToken && "Unexpected Chain value");
81008159
SDValue Glue = Chain.getValue(1);
81018160

81028161
MachineFunction &MF = DAG.getMachineFunction();
8103-
SMEAttrs SMEFnAttrs = MF.getInfo<AArch64FunctionInfo>()->getSMEFnAttrs();
8162+
auto &FuncInfo = *MF.getInfo<AArch64FunctionInfo>();
8163+
auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
8164+
const AArch64RegisterInfo &TRI = *Subtarget.getRegisterInfo();
8165+
8166+
SMEAttrs SMEFnAttrs = FuncInfo.getSMEFnAttrs();
81048167

81058168
// The following conditions are true on entry to an exception handler:
81068169
// - PSTATE.SM is 0.
@@ -8115,14 +8178,43 @@ SDValue AArch64TargetLowering::lowerEHPadEntry(SDValue Chain, SDLoc const &DL,
81158178
// These mode changes are usually optimized away in catch blocks as they
81168179
// occur before the __cxa_begin_catch (which is a non-streaming function),
81178180
// but are necessary in some cases (such as for cleanups).
8181+
//
8182+
// Additionally, if the function has ZA or ZT0 state, we must restore it.
81188183

8184+
// [COND_]SMSTART SM
81198185
if (SMEFnAttrs.hasStreamingInterfaceOrBody())
8120-
return changeStreamingMode(DAG, DL, /*Enable=*/true, Chain,
8121-
/*Glue*/ Glue, AArch64SME::Always);
8186+
Chain = changeStreamingMode(DAG, DL, /*Enable=*/true, Chain,
8187+
/*Glue*/ Glue, AArch64SME::Always);
8188+
else if (SMEFnAttrs.hasStreamingCompatibleInterface())
8189+
Chain = changeStreamingMode(DAG, DL, /*Enable=*/true, Chain, Glue,
8190+
AArch64SME::IfCallerIsStreaming);
8191+
8192+
if (getTM().useNewSMEABILowering())
8193+
return Chain;
81228194

8123-
if (SMEFnAttrs.hasStreamingCompatibleInterface())
8124-
return changeStreamingMode(DAG, DL, /*Enable=*/true, Chain, Glue,
8125-
AArch64SME::IfCallerIsStreaming);
8195+
if (SMEFnAttrs.hasAgnosticZAInterface()) {
8196+
// Restore full ZA
8197+
Chain = emitSMEStateSaveRestore(*this, DAG, &FuncInfo, DL, Chain,
8198+
/*IsSave=*/false);
8199+
} else if (SMEFnAttrs.hasZAState() || SMEFnAttrs.hasZT0State()) {
8200+
// SMSTART ZA
8201+
Chain = DAG.getNode(
8202+
AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Chain,
8203+
DAG.getTargetConstant(int32_t(AArch64SVCR::SVCRZA), DL, MVT::i32));
8204+
8205+
// Restore ZT0
8206+
if (SMEFnAttrs.hasZT0State()) {
8207+
SDValue ZT0FrameIndex =
8208+
getZT0FrameIndex(MF.getFrameInfo(), FuncInfo, DAG);
8209+
Chain =
8210+
DAG.getNode(AArch64ISD::RESTORE_ZT, DL, DAG.getVTList(MVT::Other),
8211+
{Chain, DAG.getConstant(0, DL, MVT::i32), ZT0FrameIndex});
8212+
}
8213+
8214+
// Restore ZA
8215+
if (SMEFnAttrs.hasZAState())
8216+
Chain = emitRestoreZALazySave(Chain, DL, *this, TRI, FuncInfo, DAG);
8217+
}
81268218

81278219
return Chain;
81288220
}
@@ -9240,30 +9332,6 @@ SDValue AArch64TargetLowering::changeStreamingMode(
92409332
return GetCheckVL(SMChange.getValue(0), SMChange.getValue(1));
92419333
}
92429334

9243-
// Emit a call to __arm_sme_save or __arm_sme_restore.
9244-
static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
9245-
SelectionDAG &DAG,
9246-
AArch64FunctionInfo *Info, SDLoc DL,
9247-
SDValue Chain, bool IsSave) {
9248-
MachineFunction &MF = DAG.getMachineFunction();
9249-
AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
9250-
FuncInfo->setSMESaveBufferUsed();
9251-
TargetLowering::ArgListTy Args;
9252-
Args.emplace_back(
9253-
DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64),
9254-
PointerType::getUnqual(*DAG.getContext()));
9255-
9256-
RTLIB::Libcall LC =
9257-
IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE;
9258-
SDValue Callee = DAG.getExternalSymbol(TLI.getLibcallName(LC),
9259-
TLI.getPointerTy(DAG.getDataLayout()));
9260-
auto *RetTy = Type::getVoidTy(*DAG.getContext());
9261-
TargetLowering::CallLoweringInfo CLI(DAG);
9262-
CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
9263-
TLI.getLibcallCallingConv(LC), RetTy, Callee, std::move(Args));
9264-
return TLI.LowerCallTo(CLI).second;
9265-
}
9266-
92679335
static AArch64SME::ToggleCondition
92689336
getSMToggleCondition(const SMECallAttrs &CallAttrs) {
92699337
if (!CallAttrs.caller().hasStreamingCompatibleInterface() ||
@@ -10023,33 +10091,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
1002310091
{Result, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
1002410092

1002510093
if (RequiresLazySave) {
10026-
// Conditionally restore the lazy save using a pseudo node.
10027-
RTLIB::Libcall LC = RTLIB::SMEABI_TPIDR2_RESTORE;
10028-
TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
10029-
SDValue RegMask = DAG.getRegisterMask(
10030-
TRI->getCallPreservedMask(MF, getLibcallCallingConv(LC)));
10031-
SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
10032-
getLibcallName(LC), getPointerTy(DAG.getDataLayout()));
10033-
SDValue TPIDR2_EL0 = DAG.getNode(
10034-
ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
10035-
DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
10036-
// Copy the address of the TPIDR2 block into X0 before 'calling' the
10037-
// RESTORE_ZA pseudo.
10038-
SDValue Glue;
10039-
SDValue TPIDR2Block = DAG.getFrameIndex(
10040-
TPIDR2.FrameIndex,
10041-
DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
10042-
Result = DAG.getCopyToReg(Result, DL, AArch64::X0, TPIDR2Block, Glue);
10043-
Result =
10044-
DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
10045-
{Result, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64),
10046-
RestoreRoutine, RegMask, Result.getValue(1)});
10047-
// Finally reset the TPIDR2_EL0 register to 0.
10048-
Result = DAG.getNode(
10049-
ISD::INTRINSIC_VOID, DL, MVT::Other, Result,
10050-
DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
10051-
DAG.getConstant(0, DL, MVT::i64));
10052-
TPIDR2.Uses++;
10094+
Result = emitRestoreZALazySave(Result, DL, *this, *TRI, *FuncInfo, DAG);
1005310095
} else if (RequiresSaveAllZA) {
1005410096
Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Result,
1005510097
/*IsSave=*/false);

0 commit comments

Comments
 (0)