Skip to content

Commit 271688b

Browse files
authored
[AArch64][SME] Port all SME routines to RuntimeLibcalls (#152505)
This updates everywhere we emit/check an SME routines to use RuntimeLibcalls to get the function name and calling convention. Note: RuntimeLibcallEmitter had some issues with emitting non-unique variable names for sets of libcalls, so I tweaked the output to avoid the need for variables.
1 parent b9138bd commit 271688b

14 files changed

+217
-185
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3560,6 +3560,12 @@ class LLVM_ABI TargetLoweringBase {
35603560
return Libcalls.getLibcallImplName(Call);
35613561
}
35623562

3563+
/// Check if this is valid libcall for the current module, otherwise
3564+
/// RTLIB::Unsupported.
3565+
RTLIB::LibcallImpl getSupportedLibcallImpl(StringRef FuncName) const {
3566+
return Libcalls.getSupportedLibcallImpl(FuncName);
3567+
}
3568+
35633569
const char *getMemcpyName() const { return Libcalls.getMemcpyName(); }
35643570

35653571
/// Get the comparison predicate that's to be used to test the result of the

llvm/include/llvm/IR/RuntimeLibcalls.td

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,17 @@ multiclass LibmLongDoubleLibCall<string libcall_basename = !toupper(NAME),
406406
def SC_MEMCPY : RuntimeLibcall;
407407
def SC_MEMMOVE : RuntimeLibcall;
408408
def SC_MEMSET : RuntimeLibcall;
409+
def SC_MEMCHR: RuntimeLibcall;
410+
411+
// AArch64 SME ABI calls
412+
def SMEABI_SME_STATE : RuntimeLibcall;
413+
def SMEABI_TPIDR2_SAVE : RuntimeLibcall;
414+
def SMEABI_ZA_DISABLE : RuntimeLibcall;
415+
def SMEABI_TPIDR2_RESTORE : RuntimeLibcall;
416+
def SMEABI_GET_CURRENT_VG : RuntimeLibcall;
417+
def SMEABI_SME_STATE_SIZE : RuntimeLibcall;
418+
def SMEABI_SME_SAVE : RuntimeLibcall;
419+
def SMEABI_SME_RESTORE : RuntimeLibcall;
409420

410421
// ARM EABI calls
411422
def AEABI_MEMCPY4 : RuntimeLibcall; // Align 4
@@ -1223,8 +1234,35 @@ defset list<RuntimeLibcallImpl> AArch64LibcallImpls = {
12231234
def __arm_sc_memcpy : RuntimeLibcallImpl<SC_MEMCPY>;
12241235
def __arm_sc_memmove : RuntimeLibcallImpl<SC_MEMMOVE>;
12251236
def __arm_sc_memset : RuntimeLibcallImpl<SC_MEMSET>;
1237+
def __arm_sc_memchr : RuntimeLibcallImpl<SC_MEMCHR>;
12261238
} // End AArch64LibcallImpls
12271239

1240+
def __arm_sme_state : RuntimeLibcallImpl<SMEABI_SME_STATE>;
1241+
def __arm_tpidr2_save : RuntimeLibcallImpl<SMEABI_TPIDR2_SAVE>;
1242+
def __arm_za_disable : RuntimeLibcallImpl<SMEABI_ZA_DISABLE>;
1243+
def __arm_tpidr2_restore : RuntimeLibcallImpl<SMEABI_TPIDR2_RESTORE>;
1244+
def __arm_get_current_vg : RuntimeLibcallImpl<SMEABI_GET_CURRENT_VG>;
1245+
def __arm_sme_state_size : RuntimeLibcallImpl<SMEABI_SME_STATE_SIZE>;
1246+
def __arm_sme_save : RuntimeLibcallImpl<SMEABI_SME_SAVE>;
1247+
def __arm_sme_restore : RuntimeLibcallImpl<SMEABI_SME_RESTORE>;
1248+
1249+
def SMEABI_LibCalls_PreserveMost_From_X0 : LibcallsWithCC<(add
1250+
__arm_tpidr2_save,
1251+
__arm_za_disable,
1252+
__arm_tpidr2_restore),
1253+
SMEABI_PreserveMost_From_X0>;
1254+
1255+
def SMEABI_LibCalls_PreserveMost_From_X1 : LibcallsWithCC<(add
1256+
__arm_get_current_vg,
1257+
__arm_sme_state_size,
1258+
__arm_sme_save,
1259+
__arm_sme_restore),
1260+
SMEABI_PreserveMost_From_X1>;
1261+
1262+
def SMEABI_LibCalls_PreserveMost_From_X2 : LibcallsWithCC<(add
1263+
__arm_sme_state),
1264+
SMEABI_PreserveMost_From_X2>;
1265+
12281266
def isAArch64_ExceptArm64EC
12291267
: RuntimeLibcallPredicate<"(TT.isAArch64() && !TT.isWindowsArm64EC())">;
12301268
def isWindowsArm64EC : RuntimeLibcallPredicate<"TT.isWindowsArm64EC()">;
@@ -1244,7 +1282,10 @@ def AArch64SystemLibrary : SystemRuntimeLibrary<
12441282
LibmHasSinCosF32, LibmHasSinCosF64, LibmHasSinCosF128,
12451283
DefaultLibmExp10,
12461284
DefaultStackProtector,
1247-
SecurityCheckCookieIfWinMSVC)
1285+
SecurityCheckCookieIfWinMSVC,
1286+
SMEABI_LibCalls_PreserveMost_From_X0,
1287+
SMEABI_LibCalls_PreserveMost_From_X1,
1288+
SMEABI_LibCalls_PreserveMost_From_X2)
12481289
>;
12491290

