Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1717,6 +1717,7 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
}
case AArch64::InOutZAUsePseudo:
case AArch64::RequiresZASavePseudo:
case AArch64::RequiresZT0SavePseudo:
case AArch64::SMEStateAllocPseudo:
case AArch64::COALESCER_BARRIER_FPR16:
case AArch64::COALESCER_BARRIER_FPR32:
Expand Down
11 changes: 8 additions & 3 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9634,6 +9634,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
if (CallAttrs.requiresLazySave() ||
CallAttrs.requiresPreservingAllZAState())
ZAMarkerNode = AArch64ISD::REQUIRES_ZA_SAVE;
else if (CallAttrs.requiresPreservingZT0())
ZAMarkerNode = AArch64ISD::REQUIRES_ZT0_SAVE;
else if (CallAttrs.caller().hasZAState() ||
CallAttrs.caller().hasZT0State())
ZAMarkerNode = AArch64ISD::INOUT_ZA_USE;
Expand Down Expand Up @@ -9753,7 +9755,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,

SDValue ZTFrameIdx;
MachineFrameInfo &MFI = MF.getFrameInfo();
bool ShouldPreserveZT0 = CallAttrs.requiresPreservingZT0();
bool ShouldPreserveZT0 =
!UseNewSMEABILowering && CallAttrs.requiresPreservingZT0();

// If the caller has ZT0 state which will not be preserved by the callee,
// spill ZT0 before the call.
Expand All @@ -9766,7 +9769,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,

// If caller shares ZT0 but the callee is not shared ZA, we need to stop
// PSTATE.ZA before the call if there is no lazy-save active.
bool DisableZA = CallAttrs.requiresDisablingZABeforeCall();
bool DisableZA =
!UseNewSMEABILowering && CallAttrs.requiresDisablingZABeforeCall();
assert((!DisableZA || !RequiresLazySave) &&
"Lazy-save should have PSTATE.SM=1 on entry to the function");

Expand Down Expand Up @@ -10252,7 +10256,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
getSMToggleCondition(CallAttrs));
}

