Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llvm/lib/Target/AArch64/AArch64.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ FunctionPass *createAArch64CleanupLocalDynamicTLSPass();
FunctionPass *createAArch64CollectLOHPass();
FunctionPass *createSMEABIPass();
FunctionPass *createSMEPeepholeOptPass();
FunctionPass *createMachineSMEABIPass();
FunctionPass *createMachineSMEABIPass(CodeGenOptLevel);
ModulePass *createSVEIntrinsicOptsPass();
InstructionSelector *
createAArch64InstructionSelector(const AArch64TargetMachine &,
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -764,8 +764,8 @@ bool AArch64PassConfig::addGlobalInstructionSelect() {
}

void AArch64PassConfig::addMachineSSAOptimization() {
if (EnableNewSMEABILowering && TM->getOptLevel() != CodeGenOptLevel::None)
addPass(createMachineSMEABIPass());
if (TM->getOptLevel() != CodeGenOptLevel::None && EnableNewSMEABILowering)
addPass(createMachineSMEABIPass(TM->getOptLevel()));

if (TM->getOptLevel() != CodeGenOptLevel::None && EnableSMEPeepholeOpt)
addPass(createSMEPeepholeOptPass());
Expand Down Expand Up @@ -798,7 +798,7 @@ bool AArch64PassConfig::addILPOpts() {

void AArch64PassConfig::addPreRegAlloc() {
if (TM->getOptLevel() == CodeGenOptLevel::None && EnableNewSMEABILowering)
addPass(createMachineSMEABIPass());
addPass(createMachineSMEABIPass(CodeGenOptLevel::None));

// Change dead register definitions to refer to the zero register.
if (TM->getOptLevel() != CodeGenOptLevel::None &&
Expand Down
150 changes: 129 additions & 21 deletions llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,10 @@ struct InstInfo {
/// Contains the needed ZA state for each instruction in a block. Instructions
/// that do not require a ZA state are not recorded.
struct BlockInfo {
ZAState FixedEntryState{ZAState::ANY};
SmallVector<InstInfo> Insts;
ZAState FixedEntryState{ZAState::ANY};
ZAState DesiredIncomingState{ZAState::ANY};
ZAState DesiredOutgoingState{ZAState::ANY};
LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
};
Expand Down Expand Up @@ -175,10 +177,15 @@ class EmitContext {
Register AgnosticZABufferPtr = AArch64::NoRegister;
};

/// Checks if \p State is a legal edge bundle state. For a state to be a legal
/// bundle state, it must be possible to transition from it to any other bundle
/// state without losing any ZA state. This is the case for ACTIVE/LOCAL_SAVED,
/// as you can transition between those states by saving/restoring ZA. The OFF
/// state would not be legal, as transitioning to it drops the content of ZA.
static bool isLegalEdgeBundleZAState(ZAState State) {
switch (State) {
case ZAState::ACTIVE:
case ZAState::LOCAL_SAVED:
case ZAState::ACTIVE: // ZA state within the accumulator/ZT0.
case ZAState::LOCAL_SAVED: // ZA state is saved on the stack.
return true;
default:
return false;
Expand Down Expand Up @@ -238,7 +245,8 @@ getZAStateBeforeInst(const TargetRegisterInfo &TRI, MachineInstr &MI,
struct MachineSMEABI : public MachineFunctionPass {
inline static char ID = 0;

MachineSMEABI() : MachineFunctionPass(ID) {}
MachineSMEABI(CodeGenOptLevel OptLevel = CodeGenOptLevel::Default)
: MachineFunctionPass(ID), OptLevel(OptLevel) {}

bool runOnMachineFunction(MachineFunction &MF) override;

Expand Down Expand Up @@ -267,6 +275,11 @@ struct MachineSMEABI : public MachineFunctionPass {
const EdgeBundles &Bundles,
ArrayRef<ZAState> BundleStates);

/// Propagates desired states forwards (from predecessors -> successors) if
/// \p Forwards, otherwise, propagates backwards (from successors ->
/// predecessors).
void propagateDesiredStates(FunctionInfo &FnInfo, bool Forwards = true);

// Emission routines for private and shared ZA functions (using lazy saves).
void emitNewZAPrologue(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI);
Expand Down Expand Up @@ -335,12 +348,15 @@ struct MachineSMEABI : public MachineFunctionPass {
MachineBasicBlock::iterator MBBI, DebugLoc DL);

private:
CodeGenOptLevel OptLevel = CodeGenOptLevel::Default;

MachineFunction *MF = nullptr;
const AArch64Subtarget *Subtarget = nullptr;
const AArch64RegisterInfo *TRI = nullptr;
const AArch64FunctionInfo *AFI = nullptr;
const TargetInstrInfo *TII = nullptr;
MachineRegisterInfo *MRI = nullptr;
MachineLoopInfo *MLI = nullptr;
};

static LiveRegs getPhysLiveRegs(LiveRegUnits const &LiveUnits) {
Expand Down Expand Up @@ -422,12 +438,69 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {

// Reverse vector (as we had to iterate backwards for liveness).
std::reverse(Block.Insts.begin(), Block.Insts.end());

// Record the desired states on entry/exit of this block. These are the
// states that would not incur a state transition.
if (!Block.Insts.empty()) {
Block.DesiredIncomingState = Block.Insts.front().NeededState;
Block.DesiredOutgoingState = Block.Insts.back().NeededState;
}
}

return FunctionInfo{std::move(Blocks), AfterSMEProloguePt,
PhysLiveRegsAfterSMEPrologue};
}

void MachineSMEABI::propagateDesiredStates(FunctionInfo &FnInfo,
bool Forwards) {
// If `Forwards`, this propagates desired states from predecessors to
// successors, otherwise, this propagates states from successors to
// predecessors.
auto GetBlockState = [](BlockInfo &Block, bool Incoming) -> ZAState & {
return Incoming ? Block.DesiredIncomingState : Block.DesiredOutgoingState;
};

SmallVector<MachineBasicBlock *> Worklist;
for (auto [BlockID, BlockInfo] : enumerate(FnInfo.Blocks)) {
if (!isLegalEdgeBundleZAState(GetBlockState(BlockInfo, Forwards)))
Worklist.push_back(MF->getBlockNumbered(BlockID));
}

while (!Worklist.empty()) {
MachineBasicBlock *MBB = Worklist.pop_back_val();
BlockInfo &Block = FnInfo.Blocks[MBB->getNumber()];

// Pick a legal edge bundle state that matches the majority of
// predecessors/successors.
int StateCounts[ZAState::NUM_ZA_STATE] = {0};
for (MachineBasicBlock *PredOrSucc :
Forwards ? predecessors(MBB) : successors(MBB)) {
BlockInfo &PredOrSuccBlock = FnInfo.Blocks[PredOrSucc->getNumber()];
ZAState ZAState = GetBlockState(PredOrSuccBlock, !Forwards);
if (isLegalEdgeBundleZAState(ZAState))
StateCounts[ZAState]++;
}

ZAState PropagatedState = ZAState(max_element(StateCounts) - StateCounts);
ZAState &CurrentState = GetBlockState(Block, Forwards);
if (PropagatedState != CurrentState) {
CurrentState = PropagatedState;
ZAState &OtherState = GetBlockState(Block, !Forwards);
// Propagate to the incoming/outgoing state if that is also "ANY".
if (OtherState == ZAState::ANY)
OtherState = PropagatedState;
// Push any successors/predecessors that may need updating to the
// worklist.
for (MachineBasicBlock *SuccOrPred :
Forwards ? successors(MBB) : predecessors(MBB)) {
BlockInfo &SuccOrPredBlock = FnInfo.Blocks[SuccOrPred->getNumber()];
if (!isLegalEdgeBundleZAState(GetBlockState(SuccOrPredBlock, Forwards)))
Worklist.push_back(SuccOrPred);
}
}
}
}

/// Assigns each edge bundle a ZA state based on the needed states of blocks
/// that have incoming or outgoing edges in that bundle.
SmallVector<ZAState>
Expand All @@ -440,40 +513,36 @@ MachineSMEABI::assignBundleZAStates(const EdgeBundles &Bundles,
// Attempt to assign a ZA state for this bundle that minimizes state
// transitions. Edges within loops are given a higher weight as we assume
// they will be executed more than once.
// TODO: We should propagate desired incoming/outgoing states through blocks
// that have the "ANY" state first to make better global decisions.
int EdgeStateCounts[ZAState::NUM_ZA_STATE] = {0};
for (unsigned BlockID : Bundles.getBlocks(I)) {
LLVM_DEBUG(dbgs() << "- bb." << BlockID);

const BlockInfo &Block = FnInfo.Blocks[BlockID];
if (Block.Insts.empty()) {
LLVM_DEBUG(dbgs() << " (no state preference)\n");
continue;
}
bool InEdge = Bundles.getBundle(BlockID, /*Out=*/false) == I;
bool OutEdge = Bundles.getBundle(BlockID, /*Out=*/true) == I;

ZAState DesiredIncomingState = Block.Insts.front().NeededState;
if (InEdge && isLegalEdgeBundleZAState(DesiredIncomingState)) {
EdgeStateCounts[DesiredIncomingState]++;
bool LegalInEdge =
InEdge && isLegalEdgeBundleZAState(Block.DesiredIncomingState);
bool LegalOutEgde =
OutEdge && isLegalEdgeBundleZAState(Block.DesiredOutgoingState);
if (LegalInEdge) {
LLVM_DEBUG(dbgs() << " DesiredIncomingState: "
<< getZAStateString(DesiredIncomingState));
<< getZAStateString(Block.DesiredIncomingState));
EdgeStateCounts[Block.DesiredIncomingState]++;
}
ZAState DesiredOutgoingState = Block.Insts.back().NeededState;
if (OutEdge && isLegalEdgeBundleZAState(DesiredOutgoingState)) {
EdgeStateCounts[DesiredOutgoingState]++;
if (LegalOutEgde) {
LLVM_DEBUG(dbgs() << " DesiredOutgoingState: "
<< getZAStateString(DesiredOutgoingState));
<< getZAStateString(Block.DesiredOutgoingState));
EdgeStateCounts[Block.DesiredOutgoingState]++;
}
if (!LegalInEdge && !LegalOutEgde)
LLVM_DEBUG(dbgs() << " (no state preference)");
LLVM_DEBUG(dbgs() << '\n');
}

ZAState BundleState =
ZAState(max_element(EdgeStateCounts) - EdgeStateCounts);

// Force ZA to be active in bundles that don't have a preferred state.
// TODO: Something better here (to avoid extra mode switches).
if (BundleState == ZAState::ANY)
BundleState = ZAState::ACTIVE;

Expand Down Expand Up @@ -918,6 +987,43 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
getAnalysis<EdgeBundlesWrapperLegacy>().getEdgeBundles();

FunctionInfo FnInfo = collectNeededZAStates(SMEFnAttrs);

if (OptLevel != CodeGenOptLevel::None) {
// Propagate desired states forward, then backwards. Most of the propagation
// should be done in the forward step, and backwards propagation is then
// used to fill in the gaps. Note: Doing both in one step can give poor
// results. For example, consider this subgraph:
//
// ┌─────┐
// ┌─┤ BB0 ◄───┐
// │ └─┬───┘ │
// │ ┌─▼───◄──┐│
// │ │ BB1 │ ││
// │ └─┬┬──┘ ││
// │ │└─────┘│
// │ ┌─▼───┐ │
// │ │ BB2 ├───┘
// │ └─┬───┘
// │ ┌─▼───┐
// └─► BB3 │
// └─────┘
//
// If:
// - "BB0" and "BB2" (outer loop) has no state preference
// - "BB1" (inner loop) desires the ACTIVE state on entry/exit
// - "BB3" desires the LOCAL_SAVED state on entry
//
// If we propagate forwards first, ACTIVE is propagated from BB1 to BB2,
// then from BB2 to BB0. Which results in the inner and outer loops having
// the "ACTIVE" state. This avoids any state changes in the loops.
//
// If we propagate backwards first, we _could_ propagate LOCAL_SAVED from
// BB3 to BB0, which would result in a transition from ACTIVE -> LOCAL_SAVED
// in the outer loop.
for (bool Forwards : {true, false})
propagateDesiredStates(FnInfo, Forwards);
}

SmallVector<ZAState> BundleStates = assignBundleZAStates(Bundles, FnInfo);

EmitContext Context;
Expand All @@ -941,4 +1047,6 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
return true;
}

FunctionPass *llvm::createMachineSMEABIPass() { return new MachineSMEABI(); }
FunctionPass *llvm::createMachineSMEABIPass(CodeGenOptLevel OptLevel) {
return new MachineSMEABI(OptLevel);
}
7 changes: 2 additions & 5 deletions llvm/test/CodeGen/AArch64/sme-agnostic-za.ll
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,6 @@ define i64 @test_many_callee_arguments(
ret i64 %ret
}

; FIXME: The new lowering should avoid saves/restores in the probing loop.
define void @agnostic_za_buffer_alloc_with_stack_probes() nounwind "aarch64_za_state_agnostic" "probe-stack"="inline-asm" "stack-probe-size"="65536"{
; CHECK-LABEL: agnostic_za_buffer_alloc_with_stack_probes:
; CHECK: // %bb.0:
Expand Down Expand Up @@ -389,16 +388,14 @@ define void @agnostic_za_buffer_alloc_with_stack_probes() nounwind "aarch64_za_s
; CHECK-NEWLOWERING-NEXT: bl __arm_sme_state_size
; CHECK-NEWLOWERING-NEXT: mov x8, sp
; CHECK-NEWLOWERING-NEXT: sub x19, x8, x0
; CHECK-NEWLOWERING-NEXT: .LBB7_1: // =>This Inner Loop Header: Depth=1
; CHECK-NEWLOWERING-NEXT: sub sp, sp, #16, lsl #12 // =65536
; CHECK-NEWLOWERING-NEXT: mov x0, x19
; CHECK-NEWLOWERING-NEXT: bl __arm_sme_save
; CHECK-NEWLOWERING-NEXT: .LBB7_1: // =>This Inner Loop Header: Depth=1
; CHECK-NEWLOWERING-NEXT: sub sp, sp, #16, lsl #12 // =65536
; CHECK-NEWLOWERING-NEXT: cmp sp, x19
; CHECK-NEWLOWERING-NEXT: b.le .LBB7_3
; CHECK-NEWLOWERING-NEXT: // %bb.2: // in Loop: Header=BB7_1 Depth=1
; CHECK-NEWLOWERING-NEXT: mov x0, x19
; CHECK-NEWLOWERING-NEXT: str xzr, [sp]
; CHECK-NEWLOWERING-NEXT: bl __arm_sme_restore
; CHECK-NEWLOWERING-NEXT: b .LBB7_1
; CHECK-NEWLOWERING-NEXT: .LBB7_3:
; CHECK-NEWLOWERING-NEXT: mov sp, x19
Expand Down
Loading