12501291
// Prepend a # to every name

llvm/include/llvm/IR/RuntimeLibcallsImpl.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ def ARM_AAPCS : LibcallCallingConv<[{CallingConv::ARM_AAPCS}]>;
3636
def ARM_AAPCS_VFP : LibcallCallingConv<[{CallingConv::ARM_AAPCS_VFP}]>;
3737
def X86_STDCALL : LibcallCallingConv<[{CallingConv::X86_StdCall}]>;
3838
def AVR_BUILTIN : LibcallCallingConv<[{CallingConv::AVR_BUILTIN}]>;
39+
def SMEABI_PreserveMost_From_X0 : LibcallCallingConv<[{CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0}]>;
40+
def SMEABI_PreserveMost_From_X1 : LibcallCallingConv<[{CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1}]>;
41+
def SMEABI_PreserveMost_From_X2 : LibcallCallingConv<[{CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2}]>;
3942

4043
/// Abstract definition for functionality the compiler may need to
4144
/// emit a call to. Emits the RTLIB::Libcall enum - This enum defines

llvm/lib/Target/AArch64/AArch64FrameLowering.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1487,8 +1487,11 @@ bool isVGInstruction(MachineBasicBlock::iterator MBBI) {
14871487

14881488
if (Opc == AArch64::BL) {
14891489
auto Op1 = MBBI->getOperand(0);
1490-
return Op1.isSymbol() &&
1491-
(StringRef(Op1.getSymbolName()) == "__arm_get_current_vg");
1490+
auto &TLI =
1491+
*MBBI->getMF()->getSubtarget<AArch64Subtarget>().getTargetLowering();
1492+
char const *GetCurrentVG =
1493+
TLI.getLibcallName(RTLIB::SMEABI_GET_CURRENT_VG);
1494+
return Op1.isSymbol() && StringRef(Op1.getSymbolName()) == GetCurrentVG;
14921495
}
14931496
}
14941497