if (RequiresLazySave || CallAttrs.requiresEnablingZAAfterCall())
if (!UseNewSMEABILowering &&
(RequiresLazySave || CallAttrs.requiresEnablingZAAfterCall()))
// Unconditionally resume ZA.
Result = DAG.getNode(
AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Result,
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def : Pat<(i64 (AArch64AllocateSMESaveBuffer GPR64:$size)),
let hasSideEffects = 1, isMeta = 1 in {
def InOutZAUsePseudo : Pseudo<(outs), (ins), []>, Sched<[]>;
def RequiresZASavePseudo : Pseudo<(outs), (ins), []>, Sched<[]>;
def RequiresZT0SavePseudo : Pseudo<(outs), (ins), []>, Sched<[]>;
}

def SMEStateAllocPseudo : Pseudo<(outs), (ins), []>, Sched<[]>;
Expand All @@ -122,6 +123,11 @@ def AArch64_requires_za_save
[SDNPHasChain, SDNPInGlue, SDNPOutGlue]>;
def : Pat<(AArch64_requires_za_save), (RequiresZASavePseudo)>;

def AArch64_requires_zt0_save
: SDNode<"AArch64ISD::REQUIRES_ZT0_SAVE", SDTypeProfile<0, 0, []>,
[SDNPHasChain, SDNPInGlue, SDNPOutGlue]>;
def : Pat<(AArch64_requires_zt0_save), (RequiresZT0SavePseudo)>;

def AArch64_sme_state_alloc
: SDNode<"AArch64ISD::SME_STATE_ALLOC", SDTypeProfile<0, 0,[]>,
[SDNPHasChain]>;
Expand Down
187 changes: 159 additions & 28 deletions llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,30 @@ using namespace llvm;

namespace {

enum ZAState {
// Note: For agnostic ZA, we assume the function is always entered/exited in the
// "ACTIVE" state -- this _may_ not be the case (since OFF is also a
// possibility, but for the purpose of placing ZA saves/restores, that does not
// matter).
enum ZAState : uint8_t {
// Any/unknown state (not valid)
ANY = 0,

// ZA is in use and active (i.e. within the accumulator)
ACTIVE,

// ZA is active, but ZT0 has been saved.
// This handles the edge case of sharedZA && !sharesZT0.
ACTIVE_ZT0_SAVED,

// A ZA save has been set up or committed (i.e. ZA is dormant or off)
// If the function uses ZT0 it must also be saved.
LOCAL_SAVED,

// ZA has been committed to the lazy save buffer of the current function.
// If the function uses ZT0 it must also be saved.
// ZA is off when a save has been committed.
LOCAL_COMMITTED,

// The ZA/ZT0 state on entry to the function.
ENTRY,

Expand Down Expand Up @@ -164,6 +178,14 @@ class EmitContext {
return AgnosticZABufferPtr;
}

int getZT0SaveSlot(MachineFunction &MF) {
if (ZT0SaveFI)
return *ZT0SaveFI;
MachineFrameInfo &MFI = MF.getFrameInfo();
ZT0SaveFI = MFI.CreateSpillStackObject(64, Align(16));
return *ZT0SaveFI;
}

/// Returns true if the function must allocate a ZA save buffer on entry. This
/// will be the case if, at any point in the function, a ZA save was emitted.
bool needsSaveBuffer() const {
Expand All @@ -173,6 +195,7 @@ class EmitContext {
}

private:
std::optional<int> ZT0SaveFI;
std::optional<int> TPIDR2BlockFI;
Register AgnosticZABufferPtr = AArch64::NoRegister;
};
Expand All @@ -184,8 +207,10 @@ class EmitContext {
/// state would not be legal, as transitioning to it drops the content of ZA.
static bool isLegalEdgeBundleZAState(ZAState State) {
switch (State) {
case ZAState::ACTIVE: // ZA state within the accumulator/ZT0.
case ZAState::LOCAL_SAVED: // ZA state is saved on the stack.
case ZAState::ACTIVE: // ZA state within the accumulator/ZT0.
case ZAState::ACTIVE_ZT0_SAVED: // ZT0 is saved (ZA is active).
case ZAState::LOCAL_SAVED: // ZA state may be saved on the stack.
case ZAState::LOCAL_COMMITTED: // ZA state is saved on the stack.
return true;
default:
return false;
Expand All @@ -199,7 +224,9 @@ StringRef getZAStateString(ZAState State) {
switch (State) {
MAKE_CASE(ZAState::ANY)
MAKE_CASE(ZAState::ACTIVE)
MAKE_CASE(ZAState::ACTIVE_ZT0_SAVED)
MAKE_CASE(ZAState::LOCAL_SAVED)
MAKE_CASE(ZAState::LOCAL_COMMITTED)
MAKE_CASE(ZAState::ENTRY)
MAKE_CASE(ZAState::OFF)
default:
Expand All @@ -221,18 +248,39 @@ static bool isZAorZTRegOp(const TargetRegisterInfo &TRI,
/// Returns the required ZA state needed before \p MI and an iterator pointing
/// to where any code required to change the ZA state should be inserted.
static std::pair<ZAState, MachineBasicBlock::iterator>
getZAStateBeforeInst(const TargetRegisterInfo &TRI, MachineInstr &MI,
bool ZAOffAtReturn) {
getInstNeededZAState(const TargetRegisterInfo &TRI, MachineInstr &MI,
SMEAttrs SMEFnAttrs) {
MachineBasicBlock::iterator InsertPt(MI);

// Note: InOutZAUsePseudo, RequiresZASavePseudo, and RequiresZT0SavePseudo are
// intended to mark the position immediately before a call. Due to
// SelectionDAG constraints, these markers occur after the ADJCALLSTACKDOWN,
// so we use std::prev(InsertPt) to get the position before the call.

if (MI.getOpcode() == AArch64::InOutZAUsePseudo)
return {ZAState::ACTIVE, std::prev(InsertPt)};

// Note: If we need to save both ZA and ZT0 we use RequiresZASavePseudo.
if (MI.getOpcode() == AArch64::RequiresZASavePseudo)
return {ZAState::LOCAL_SAVED, std::prev(InsertPt)};

if (MI.isReturn())
// If we only need to save ZT0 there's two cases to consider:
// 1. The function has ZA state (that we don't need to save).
// - In this case we switch to the "ACTIVE_ZT0_SAVED" state.
// This only saves ZT0.
// 2. The function does not have ZA state
// - In this case we switch to "LOCAL_COMMITTED" state.
// This saves ZT0 and turns ZA off.
if (MI.getOpcode() == AArch64::RequiresZT0SavePseudo) {
return {SMEFnAttrs.hasZAState() ? ZAState::ACTIVE_ZT0_SAVED
: ZAState::LOCAL_COMMITTED,
std::prev(InsertPt)};
}

if (MI.isReturn()) {
bool ZAOffAtReturn = SMEFnAttrs.hasPrivateZAInterface();
return {ZAOffAtReturn ? ZAState::OFF : ZAState::ACTIVE, InsertPt};
}

for (auto &MO : MI.operands()) {
if (isZAorZTRegOp(TRI, MO))
Expand Down Expand Up @@ -280,6 +328,9 @@ struct MachineSMEABI : public MachineFunctionPass {
/// predecessors).
void propagateDesiredStates(FunctionInfo &FnInfo, bool Forwards = true);

void emitZT0SaveRestore(EmitContext &, MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI, bool IsSave);

// Emission routines for private and shared ZA functions (using lazy saves).
void emitSMEPrologue(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI);
Expand All @@ -290,8 +341,8 @@ struct MachineSMEABI : public MachineFunctionPass {
MachineBasicBlock::iterator MBBI);
void emitAllocateLazySaveBuffer(EmitContext &, MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI);
void emitZAOff(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
bool ClearTPIDR2);
void emitZAMode(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
bool ClearTPIDR2, bool On);

// Emission routines for agnostic ZA functions.
void emitSetupFullZASave(MachineBasicBlock &MBB,
Expand Down Expand Up @@ -409,7 +460,7 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
Block.FixedEntryState = ZAState::ENTRY;
} else if (MBB.isEHPad()) {
// EH entry block:
Block.FixedEntryState = ZAState::LOCAL_SAVED;
Block.FixedEntryState = ZAState::LOCAL_COMMITTED;
}

LiveRegUnits LiveUnits(*TRI);
Expand All @@ -431,8 +482,7 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
}
// Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
auto [NeededState, InsertPt] = getZAStateBeforeInst(
*TRI, MI, /*ZAOffAtReturn=*/SMEFnAttrs.hasPrivateZAInterface());
auto [NeededState, InsertPt] = getInstNeededZAState(*TRI, MI, SMEFnAttrs);
assert((InsertPt == MBBI || isCallStartOpcode(InsertPt->getOpcode())) &&
"Unexpected state change insertion point!");
// TODO: Do something to avoid state changes where NZCV is live.
Expand Down Expand Up @@ -752,9 +802,9 @@ void MachineSMEABI::emitRestoreLazySave(EmitContext &Context,
restorePhyRegSave(RegSave, MBB, MBBI, DL);
}

void MachineSMEABI::emitZAOff(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
bool ClearTPIDR2) {
void MachineSMEABI::emitZAMode(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
bool ClearTPIDR2, bool On) {
DebugLoc DL = getDebugLoc(MBB, MBBI);

if (ClearTPIDR2)
Expand All @@ -765,7 +815,7 @@ void MachineSMEABI::emitZAOff(MachineBasicBlock &MBB,
// Disable ZA.
BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1))
.addImm(AArch64SVCR::SVCRZA)
.addImm(0);
.addImm(On ? 1 : 0);
}

void MachineSMEABI::emitAllocateLazySaveBuffer(
Expand Down Expand Up @@ -891,6 +941,28 @@ void MachineSMEABI::emitFullZASaveRestore(EmitContext &Context,
restorePhyRegSave(RegSave, MBB, MBBI, DL);
}

void MachineSMEABI::emitZT0SaveRestore(EmitContext &Context,
MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
bool IsSave) {
DebugLoc DL = getDebugLoc(MBB, MBBI);
Register ZT0Save = MRI->createVirtualRegister(&AArch64::GPR64spRegClass);

BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri), ZT0Save)
.addFrameIndex(Context.getZT0SaveSlot(*MF))
.addImm(0)
.addImm(0);

if (IsSave) {
BuildMI(MBB, MBBI, DL, TII->get(AArch64::STR_TX))
.addReg(AArch64::ZT0)
.addReg(ZT0Save);
} else {
BuildMI(MBB, MBBI, DL, TII->get(AArch64::LDR_TX), AArch64::ZT0)
.addReg(ZT0Save);
}
}

void MachineSMEABI::emitAllocateFullZASaveBuffer(
EmitContext &Context, MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs) {
Expand Down Expand Up @@ -935,6 +1007,17 @@ void MachineSMEABI::emitAllocateFullZASaveBuffer(
restorePhyRegSave(RegSave, MBB, MBBI, DL);
}

struct FromState {
ZAState From;

constexpr uint8_t to(ZAState To) const {
static_assert(NUM_ZA_STATE < 16, "expected ZAState to fit in 4-bits");
return uint8_t(From) << 4 | uint8_t(To);
}
};

constexpr FromState transitionFrom(ZAState From) { return FromState{From}; }

void MachineSMEABI::emitStateChange(EmitContext &Context,
MachineBasicBlock &MBB,
MachineBasicBlock::iterator InsertPt,
Expand All @@ -949,8 +1032,6 @@ void MachineSMEABI::emitStateChange(EmitContext &Context,
if (From == ZAState::ENTRY && To == ZAState::OFF)
return;

[[maybe_unused]] SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();

// TODO: Avoid setting up the save buffer if there's no transition to
// LOCAL_SAVED.
if (From == ZAState::ENTRY) {
Expand All @@ -966,17 +1047,67 @@ void MachineSMEABI::emitStateChange(EmitContext &Context,
From = ZAState::ACTIVE;
}

if (From == ZAState::ACTIVE && To == ZAState::LOCAL_SAVED)
emitZASave(Context, MBB, InsertPt, PhysLiveRegs);
else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE)
emitZARestore(Context, MBB, InsertPt, PhysLiveRegs);
else if (To == ZAState::OFF) {
assert(From != ZAState::ENTRY &&
"ENTRY to OFF should have already been handled");
assert(!SMEFnAttrs.hasAgnosticZAInterface() &&
"Should not turn ZA off in agnostic ZA function");
emitZAOff(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED);
} else {
SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();
bool IsAgnosticZA = SMEFnAttrs.hasAgnosticZAInterface();
bool HasZT0State = SMEFnAttrs.hasZT0State();
bool HasZAState = IsAgnosticZA || SMEFnAttrs.hasZAState();

switch (transitionFrom(From).to(To)) {
// This section handles: ACTIVE <-> ACTIVE_ZT0_SAVED
case transitionFrom(ZAState::ACTIVE).to(ZAState::ACTIVE_ZT0_SAVED):
emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true);
break;
case transitionFrom(ZAState::ACTIVE_ZT0_SAVED).to(ZAState::ACTIVE):
emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/false);
break;

// This section handles: ACTIVE[_ZT0_SAVED] -> LOCAL_SAVED
case transitionFrom(ZAState::ACTIVE).to(ZAState::LOCAL_SAVED):
case transitionFrom(ZAState::ACTIVE_ZT0_SAVED).to(ZAState::LOCAL_SAVED):
if (HasZT0State && From == ZAState::ACTIVE)
emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true);
if (HasZAState)
emitZASave(Context, MBB, InsertPt, PhysLiveRegs);
break;

// This section handles: ACTIVE -> LOCAL_COMMITTED
case transitionFrom(ZAState::ACTIVE).to(ZAState::LOCAL_COMMITTED):
// TODO: We could support ZA state here, but this transition is currently
// only possible when we _don't_ have ZA state.
assert(HasZT0State && !HasZAState && "Expect to only have ZT0 state.");
emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true);
emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/false, /*On=*/false);
break;

// This section handles: LOCAL_COMMITTED -> (OFF|LOCAL_SAVED)
case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::OFF):
case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::LOCAL_SAVED):
// These transistions are a no-op.
break;

// This section handles: LOCAL_(SAVED|COMMITTED) -> ACTIVE[_ZT0_SAVED]
case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::ACTIVE):
case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::ACTIVE_ZT0_SAVED):
case transitionFrom(ZAState::LOCAL_SAVED).to(ZAState::ACTIVE):
if (HasZAState)
emitZARestore(Context, MBB, InsertPt, PhysLiveRegs);
else
emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/false, /*On=*/true);
if (HasZT0State && To == ZAState::ACTIVE)
emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/false);
break;

// This section handles transistions to OFF (not previously covered)
case transitionFrom(ZAState::ACTIVE).to(ZAState::OFF):
case transitionFrom(ZAState::ACTIVE_ZT0_SAVED).to(ZAState::OFF):
case transitionFrom(ZAState::LOCAL_SAVED).to(ZAState::OFF):
assert(SMEFnAttrs.hasPrivateZAInterface() &&
"Did not expect to turn ZA off in shared/agnostic ZA function");
emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED,
/*On=*/false);
break;

default:
dbgs() << "Error: Transition from " << getZAStateString(From) << " to "
<< getZAStateString(To) << '\n';
llvm_unreachable("Unimplemented state transition");
Expand Down
4 changes: 0 additions & 4 deletions llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,6 @@ define void @test7() nounwind "aarch64_inout_zt0" {
; CHECK-NEXT: str zt0, [x19]
; CHECK-NEXT: smstop za
; CHECK-NEXT: bl callee
; CHECK-NEXT: smstart za
; CHECK-NEXT: ldr zt0, [x19]
; CHECK-NEXT: str zt0, [x19]
; CHECK-NEXT: smstop za
; CHECK-NEXT: bl callee
; CHECK-NEXT: smstart za
; CHECK-NEXT: ldr zt0, [x19]
Expand Down
Loading