Skip to content

Commit 70cf1b0

Browse files
MacDueaokblast
authored andcommitted
[AArch64][SME] Propagate desired ZA states in the MachineSMEABIPass (llvm#149510)
This patch adds a step to the MachineSMEABIPass that propagates desired ZA states. This aims to pick better ZA states for edge bundles, as when many (or all) blocks in a bundle do not have a preferred ZA state, the ZA state assigned to a bundle can be less than ideal. An important case is nested loops, where only the inner loop has a preferred ZA state. Here we'd like to propagate the ZA state from the inner loop to the outer loops (to avoid saves/restores in any loop).
1 parent 99d6ad8 commit 70cf1b0

File tree

8 files changed

+513
-183
lines changed

8 files changed

+513
-183
lines changed

llvm/lib/Target/AArch64/AArch64.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ FunctionPass *createAArch64CleanupLocalDynamicTLSPass();
6060
FunctionPass *createAArch64CollectLOHPass();
6161
FunctionPass *createSMEABIPass();
6262
FunctionPass *createSMEPeepholeOptPass();
63-
FunctionPass *createMachineSMEABIPass();
63+
FunctionPass *createMachineSMEABIPass(CodeGenOptLevel);
6464
ModulePass *createSVEIntrinsicOptsPass();
6565
InstructionSelector *
6666
createAArch64InstructionSelector(const AArch64TargetMachine &,

llvm/lib/Target/AArch64/AArch64TargetMachine.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -764,8 +764,8 @@ bool AArch64PassConfig::addGlobalInstructionSelect() {
764764
}
765765

766766
void AArch64PassConfig::addMachineSSAOptimization() {
767-
if (EnableNewSMEABILowering && TM->getOptLevel() != CodeGenOptLevel::None)
768-
addPass(createMachineSMEABIPass());
767+
if (TM->getOptLevel() != CodeGenOptLevel::None && EnableNewSMEABILowering)
768+
addPass(createMachineSMEABIPass(TM->getOptLevel()));
769769

770770
if (TM->getOptLevel() != CodeGenOptLevel::None && EnableSMEPeepholeOpt)
771771
addPass(createSMEPeepholeOptPass());
@@ -798,7 +798,7 @@ bool AArch64PassConfig::addILPOpts() {
798798

799799
void AArch64PassConfig::addPreRegAlloc() {
800800
if (TM->getOptLevel() == CodeGenOptLevel::None && EnableNewSMEABILowering)
801-
addPass(createMachineSMEABIPass());
801+
addPass(createMachineSMEABIPass(CodeGenOptLevel::None));
802802

803803
// Change dead register definitions to refer to the zero register.
804804
if (TM->getOptLevel() != CodeGenOptLevel::None &&

llvm/lib/Target/AArch64/MachineSMEABIPass.cpp

Lines changed: 129 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,10 @@ struct InstInfo {
121121
/// Contains the needed ZA state for each instruction in a block. Instructions
122122
/// that do not require a ZA state are not recorded.
123123
struct BlockInfo {
124-
ZAState FixedEntryState{ZAState::ANY};
125124
SmallVector<InstInfo> Insts;
125+
ZAState FixedEntryState{ZAState::ANY};
126+
ZAState DesiredIncomingState{ZAState::ANY};
127+
ZAState DesiredOutgoingState{ZAState::ANY};
126128
LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
127129
LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
128130
};
@@ -175,10 +177,15 @@ class EmitContext {
175177
Register AgnosticZABufferPtr = AArch64::NoRegister;
176178
};
177179

180+
/// Checks if \p State is a legal edge bundle state. For a state to be a legal
181+
/// bundle state, it must be possible to transition from it to any other bundle
182+
/// state without losing any ZA state. This is the case for ACTIVE/LOCAL_SAVED,
183+
/// as you can transition between those states by saving/restoring ZA. The OFF
184+
/// state would not be legal, as transitioning to it drops the content of ZA.
178185
static bool isLegalEdgeBundleZAState(ZAState State) {
179186
switch (State) {
180-
case ZAState::ACTIVE:
181-
case ZAState::LOCAL_SAVED:
187+
case ZAState::ACTIVE: // ZA state within the accumulator/ZT0.
188+
case ZAState::LOCAL_SAVED: // ZA state is saved on the stack.
182189
return true;
183190
default:
184191
return false;
@@ -238,7 +245,8 @@ getZAStateBeforeInst(const TargetRegisterInfo &TRI, MachineInstr &MI,
238245
struct MachineSMEABI : public MachineFunctionPass {
239246
inline static char ID = 0;
240247

241-
MachineSMEABI() : MachineFunctionPass(ID) {}
248+
MachineSMEABI(CodeGenOptLevel OptLevel = CodeGenOptLevel::Default)
249+
: MachineFunctionPass(ID), OptLevel(OptLevel) {}
242250

243251
bool runOnMachineFunction(MachineFunction &MF) override;
244252

@@ -267,6 +275,11 @@ struct MachineSMEABI : public MachineFunctionPass {
267275
const EdgeBundles &Bundles,
268276
ArrayRef<ZAState> BundleStates);
269277

278+
/// Propagates desired states forwards (from predecessors -> successors) if
279+
/// \p Forwards, otherwise, propagates backwards (from successors ->
280+
/// predecessors).
281+
void propagateDesiredStates(FunctionInfo &FnInfo, bool Forwards = true);
282+
270283
// Emission routines for private and shared ZA functions (using lazy saves).
271284
void emitNewZAPrologue(MachineBasicBlock &MBB,
272285
MachineBasicBlock::iterator MBBI);
@@ -335,12 +348,15 @@ struct MachineSMEABI : public MachineFunctionPass {
335348
MachineBasicBlock::iterator MBBI, DebugLoc DL);
336349

337350
private:
351+
CodeGenOptLevel OptLevel = CodeGenOptLevel::Default;
352+
338353
MachineFunction *MF = nullptr;
339354
const AArch64Subtarget *Subtarget = nullptr;
340355
const AArch64RegisterInfo *TRI = nullptr;
341356
const AArch64FunctionInfo *AFI = nullptr;
342357
const TargetInstrInfo *TII = nullptr;
343358
MachineRegisterInfo *MRI = nullptr;
359+
MachineLoopInfo *MLI = nullptr;
344360
};
345361

346362
static LiveRegs getPhysLiveRegs(LiveRegUnits const &LiveUnits) {
@@ -422,12 +438,69 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
422438

423439
// Reverse vector (as we had to iterate backwards for liveness).
424440
std::reverse(Block.Insts.begin(), Block.Insts.end());
441+
442+
// Record the desired states on entry/exit of this block. These are the
443+
// states that would not incur a state transition.
444+
if (!Block.Insts.empty()) {
445+
Block.DesiredIncomingState = Block.Insts.front().NeededState;
446+
Block.DesiredOutgoingState = Block.Insts.back().NeededState;
447+
}
425448
}
426449

427450
return FunctionInfo{std::move(Blocks), AfterSMEProloguePt,
428451
PhysLiveRegsAfterSMEPrologue};
429452
}
430453

454+
void MachineSMEABI::propagateDesiredStates(FunctionInfo &FnInfo,
455+
bool Forwards) {
456+
// If `Forwards`, this propagates desired states from predecessors to
457+
// successors, otherwise, this propagates states from successors to
458+
// predecessors.
459+
auto GetBlockState = [](BlockInfo &Block, bool Incoming) -> ZAState & {
460+
return Incoming ? Block.DesiredIncomingState : Block.DesiredOutgoingState;
461+
};
462+
463+
SmallVector<MachineBasicBlock *> Worklist;
464+
for (auto [BlockID, BlockInfo] : enumerate(FnInfo.Blocks)) {
465+
if (!isLegalEdgeBundleZAState(GetBlockState(BlockInfo, Forwards)))
466+
Worklist.push_back(MF->getBlockNumbered(BlockID));
467+
}
468+
469+
while (!Worklist.empty()) {
470+
MachineBasicBlock *MBB = Worklist.pop_back_val();
471+
BlockInfo &Block = FnInfo.Blocks[MBB->getNumber()];
472+
473+
// Pick a legal edge bundle state that matches the majority of
474+
// predecessors/successors.
475+
int StateCounts[ZAState::NUM_ZA_STATE] = {0};
476+
for (MachineBasicBlock *PredOrSucc :
477+
Forwards ? predecessors(MBB) : successors(MBB)) {
478+
BlockInfo &PredOrSuccBlock = FnInfo.Blocks[PredOrSucc->getNumber()];
479+
ZAState ZAState = GetBlockState(PredOrSuccBlock, !Forwards);
480+
if (isLegalEdgeBundleZAState(ZAState))
481+
StateCounts[ZAState]++;
482+
}
483+
484+
ZAState PropagatedState = ZAState(max_element(StateCounts) - StateCounts);
485+
ZAState &CurrentState = GetBlockState(Block, Forwards);
486+
if (PropagatedState != CurrentState) {
487+
CurrentState = PropagatedState;
488+
ZAState &OtherState = GetBlockState(Block, !Forwards);
489+
// Propagate to the incoming/outgoing state if that is also "ANY".
490+
if (OtherState == ZAState::ANY)
491+
OtherState = PropagatedState;
492+
// Push any successors/predecessors that may need updating to the
493+
// worklist.
494+
for (MachineBasicBlock *SuccOrPred :
495+
Forwards ? successors(MBB) : predecessors(MBB)) {
496+
BlockInfo &SuccOrPredBlock = FnInfo.Blocks[SuccOrPred->getNumber()];
497+
if (!isLegalEdgeBundleZAState(GetBlockState(SuccOrPredBlock, Forwards)))
498+
Worklist.push_back(SuccOrPred);
499+
}
500+
}
501+
}
502+
}
503+
431504
/// Assigns each edge bundle a ZA state based on the needed states of blocks
432505
/// that have incoming or outgoing edges in that bundle.
433506
SmallVector<ZAState>
@@ -440,40 +513,36 @@ MachineSMEABI::assignBundleZAStates(const EdgeBundles &Bundles,
440513
// Attempt to assign a ZA state for this bundle that minimizes state
441514
// transitions. Edges within loops are given a higher weight as we assume
442515
// they will be executed more than once.
443-
// TODO: We should propagate desired incoming/outgoing states through blocks
444-
// that have the "ANY" state first to make better global decisions.
445516
int EdgeStateCounts[ZAState::NUM_ZA_STATE] = {0};
446517
for (unsigned BlockID : Bundles.getBlocks(I)) {
447518
LLVM_DEBUG(dbgs() << "- bb." << BlockID);
448519

449520
const BlockInfo &Block = FnInfo.Blocks[BlockID];
450-
if (Block.Insts.empty()) {
451-
LLVM_DEBUG(dbgs() << " (no state preference)\n");
452-
continue;
453-
}
454521
bool InEdge = Bundles.getBundle(BlockID, /*Out=*/false) == I;
455522
bool OutEdge = Bundles.getBundle(BlockID, /*Out=*/true) == I;
456523

457-
ZAState DesiredIncomingState = Block.Insts.front().NeededState;
458-
if (InEdge && isLegalEdgeBundleZAState(DesiredIncomingState)) {
459-
EdgeStateCounts[DesiredIncomingState]++;
524+
bool LegalInEdge =
525+
InEdge && isLegalEdgeBundleZAState(Block.DesiredIncomingState);
526+
bool LegalOutEgde =
527+
OutEdge && isLegalEdgeBundleZAState(Block.DesiredOutgoingState);
528+
if (LegalInEdge) {
460529
LLVM_DEBUG(dbgs() << " DesiredIncomingState: "
461-
<< getZAStateString(DesiredIncomingState));
530+
<< getZAStateString(Block.DesiredIncomingState));
531+
EdgeStateCounts[Block.DesiredIncomingState]++;
462532
}
463-
ZAState DesiredOutgoingState = Block.Insts.back().NeededState;
464-
if (OutEdge && isLegalEdgeBundleZAState(DesiredOutgoingState)) {
465-
EdgeStateCounts[DesiredOutgoingState]++;
533+
if (LegalOutEgde) {
466534
LLVM_DEBUG(dbgs() << " DesiredOutgoingState: "
467-
<< getZAStateString(DesiredOutgoingState));
535+
<< getZAStateString(Block.DesiredOutgoingState));
536+
EdgeStateCounts[Block.DesiredOutgoingState]++;
468537
}
538+
if (!LegalInEdge && !LegalOutEgde)
539+
LLVM_DEBUG(dbgs() << " (no state preference)");
469540
LLVM_DEBUG(dbgs() << '\n');
470541
}
471542

472543
ZAState BundleState =
473544
ZAState(max_element(EdgeStateCounts) - EdgeStateCounts);
474545

475-
// Force ZA to be active in bundles that don't have a preferred state.
476-
// TODO: Something better here (to avoid extra mode switches).
477546
if (BundleState == ZAState::ANY)
478547
BundleState = ZAState::ACTIVE;
479548

@@ -918,6 +987,43 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
918987
getAnalysis<EdgeBundlesWrapperLegacy>().getEdgeBundles();
919988

920989
FunctionInfo FnInfo = collectNeededZAStates(SMEFnAttrs);
990+
991+
if (OptLevel != CodeGenOptLevel::None) {
992+
// Propagate desired states forward, then backwards. Most of the propagation
993+
// should be done in the forward step, and backwards propagation is then
994+
// used to fill in the gaps. Note: Doing both in one step can give poor
995+
// results. For example, consider this subgraph:
996+
//
997+
// ┌─────┐
998+
// ┌─┤ BB0 ◄───┐
999+
// │ └─┬───┘ │
1000+
// │ ┌─▼───◄──┐│
1001+
// │ │ BB1 │ ││
1002+
// │ └─┬┬──┘ ││
1003+
// │ │└─────┘│
1004+
// │ ┌─▼───┐ │
1005+
// │ │ BB2 ├───┘
1006+
// │ └─┬───┘
1007+
// │ ┌─▼───┐
1008+
// └─► BB3 │
1009+
// └─────┘
1010+
//
1011+
// If:
1012+
// - "BB0" and "BB2" (outer loop) has no state preference
1013+
// - "BB1" (inner loop) desires the ACTIVE state on entry/exit
1014+
// - "BB3" desires the LOCAL_SAVED state on entry
1015+
//
1016+
// If we propagate forwards first, ACTIVE is propagated from BB1 to BB2,
1017+
// then from BB2 to BB0. Which results in the inner and outer loops having
1018+
// the "ACTIVE" state. This avoids any state changes in the loops.
1019+
//
1020+
// If we propagate backwards first, we _could_ propagate LOCAL_SAVED from
1021+
// BB3 to BB0, which would result in a transition from ACTIVE -> LOCAL_SAVED
1022+
// in the outer loop.
1023+
for (bool Forwards : {true, false})
1024+
propagateDesiredStates(FnInfo, Forwards);
1025+
}
1026+
9211027
SmallVector<ZAState> BundleStates = assignBundleZAStates(Bundles, FnInfo);
9221028

9231029
EmitContext Context;
@@ -941,4 +1047,6 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
9411047
return true;
9421048
}
9431049

944-
FunctionPass *llvm::createMachineSMEABIPass() { return new MachineSMEABI(); }
1050+
FunctionPass *llvm::createMachineSMEABIPass(CodeGenOptLevel OptLevel) {
1051+
return new MachineSMEABI(OptLevel);
1052+
}

llvm/test/CodeGen/AArch64/sme-agnostic-za.ll

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,6 @@ define i64 @test_many_callee_arguments(
351351
ret i64 %ret
352352
}
353353

354-
; FIXME: The new lowering should avoid saves/restores in the probing loop.
355354
define void @agnostic_za_buffer_alloc_with_stack_probes() nounwind "aarch64_za_state_agnostic" "probe-stack"="inline-asm" "stack-probe-size"="65536"{
356355
; CHECK-LABEL: agnostic_za_buffer_alloc_with_stack_probes:
357356
; CHECK: // %bb.0:
@@ -389,16 +388,14 @@ define void @agnostic_za_buffer_alloc_with_stack_probes() nounwind "aarch64_za_s
389388
; CHECK-NEWLOWERING-NEXT: bl __arm_sme_state_size
390389
; CHECK-NEWLOWERING-NEXT: mov x8, sp
391390
; CHECK-NEWLOWERING-NEXT: sub x19, x8, x0
392-
; CHECK-NEWLOWERING-NEXT: .LBB7_1: // =>This Inner Loop Header: Depth=1
393-
; CHECK-NEWLOWERING-NEXT: sub sp, sp, #16, lsl #12 // =65536
394391
; CHECK-NEWLOWERING-NEXT: mov x0, x19
395392
; CHECK-NEWLOWERING-NEXT: bl __arm_sme_save
393+
; CHECK-NEWLOWERING-NEXT: .LBB7_1: // =>This Inner Loop Header: Depth=1
394+
; CHECK-NEWLOWERING-NEXT: sub sp, sp, #16, lsl #12 // =65536
396395
; CHECK-NEWLOWERING-NEXT: cmp sp, x19
397396
; CHECK-NEWLOWERING-NEXT: b.le .LBB7_3
398397
; CHECK-NEWLOWERING-NEXT: // %bb.2: // in Loop: Header=BB7_1 Depth=1
399-
; CHECK-NEWLOWERING-NEXT: mov x0, x19
400398
; CHECK-NEWLOWERING-NEXT: str xzr, [sp]
401-
; CHECK-NEWLOWERING-NEXT: bl __arm_sme_restore
402399
; CHECK-NEWLOWERING-NEXT: b .LBB7_1
403400
; CHECK-NEWLOWERING-NEXT: .LBB7_3:
404401
; CHECK-NEWLOWERING-NEXT: mov sp, x19

0 commit comments

Comments
 (0)