Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llvm/lib/Target/AArch64/AArch64.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ FunctionPass *createAArch64CleanupLocalDynamicTLSPass();
FunctionPass *createAArch64CollectLOHPass();
FunctionPass *createSMEABIPass();
FunctionPass *createSMEPeepholeOptPass();
FunctionPass *createMachineSMEABIPass();
FunctionPass *createMachineSMEABIPass(CodeGenOptLevel);
ModulePass *createSVEIntrinsicOptsPass();
InstructionSelector *
createAArch64InstructionSelector(const AArch64TargetMachine &,
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -764,8 +764,8 @@ bool AArch64PassConfig::addGlobalInstructionSelect() {
}

void AArch64PassConfig::addMachineSSAOptimization() {
if (EnableNewSMEABILowering && TM->getOptLevel() != CodeGenOptLevel::None)
addPass(createMachineSMEABIPass());
if (TM->getOptLevel() != CodeGenOptLevel::None && EnableNewSMEABILowering)
addPass(createMachineSMEABIPass(TM->getOptLevel()));

if (TM->getOptLevel() != CodeGenOptLevel::None && EnableSMEPeepholeOpt)
addPass(createSMEPeepholeOptPass());
Expand Down Expand Up @@ -798,7 +798,7 @@ bool AArch64PassConfig::addILPOpts() {

void AArch64PassConfig::addPreRegAlloc() {
if (TM->getOptLevel() == CodeGenOptLevel::None && EnableNewSMEABILowering)
addPass(createMachineSMEABIPass());
addPass(createMachineSMEABIPass(CodeGenOptLevel::None));

// Change dead register definitions to refer to the zero register.
if (TM->getOptLevel() != CodeGenOptLevel::None &&
Expand Down
150 changes: 129 additions & 21 deletions llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,10 @@ struct InstInfo {
/// Contains the needed ZA state for each instruction in a block. Instructions
/// that do not require a ZA state are not recorded.
struct BlockInfo {
ZAState FixedEntryState{ZAState::ANY};
SmallVector<InstInfo> Insts;
ZAState FixedEntryState{ZAState::ANY};
ZAState DesiredIncomingState{ZAState::ANY};
ZAState DesiredOutgoingState{ZAState::ANY};
LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
};
Expand Down Expand Up @@ -175,10 +177,15 @@ class EmitContext {
Register AgnosticZABufferPtr = AArch64::NoRegister;
};

/// Checks if \p State is a legal edge bundle state. For a state to be a legal
/// bundle state, it must be possible to transition from it to any other bundle
/// state without losing any ZA state. This is the case for ACTIVE/LOCAL_SAVED,
/// as you can transition between those states by saving/restoring ZA. The OFF
/// state would not be legal, as transitioning to it drops the content of ZA.
static bool isLegalEdgeBundleZAState(ZAState State) {
switch (State) {
case ZAState::ACTIVE:
case ZAState::LOCAL_SAVED:
case ZAState::ACTIVE: // ZA state within the accumulator/ZT0.
case ZAState::LOCAL_SAVED: // ZA state is saved on the stack.
return true;
default:
return false;
Expand Down Expand Up @@ -238,7 +245,8 @@ getZAStateBeforeInst(const TargetRegisterInfo &TRI, MachineInstr &MI,
struct MachineSMEABI : public MachineFunctionPass {
inline static char ID = 0;

MachineSMEABI() : MachineFunctionPass(ID) {}
MachineSMEABI(CodeGenOptLevel OptLevel = CodeGenOptLevel::Default)
: MachineFunctionPass(ID), OptLevel(OptLevel) {}

bool runOnMachineFunction(MachineFunction &MF) override;

Expand Down Expand Up @@ -267,6 +275,11 @@ struct MachineSMEABI : public MachineFunctionPass {
const EdgeBundles &Bundles,
ArrayRef<ZAState> BundleStates);

/// Propagates desired states forwards (from predecessors -> successors) if
/// \p Forwards, otherwise, propagates backwards (from successors ->
/// predecessors).
void propagateDesiredStates(FunctionInfo &FnInfo, bool Forwards = true);

// Emission routines for private and shared ZA functions (using lazy saves).
void emitNewZAPrologue(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI);
Expand Down Expand Up @@ -335,12 +348,15 @@ struct MachineSMEABI : public MachineFunctionPass {
MachineBasicBlock::iterator MBBI, DebugLoc DL);

private:
CodeGenOptLevel OptLevel = CodeGenOptLevel::Default;

MachineFunction *MF = nullptr;
const AArch64Subtarget *Subtarget = nullptr;
const AArch64RegisterInfo *TRI = nullptr;
const AArch64FunctionInfo *AFI = nullptr;
const TargetInstrInfo *TII = nullptr;
MachineRegisterInfo *MRI = nullptr;
MachineLoopInfo *MLI = nullptr;
};

static LiveRegs getPhysLiveRegs(LiveRegUnits const &LiveUnits) {
Expand Down Expand Up @@ -422,12 +438,69 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {

// Reverse vector (as we had to iterate backwards for liveness).
std::reverse(Block.Insts.begin(), Block.Insts.end());

// Record the desired states on entry/exit of this block. These are the
// states that would not incur a state transition.
if (!Block.Insts.empty()) {
Block.DesiredIncomingState = Block.Insts.front().NeededState;
Block.DesiredOutgoingState = Block.Insts.back().NeededState;
}
}

return FunctionInfo{std::move(Blocks), AfterSMEProloguePt,
PhysLiveRegsAfterSMEPrologue};
}

void MachineSMEABI::propagateDesiredStates(FunctionInfo &FnInfo,
bool Forwards) {
// If `Forwards`, this propagates desired states from predecessors to
// successors, otherwise, this propagates states from successors to
// predecessors.
auto GetBlockState = [](BlockInfo &Block, bool Incoming) -> ZAState & {
return Incoming ? Block.DesiredIncomingState : Block.DesiredOutgoingState;
};

SmallVector<MachineBasicBlock *> Worklist;
for (auto [BlockID, BlockInfo] : enumerate(FnInfo.Blocks)) {
if (!isLegalEdgeBundleZAState(GetBlockState(BlockInfo, Forwards)))
Worklist.push_back(MF->getBlockNumbered(BlockID));
}

while (!Worklist.empty()) {
MachineBasicBlock *MBB = Worklist.pop_back_val();
BlockInfo &Block = FnInfo.Blocks[MBB->getNumber()];

// Pick a legal edge bundle state that matches the majority of
// predecessors/successors.
int StateCounts[ZAState::NUM_ZA_STATE] = {0};
for (MachineBasicBlock *PredOrSucc :
Forwards ? predecessors(MBB) : successors(MBB)) {
BlockInfo &PredOrSuccBlock = FnInfo.Blocks[PredOrSucc->getNumber()];
ZAState ZAState = GetBlockState(PredOrSuccBlock, !Forwards);
if (isLegalEdgeBundleZAState(ZAState))
StateCounts[ZAState]++;
}

ZAState PropagatedState = ZAState(max_element(StateCounts) - StateCounts);
ZAState &CurrentState = GetBlockState(Block, Forwards);
if (PropagatedState != CurrentState) {
CurrentState = PropagatedState;
ZAState &OtherState = GetBlockState(Block, !Forwards);
// Propagate to the incoming/outgoing state if that is also "ANY".
if (OtherState == ZAState::ANY)
OtherState = PropagatedState;
// Push any successors/predecessors that may need updating to the
// worklist.
for (MachineBasicBlock *SuccOrPred :
Forwards ? successors(MBB) : predecessors(MBB)) {
BlockInfo &SuccOrPredBlock = FnInfo.Blocks[SuccOrPred->getNumber()];
if (!isLegalEdgeBundleZAState(GetBlockState(SuccOrPredBlock, Forwards)))
Worklist.push_back(SuccOrPred);
}
}
}
}

/// Assigns each edge bundle a ZA state based on the needed states of blocks
/// that have incoming or outgoing edges in that bundle.
SmallVector<ZAState>
Expand All @@ -440,40 +513,36 @@ MachineSMEABI::assignBundleZAStates(const EdgeBundles &Bundles,
// Attempt to assign a ZA state for this bundle that minimizes state
// transitions. Edges within loops are given a higher weight as we assume
// they will be executed more than once.
// TODO: We should propagate desired incoming/outgoing states through blocks
// that have the "ANY" state first to make better global decisions.
int EdgeStateCounts[ZAState::NUM_ZA_STATE] = {0};
for (unsigned BlockID : Bundles.getBlocks(I)) {
LLVM_DEBUG(dbgs() << "- bb." << BlockID);

const BlockInfo &Block = FnInfo.Blocks[BlockID];
if (Block.Insts.empty()) {
LLVM_DEBUG(dbgs() << " (no state preference)\n");
continue;
}
bool InEdge = Bundles.getBundle(BlockID, /*Out=*/false) == I;
bool OutEdge = Bundles.getBundle(BlockID, /*Out=*/true) == I;

ZAState DesiredIncomingState = Block.Insts.front().NeededState;
if (InEdge && isLegalEdgeBundleZAState(DesiredIncomingState)) {
EdgeStateCounts[DesiredIncomingState]++;
bool LegalInEdge =
InEdge && isLegalEdgeBundleZAState(Block.DesiredIncomingState);
bool LegalOutEgde =
OutEdge && isLegalEdgeBundleZAState(Block.DesiredOutgoingState);
if (LegalInEdge) {
LLVM_DEBUG(dbgs() << " DesiredIncomingState: "
<< getZAStateString(DesiredIncomingState));
<< getZAStateString(Block.DesiredIncomingState));
EdgeStateCounts[Block.DesiredIncomingState]++;
}
ZAState DesiredOutgoingState = Block.Insts.back().NeededState;
if (OutEdge && isLegalEdgeBundleZAState(DesiredOutgoingState)) {
EdgeStateCounts[DesiredOutgoingState]++;
if (LegalOutEgde) {
LLVM_DEBUG(dbgs() << " DesiredOutgoingState: "
<< getZAStateString(DesiredOutgoingState));
<< getZAStateString(Block.DesiredOutgoingState));
EdgeStateCounts[Block.DesiredOutgoingState]++;
}
if (!LegalInEdge && !LegalOutEgde)
LLVM_DEBUG(dbgs() << " (no state preference)");
LLVM_DEBUG(dbgs() << '\n');
}

ZAState BundleState =
ZAState(max_element(EdgeStateCounts) - EdgeStateCounts);

// Force ZA to be active in bundles that don't have a preferred state.
// TODO: Something better here (to avoid extra mode switches).
if (BundleState == ZAState::ANY)
BundleState = ZAState::ACTIVE;

Expand Down Expand Up @@ -918,6 +987,43 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
getAnalysis<EdgeBundlesWrapperLegacy>().getEdgeBundles();

FunctionInfo FnInfo = collectNeededZAStates(SMEFnAttrs);

if (OptLevel != CodeGenOptLevel::None) {
// Propagate desired states forwards then backwards. We propagate forwards
// first as this propagates desired states from inner to outer loops.
// Backwards propagation is then used to fill in any gaps. Note: Doing both
// in one step can give poor results. For example:
//
// ┌─────┐
// ┌─┤ BB0 ◄───┐
// │ └─┬───┘ │
// │ ┌─▼───◄──┐│
// │ │ BB1 │ ││
// │ └─┬┬──┘ ││
// │ │└─────┘│
// │ ┌─▼───┐ │
// │ │ BB2 ├───┘
// │ └─┬───┘
// │ ┌─▼───┐
// └─► BB3 │
// └─────┘
//
// If:
// - "BB0" and "BB2" (outer loop) has no state preference
// - "BB1" (inner loop) desires the ACTIVE state on entry/exit
// - "BB3" desires the LOCAL_SAVED state on entry
//
// If we propagate forwards first, ACTIVE is propagated from BB1 to BB2,
// then from BB2 to BB0. Which results in the inner and outer loops having
// the "ACTIVE" state. This avoids any state changes in the loops.
//
// If we propagate backwards first, we _could_ propagate LOCAL_SAVED from
// BB3 to BB0, which would result in a transition from ACTIVE -> LOCAL_SAVED
// in the outer loop.
for (bool Forwards : {true, false})
propagateDesiredStates(FnInfo, Forwards);
}

SmallVector<ZAState> BundleStates = assignBundleZAStates(Bundles, FnInfo);

EmitContext Context;
Expand All @@ -941,4 +1047,6 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
return true;
}

FunctionPass *llvm::createMachineSMEABIPass() { return new MachineSMEABI(); }
FunctionPass *llvm::createMachineSMEABIPass(CodeGenOptLevel OptLevel) {
return new MachineSMEABI(OptLevel);
}
7 changes: 2 additions & 5 deletions llvm/test/CodeGen/AArch64/sme-agnostic-za.ll
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,6 @@ define i64 @test_many_callee_arguments(
ret i64 %ret
}

; FIXME: The new lowering should avoid saves/restores in the probing loop.
define void @agnostic_za_buffer_alloc_with_stack_probes() nounwind "aarch64_za_state_agnostic" "probe-stack"="inline-asm" "stack-probe-size"="65536"{
; CHECK-LABEL: agnostic_za_buffer_alloc_with_stack_probes:
; CHECK: // %bb.0:
Expand Down Expand Up @@ -389,16 +388,14 @@ define void @agnostic_za_buffer_alloc_with_stack_probes() nounwind "aarch64_za_s
; CHECK-NEWLOWERING-NEXT: bl __arm_sme_state_size
; CHECK-NEWLOWERING-NEXT: mov x8, sp
; CHECK-NEWLOWERING-NEXT: sub x19, x8, x0
; CHECK-NEWLOWERING-NEXT: .LBB7_1: // =>This Inner Loop Header: Depth=1
; CHECK-NEWLOWERING-NEXT: sub sp, sp, #16, lsl #12 // =65536
; CHECK-NEWLOWERING-NEXT: mov x0, x19
; CHECK-NEWLOWERING-NEXT: bl __arm_sme_save
; CHECK-NEWLOWERING-NEXT: .LBB7_1: // =>This Inner Loop Header: Depth=1
; CHECK-NEWLOWERING-NEXT: sub sp, sp, #16, lsl #12 // =65536
; CHECK-NEWLOWERING-NEXT: cmp sp, x19
; CHECK-NEWLOWERING-NEXT: b.le .LBB7_3
; CHECK-NEWLOWERING-NEXT: // %bb.2: // in Loop: Header=BB7_1 Depth=1
; CHECK-NEWLOWERING-NEXT: mov x0, x19
; CHECK-NEWLOWERING-NEXT: str xzr, [sp]
; CHECK-NEWLOWERING-NEXT: bl __arm_sme_restore
; CHECK-NEWLOWERING-NEXT: b .LBB7_1
; CHECK-NEWLOWERING-NEXT: .LBB7_3:
; CHECK-NEWLOWERING-NEXT: mov sp, x19
Expand Down
Loading