@@ -3468,6 +3471,7 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
34683471
MachineBasicBlock &MBB, MachineBasicBlock::iterator MI,
34693472
ArrayRef<CalleeSavedInfo> CSI, const TargetRegisterInfo *TRI) const {
34703473
MachineFunction &MF = *MBB.getParent();
3474+
auto &TLI = *MF.getSubtarget<AArch64Subtarget>().getTargetLowering();
34713475
const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo();
34723476
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
34733477
bool NeedsWinCFI = needsWinCFI(MF);
@@ -3581,11 +3585,11 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
35813585
.addReg(AArch64::X0, RegState::Implicit)
35823586
.setMIFlag(MachineInstr::FrameSetup);
35833587

3584-
const uint32_t *RegMask = TRI->getCallPreservedMask(
3585-
MF,
3586-
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1);
3588+
RTLIB::Libcall LC = RTLIB::SMEABI_GET_CURRENT_VG;
3589+
const uint32_t *RegMask =
3590+
TRI->getCallPreservedMask(MF, TLI.getLibcallCallingConv(LC));
35873591
BuildMI(MBB, MI, DL, TII.get(AArch64::BL))
3588-
.addExternalSymbol("__arm_get_current_vg")
3592+
.addExternalSymbol(TLI.getLibcallName(LC))
35893593
.addRegMask(RegMask)
35903594
.addReg(AArch64::X0, RegState::ImplicitDefine)
35913595
.setMIFlag(MachineInstr::FrameSetup);

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3083,13 +3083,12 @@ AArch64TargetLowering::EmitGetSMESaveSize(MachineInstr &MI,
30833083
AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
30843084
const TargetInstrInfo *TII = Subtarget->getInstrInfo();
30853085
if (FuncInfo->isSMESaveBufferUsed()) {
3086+
RTLIB::Libcall LC = RTLIB::SMEABI_SME_STATE_SIZE;
30863087
const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
30873088
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
3088-
.addExternalSymbol("__arm_sme_state_size")
3089+
.addExternalSymbol(getLibcallName(LC))
30893090
.addReg(AArch64::X0, RegState::ImplicitDefine)
3090-
.addRegMask(TRI->getCallPreservedMask(
3091-
*MF, CallingConv::
3092-
AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
3091+
.addRegMask(TRI->getCallPreservedMask(*MF, getLibcallCallingConv(LC)));
30933092
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
30943093
MI.getOperand(0).getReg())
30953094
.addReg(AArch64::X0);
@@ -3109,13 +3108,12 @@ AArch64TargetLowering::EmitEntryPStateSM(MachineInstr &MI,
31093108
const TargetInstrInfo *TII = Subtarget->getInstrInfo();
31103109
Register ResultReg = MI.getOperand(0).getReg();
31113110
if (FuncInfo->isPStateSMRegUsed()) {
3111+
RTLIB::Libcall LC = RTLIB::SMEABI_SME_STATE;
31123112
const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
31133113
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
3114-
.addExternalSymbol("__arm_sme_state")
3114+
.addExternalSymbol(getLibcallName(LC))
31153115
.addReg(AArch64::X0, RegState::ImplicitDefine)
3116-
.addRegMask(TRI->getCallPreservedMask(
3117-
*MF, CallingConv::
3118-
AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2));
3116+
.addRegMask(TRI->getCallPreservedMask(*MF, getLibcallCallingConv(LC)));
31193117
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), ResultReg)
31203118
.addReg(AArch64::X0);
31213119
} else {
@@ -5739,15 +5737,15 @@ static SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) {
57395737
SDValue AArch64TargetLowering::getRuntimePStateSM(SelectionDAG &DAG,
57405738
SDValue Chain, SDLoc DL,
57415739
EVT VT) const {
5742-
SDValue Callee = DAG.getExternalSymbol("__arm_sme_state",
5740+
RTLIB::Libcall LC = RTLIB::SMEABI_SME_STATE;
5741+
SDValue Callee = DAG.getExternalSymbol(getLibcallName(LC),
57435742
getPointerTy(DAG.getDataLayout()));
57445743
Type *Int64Ty = Type::getInt64Ty(*DAG.getContext());
57455744
Type *RetTy = StructType::get(Int64Ty, Int64Ty);
57465745
TargetLowering::CallLoweringInfo CLI(DAG);
57475746
ArgListTy Args;
57485747
CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
5749-
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2,
5750-
RetTy, Callee, std::move(Args));
5748+
getLibcallCallingConv(LC), RetTy, Callee, std::move(Args));
57515749
std::pair<SDValue, SDValue> CallResult = LowerCallTo(CLI);
57525750
SDValue Mask = DAG.getConstant(/*PSTATE.SM*/ 1, DL, MVT::i64);
57535751
return DAG.getNode(ISD::AND, DL, MVT::i64, CallResult.first.getOperand(0),
@@ -8600,12 +8598,12 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI,
86008598
}
86018599

86028600
static SMECallAttrs
8603-
getSMECallAttrs(const Function &Caller,
8601+
getSMECallAttrs(const Function &Caller, const TargetLowering &TLI,
86048602
const TargetLowering::CallLoweringInfo &CLI) {
86058603
if (CLI.CB)
8606-
return SMECallAttrs(*CLI.CB);
8604+
return SMECallAttrs(*CLI.CB, &TLI);
86078605
if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
8608-
return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(ES->getSymbol()));
8606+
return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(ES->getSymbol(), TLI));
86098607
return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(SMEAttrs::Normal));
86108608
}
86118609

@@ -8627,7 +8625,7 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
86278625

86288626
// SME Streaming functions are not eligible for TCO as they may require
86298627
// the streaming mode or ZA to be restored after returning from the call.
8630-
SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, CLI);
8628+
SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, *this, CLI);
86318629
if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
86328630
CallAttrs.requiresPreservingAllZAState() ||
86338631
CallAttrs.caller().hasStreamingBody())
@@ -8921,14 +8919,14 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
89218919
DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64);
89228920
Args.push_back(Entry);
89238921

8924-
SDValue Callee =
8925-
DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore",
8926-
TLI.getPointerTy(DAG.getDataLayout()));
8922+
RTLIB::Libcall LC =
8923+
IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE;
8924+
SDValue Callee = DAG.getExternalSymbol(TLI.getLibcallName(LC),
8925+
TLI.getPointerTy(DAG.getDataLayout()));
89278926
auto *RetTy = Type::getVoidTy(*DAG.getContext());
89288927
TargetLowering::CallLoweringInfo CLI(DAG);
89298928
CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
8930-
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1, RetTy,
8931-
Callee, std::move(Args));
8929+
TLI.getLibcallCallingConv(LC), RetTy, Callee, std::move(Args));
89328930
return TLI.LowerCallTo(CLI).second;
89338931
}
89348932

