Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -3576,6 +3576,10 @@ class LLVM_ABI TargetLoweringBase {
return nullptr;
}

const RTLIB::RuntimeLibcallsInfo &getRuntimeLibcallsInfo() const {
return Libcalls;
}

void setLibcallImpl(RTLIB::Libcall Call, RTLIB::LibcallImpl Impl) {
Libcalls.setLibcallImpl(Call, Impl);
}
Expand Down
14 changes: 8 additions & 6 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9002,12 +9002,12 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI,
}

static SMECallAttrs
getSMECallAttrs(const Function &Caller, const AArch64TargetLowering &TLI,
getSMECallAttrs(const Function &Caller, const RTLIB::RuntimeLibcallsInfo &RTLCI,
const TargetLowering::CallLoweringInfo &CLI) {
if (CLI.CB)
return SMECallAttrs(*CLI.CB, &TLI);
return SMECallAttrs(*CLI.CB, &RTLCI);
if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(ES->getSymbol(), TLI));
return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(ES->getSymbol(), RTLCI));
return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(SMEAttrs::Normal));
}

Expand All @@ -9029,7 +9029,8 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(

// SME Streaming functions are not eligible for TCO as they may require
// the streaming mode or ZA to be restored after returning from the call.
SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, *this, CLI);
SMECallAttrs CallAttrs =
getSMECallAttrs(CallerF, getRuntimeLibcallsInfo(), CLI);
if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
CallAttrs.requiresPreservingAllZAState() ||
CallAttrs.caller().hasStreamingBody())
Expand Down Expand Up @@ -9454,7 +9455,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
}

// Determine whether we need any streaming mode changes.
SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), *this, CLI);
SMECallAttrs CallAttrs =
getSMECallAttrs(MF.getFunction(), getRuntimeLibcallsInfo(), CLI);

std::optional<unsigned> ZAMarkerNode;
bool UseNewSMEABILowering = getTM().useNewSMEABILowering();
Expand Down Expand Up @@ -29818,7 +29820,7 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {

// Checks to allow the use of SME instructions
if (auto *Base = dyn_cast<CallBase>(&Inst)) {
auto CallAttrs = SMECallAttrs(*Base, this);
auto CallAttrs = SMECallAttrs(*Base, &getRuntimeLibcallsInfo());
if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
CallAttrs.requiresPreservingZT0() ||
CallAttrs.requiresPreservingAllZAState())
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ static cl::opt<bool> EnableScalableAutovecInStreamingMode(
static bool isSMEABIRoutineCall(const CallInst &CI,
const AArch64TargetLowering &TLI) {
const auto *F = CI.getCalledFunction();
return F && SMEAttrs(F->getName(), TLI).isSMEABIRoutine();
return F &&
SMEAttrs(F->getName(), TLI.getRuntimeLibcallsInfo()).isSMEABIRoutine();
}

/// Returns true if the function has explicit operations that can only be
Expand Down Expand Up @@ -355,7 +356,7 @@ AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call,
// change only once and avoid inlining of G into F.

SMEAttrs FAttrs(*F);
SMECallAttrs CallAttrs(Call, getTLI());
SMECallAttrs CallAttrs(Call, &getTLI()->getRuntimeLibcallsInfo());

if (SMECallAttrs(FAttrs, CallAttrs.callee()).requiresSMChange()) {
if (F == Call.getCaller()) // (1)
Expand Down
9 changes: 5 additions & 4 deletions llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
}

void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName,
const AArch64TargetLowering &TLI) {
RTLIB::LibcallImpl Impl = TLI.getSupportedLibcallImpl(FuncName);
const RTLIB::RuntimeLibcallsInfo &RTLCI) {
RTLIB::LibcallImpl Impl = RTLCI.getSupportedLibcallImpl(FuncName);
if (Impl == RTLIB::Unsupported)
return;
unsigned KnownAttrs = SMEAttrs::Normal;
Expand Down Expand Up @@ -124,11 +124,12 @@ bool SMECallAttrs::requiresSMChange() const {
return true;
}

SMECallAttrs::SMECallAttrs(const CallBase &CB, const AArch64TargetLowering *TLI)
SMECallAttrs::SMECallAttrs(const CallBase &CB,
const RTLIB::RuntimeLibcallsInfo *RTLCI)
: CallerFn(*CB.getFunction()), CalledFn(SMEAttrs::Normal),
Callsite(CB.getAttributes()), IsIndirect(CB.isIndirectCall()) {
if (auto *CalledFunction = CB.getCalledFunction())
CalledFn = SMEAttrs(*CalledFunction, TLI);
CalledFn = SMEAttrs(*CalledFunction, RTLCI);

// An `invoke` of an agnostic ZA function may not return normally (it may
// resume in an exception block). In this case, it acts like a private ZA
Expand Down
19 changes: 10 additions & 9 deletions llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
#include "llvm/IR/Function.h"

namespace llvm {

class AArch64TargetLowering;
namespace RTLIB {
struct RuntimeLibcallsInfo;
}

class Function;
class CallBase;
Expand Down Expand Up @@ -52,14 +53,14 @@ class SMEAttrs {

SMEAttrs() = default;
SMEAttrs(unsigned Mask) { set(Mask); }
SMEAttrs(const Function &F, const AArch64TargetLowering *TLI = nullptr)
SMEAttrs(const Function &F, const RTLIB::RuntimeLibcallsInfo *RTLCI = nullptr)
: SMEAttrs(F.getAttributes()) {
if (TLI)
addKnownFunctionAttrs(F.getName(), *TLI);
if (RTLCI)
addKnownFunctionAttrs(F.getName(), *RTLCI);
}
SMEAttrs(const AttributeList &L);
SMEAttrs(StringRef FuncName, const AArch64TargetLowering &TLI) {
addKnownFunctionAttrs(FuncName, TLI);
SMEAttrs(StringRef FuncName, const RTLIB::RuntimeLibcallsInfo &RTLCI) {
addKnownFunctionAttrs(FuncName, RTLCI);
};

void set(unsigned M, bool Enable = true) {
Expand Down Expand Up @@ -157,7 +158,7 @@ class SMEAttrs {

private:
void addKnownFunctionAttrs(StringRef FuncName,
const AArch64TargetLowering &TLI);
const RTLIB::RuntimeLibcallsInfo &RTLCI);
void validate() const;
};

Expand All @@ -175,7 +176,7 @@ class SMECallAttrs {
SMEAttrs Callsite = SMEAttrs::Normal)
: CallerFn(Caller), CalledFn(Callee), Callsite(Callsite) {}

SMECallAttrs(const CallBase &CB, const AArch64TargetLowering *TLI);
SMECallAttrs(const CallBase &CB, const RTLIB::RuntimeLibcallsInfo *RTLCI);

SMEAttrs &caller() { return CallerFn; }
SMEAttrs &callee() { return IsIndirect ? Callsite : CalledFn; }
Expand Down
Loading