Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -3576,6 +3576,10 @@ class LLVM_ABI TargetLoweringBase {
return nullptr;
}

const RTLIB::RuntimeLibcallsInfo &getRuntimeLibcallsInfo() const {
return Libcalls;
}

void setLibcallImpl(RTLIB::Libcall Call, RTLIB::LibcallImpl Impl) {
Libcalls.setLibcallImpl(Call, Impl);
}
Expand Down
14 changes: 8 additions & 6 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9002,12 +9002,12 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI,
}

static SMECallAttrs
getSMECallAttrs(const Function &Caller, const AArch64TargetLowering &TLI,
getSMECallAttrs(const Function &Caller, const RTLIB::RuntimeLibcallsInfo &RTLCI,
const TargetLowering::CallLoweringInfo &CLI) {
if (CLI.CB)
return SMECallAttrs(*CLI.CB, &TLI);
return SMECallAttrs(*CLI.CB, &RTLCI);
if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(ES->getSymbol(), TLI));
return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(ES->getSymbol(), RTLCI));
return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(SMEAttrs::Normal));
}

Expand All @@ -9029,7 +9029,8 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(

// SME Streaming functions are not eligible for TCO as they may require
// the streaming mode or ZA to be restored after returning from the call.
SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, *this, CLI);
SMECallAttrs CallAttrs =
getSMECallAttrs(CallerF, getRuntimeLibcallsInfo(), CLI);
if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
CallAttrs.requiresPreservingAllZAState() ||
CallAttrs.caller().hasStreamingBody())
Expand Down Expand Up @@ -9454,7 +9455,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
}

// Determine whether we need any streaming mode changes.
SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), *this, CLI);
SMECallAttrs CallAttrs =
getSMECallAttrs(MF.getFunction(), getRuntimeLibcallsInfo(), CLI);

std::optional<unsigned> ZAMarkerNode;
bool UseNewSMEABILowering = getTM().useNewSMEABILowering();
Expand Down Expand Up @@ -29818,7 +29820,7 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {

// Checks to allow the use of SME instructions
if (auto *Base = dyn_cast<CallBase>(&Inst)) {
auto CallAttrs = SMECallAttrs(*Base, this);
auto CallAttrs = SMECallAttrs(*Base, &getRuntimeLibcallsInfo());
if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
CallAttrs.requiresPreservingZT0() ||
CallAttrs.requiresPreservingAllZAState())
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ static cl::opt<bool> EnableScalableAutovecInStreamingMode(
static bool isSMEABIRoutineCall(const CallInst &CI,
const AArch64TargetLowering &TLI) {
const auto *F = CI.getCalledFunction();
return F && SMEAttrs(F->getName(), TLI).isSMEABIRoutine();
return F &&
SMEAttrs(F->getName(), TLI.getRuntimeLibcallsInfo()).isSMEABIRoutine();
}

/// Returns true if the function has explicit operations that can only be
Expand Down Expand Up @@ -355,7 +356,7 @@ AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call,
// change only once and avoid inlining of G into F.

SMEAttrs FAttrs(*F);
SMECallAttrs CallAttrs(Call, getTLI());
SMECallAttrs CallAttrs(Call, &getTLI()->getRuntimeLibcallsInfo());

if (SMECallAttrs(FAttrs, CallAttrs.callee()).requiresSMChange()) {
if (F == Call.getCaller()) // (1)
Expand Down
9 changes: 5 additions & 4 deletions llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
}

void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName,
const AArch64TargetLowering &TLI) {
RTLIB::LibcallImpl Impl = TLI.getSupportedLibcallImpl(FuncName);
const RTLIB::RuntimeLibcallsInfo &RTLCI) {
RTLIB::LibcallImpl Impl = RTLCI.getSupportedLibcallImpl(FuncName);
if (Impl == RTLIB::Unsupported)
return;
unsigned KnownAttrs = SMEAttrs::Normal;
Expand Down Expand Up @@ -124,11 +124,12 @@ bool SMECallAttrs::requiresSMChange() const {
return true;
}

SMECallAttrs::SMECallAttrs(const CallBase &CB, const AArch64TargetLowering *TLI)
SMECallAttrs::SMECallAttrs(const CallBase &CB,
const RTLIB::RuntimeLibcallsInfo *RTLCI)
: CallerFn(*CB.getFunction()), CalledFn(SMEAttrs::Normal),
Callsite(CB.getAttributes()), IsIndirect(CB.isIndirectCall()) {
if (auto *CalledFunction = CB.getCalledFunction())
CalledFn = SMEAttrs(*CalledFunction, TLI);
CalledFn = SMEAttrs(*CalledFunction, RTLCI);

// An `invoke` of an agnostic ZA function may not return normally (it may
// resume in an exception block). In this case, it acts like a private ZA
Expand Down
17 changes: 10 additions & 7 deletions llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
#include "llvm/IR/Function.h"

namespace llvm {
namespace RTLIB {
struct RuntimeLibcallsInfo;
}

class AArch64TargetLowering;

Expand Down Expand Up @@ -52,14 +55,14 @@ class SMEAttrs {

SMEAttrs() = default;
SMEAttrs(unsigned Mask) { set(Mask); }
SMEAttrs(const Function &F, const AArch64TargetLowering *TLI = nullptr)
SMEAttrs(const Function &F, const RTLIB::RuntimeLibcallsInfo *RTLCI = nullptr)
: SMEAttrs(F.getAttributes()) {
if (TLI)
addKnownFunctionAttrs(F.getName(), *TLI);
if (RTLCI)
addKnownFunctionAttrs(F.getName(), *RTLCI);
}
SMEAttrs(const AttributeList &L);
SMEAttrs(StringRef FuncName, const AArch64TargetLowering &TLI) {
addKnownFunctionAttrs(FuncName, TLI);
SMEAttrs(StringRef FuncName, const RTLIB::RuntimeLibcallsInfo &RTLCI) {
addKnownFunctionAttrs(FuncName, RTLCI);
};

void set(unsigned M, bool Enable = true) {
Expand Down Expand Up @@ -157,7 +160,7 @@ class SMEAttrs {

private:
void addKnownFunctionAttrs(StringRef FuncName,
const AArch64TargetLowering &TLI);
const RTLIB::RuntimeLibcallsInfo &RTLCI);
void validate() const;
};

Expand All @@ -175,7 +178,7 @@ class SMECallAttrs {
SMEAttrs Callsite = SMEAttrs::Normal)
: CallerFn(Caller), CalledFn(Callee), Callsite(Callsite) {}

SMECallAttrs(const CallBase &CB, const AArch64TargetLowering *TLI);
SMECallAttrs(const CallBase &CB, const RTLIB::RuntimeLibcallsInfo *RTLCI);

SMEAttrs &caller() { return CallerFn; }
SMEAttrs &callee() { return IsIndirect ? Callsite : CalledFn; }
Expand Down
Loading