@@ -9116,7 +9114,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91169114
}
91179115

91189116
// Determine whether we need any streaming mode changes.
9119-
SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
9117+
SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), *this, CLI);
91209118

91219119
auto DescribeCallsite =
91229120
[&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
@@ -9693,11 +9691,12 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96939691

96949692
if (RequiresLazySave) {
96959693
// Conditionally restore the lazy save using a pseudo node.
9694+
RTLIB::Libcall LC = RTLIB::SMEABI_TPIDR2_RESTORE;
96969695
TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
96979696
SDValue RegMask = DAG.getRegisterMask(
9698-
TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
9697+
TRI->getCallPreservedMask(MF, getLibcallCallingConv(LC)));
96999698
SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
9700-
"__arm_tpidr2_restore", getPointerTy(DAG.getDataLayout()));
9699+
getLibcallName(LC), getPointerTy(DAG.getDataLayout()));
97019700
SDValue TPIDR2_EL0 = DAG.getNode(
97029701
ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
97039702
DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
@@ -29036,7 +29035,7 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
2903629035

2903729036
// Checks to allow the use of SME instructions
2903829037
if (auto *Base = dyn_cast<CallBase>(&Inst)) {
29039-
auto CallAttrs = SMECallAttrs(*Base);
29038+
auto CallAttrs = SMECallAttrs(*Base, this);
2904029039
if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
2904129040
CallAttrs.requiresPreservingZT0() ||
2904229041
CallAttrs.requiresPreservingAllZAState())

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -220,20 +220,16 @@ static cl::opt<bool> EnableFixedwidthAutovecInStreamingMode(
220220
static cl::opt<bool> EnableScalableAutovecInStreamingMode(
221221
"enable-scalable-autovec-in-streaming-mode", cl::init(false), cl::Hidden);
222222

223-
static bool isSMEABIRoutineCall(const CallInst &CI) {
223+
static bool isSMEABIRoutineCall(const CallInst &CI, const TargetLowering &TLI) {
224224
const auto *F = CI.getCalledFunction();
225-
return F && StringSwitch<bool>(F->getName())
226-
.Case("__arm_sme_state", true)
227-
.Case("__arm_tpidr2_save", true)
228-
.Case("__arm_tpidr2_restore", true)
229-
.Case("__arm_za_disable", true)
230-
.Default(false);
225+
return F && SMEAttrs(F->getName(), TLI).isSMEABIRoutine();
231226
}
232227

233228
/// Returns true if the function has explicit operations that can only be
234229
/// lowered using incompatible instructions for the selected mode. This also
235230
/// returns true if the function F may use or modify ZA state.
236-
static bool hasPossibleIncompatibleOps(const Function *F) {
231+
static bool hasPossibleIncompatibleOps(const Function *F,
232+
const TargetLowering &TLI) {
237233
for (const BasicBlock &BB : *F) {
238234
for (const Instruction &I : BB) {
239235
// Be conservative for now and assume that any call to inline asm or to
@@ -242,7 +238,7 @@ static bool hasPossibleIncompatibleOps(const Function *F) {
242238
// all native LLVM instructions can be lowered to compatible instructions.
243239
if (isa<CallInst>(I) && !I.isDebugOrPseudoInst() &&
244240
(cast<CallInst>(I).isInlineAsm() || isa<IntrinsicInst>(I) ||
245-
isSMEABIRoutineCall(cast<CallInst>(I))))
241+
isSMEABIRoutineCall(cast<CallInst>(I), TLI)))
246242
return true;
247243
}
248244
}
@@ -290,7 +286,7 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
290286
if (CallAttrs.requiresLazySave() || CallAttrs.requiresSMChange() ||
291287
CallAttrs.requiresPreservingZT0() ||
292288
CallAttrs.requiresPreservingAllZAState()) {
293-
if (hasPossibleIncompatibleOps(Callee))
289+
if (hasPossibleIncompatibleOps(Callee, *getTLI()))
294290
return false;
295291
}
296292

@@ -357,7 +353,7 @@ AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call,
357353
// change only once and avoid inlining of G into F.
358354

359355
SMEAttrs FAttrs(*F);
360-
SMECallAttrs CallAttrs(Call);
356+
SMECallAttrs CallAttrs(Call, getTLI());
361357

362358
if (SMECallAttrs(FAttrs, CallAttrs.callee()).requiresSMChange()) {
363359
if (F == Call.getCaller()) // (1)

0 commit comments

Comments
 (0)