@@ -121,8 +121,10 @@ struct InstInfo {
121121// / Contains the needed ZA state for each instruction in a block. Instructions
122122// / that do not require a ZA state are not recorded.
123123struct BlockInfo {
124- ZAState FixedEntryState{ZAState::ANY};
125124 SmallVector<InstInfo> Insts;
125+ ZAState FixedEntryState{ZAState::ANY};
126+ ZAState DesiredIncomingState{ZAState::ANY};
127+ ZAState DesiredOutgoingState{ZAState::ANY};
126128 LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
127129 LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
128130};
@@ -175,10 +177,15 @@ class EmitContext {
175177 Register AgnosticZABufferPtr = AArch64::NoRegister;
176178};
177179
180+ // / Checks if \p State is a legal edge bundle state. For a state to be a legal
181+ // / bundle state, it must be possible to transition from it to any other bundle
182+ // / state without losing any ZA state. This is the case for ACTIVE/LOCAL_SAVED,
183+ // / as you can transition between those states by saving/restoring ZA. The OFF
184+ // / state would not be legal, as transitioning to it drops the content of ZA.
178185static bool isLegalEdgeBundleZAState (ZAState State) {
179186 switch (State) {
180- case ZAState::ACTIVE:
181- case ZAState::LOCAL_SAVED:
187+ case ZAState::ACTIVE: // ZA state within the accumulator/ZT0.
188+ case ZAState::LOCAL_SAVED: // ZA state is saved on the stack.
182189 return true ;
183190 default :
184191 return false ;
@@ -238,7 +245,8 @@ getZAStateBeforeInst(const TargetRegisterInfo &TRI, MachineInstr &MI,
238245struct MachineSMEABI : public MachineFunctionPass {
239246 inline static char ID = 0 ;
240247
241- MachineSMEABI () : MachineFunctionPass(ID) {}
248+ MachineSMEABI (CodeGenOptLevel OptLevel = CodeGenOptLevel::Default)
249+ : MachineFunctionPass(ID), OptLevel(OptLevel) {}
242250
243251 bool runOnMachineFunction (MachineFunction &MF) override ;
244252
@@ -267,6 +275,11 @@ struct MachineSMEABI : public MachineFunctionPass {
267275 const EdgeBundles &Bundles,
268276 ArrayRef<ZAState> BundleStates);
269277
278+ // / Propagates desired states forwards (from predecessors -> successors) if
279+ // / \p Forwards, otherwise, propagates backwards (from successors ->
280+ // / predecessors).
281+ void propagateDesiredStates (FunctionInfo &FnInfo, bool Forwards = true );
282+
270283 // Emission routines for private and shared ZA functions (using lazy saves).
271284 void emitNewZAPrologue (MachineBasicBlock &MBB,
272285 MachineBasicBlock::iterator MBBI);
@@ -335,12 +348,15 @@ struct MachineSMEABI : public MachineFunctionPass {
335348 MachineBasicBlock::iterator MBBI, DebugLoc DL);
336349
337350private:
351+ CodeGenOptLevel OptLevel = CodeGenOptLevel::Default;
352+
338353 MachineFunction *MF = nullptr ;
339354 const AArch64Subtarget *Subtarget = nullptr ;
340355 const AArch64RegisterInfo *TRI = nullptr ;
341356 const AArch64FunctionInfo *AFI = nullptr ;
342357 const TargetInstrInfo *TII = nullptr ;
343358 MachineRegisterInfo *MRI = nullptr ;
359+ MachineLoopInfo *MLI = nullptr ;
344360};
345361
346362static LiveRegs getPhysLiveRegs (LiveRegUnits const &LiveUnits) {
@@ -422,12 +438,69 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
422438
423439 // Reverse vector (as we had to iterate backwards for liveness).
424440 std::reverse (Block.Insts .begin (), Block.Insts .end ());
441+
442+ // Record the desired states on entry/exit of this block. These are the
443+ // states that would not incur a state transition.
444+ if (!Block.Insts .empty ()) {
445+ Block.DesiredIncomingState = Block.Insts .front ().NeededState ;
446+ Block.DesiredOutgoingState = Block.Insts .back ().NeededState ;
447+ }
425448 }
426449
427450 return FunctionInfo{std::move (Blocks), AfterSMEProloguePt,
428451 PhysLiveRegsAfterSMEPrologue};
429452}
430453
454+ void MachineSMEABI::propagateDesiredStates (FunctionInfo &FnInfo,
455+ bool Forwards) {
456+ // If `Forwards`, this propagates desired states from predecessors to
457+ // successors, otherwise, this propagates states from successors to
458+ // predecessors.
459+ auto GetBlockState = [](BlockInfo &Block, bool Incoming) -> ZAState & {
460+ return Incoming ? Block.DesiredIncomingState : Block.DesiredOutgoingState ;
461+ };
462+
463+ SmallVector<MachineBasicBlock *> Worklist;
464+ for (auto [BlockID, BlockInfo] : enumerate(FnInfo.Blocks )) {
465+ if (!isLegalEdgeBundleZAState (GetBlockState (BlockInfo, Forwards)))
466+ Worklist.push_back (MF->getBlockNumbered (BlockID));
467+ }
468+
469+ while (!Worklist.empty ()) {
470+ MachineBasicBlock *MBB = Worklist.pop_back_val ();
471+ BlockInfo &Block = FnInfo.Blocks [MBB->getNumber ()];
472+
473+ // Pick a legal edge bundle state that matches the majority of
474+ // predecessors/successors.
475+ int StateCounts[ZAState::NUM_ZA_STATE] = {0 };
476+ for (MachineBasicBlock *PredOrSucc :
477+ Forwards ? predecessors (MBB) : successors (MBB)) {
478+ BlockInfo &PredOrSuccBlock = FnInfo.Blocks [PredOrSucc->getNumber ()];
479+ ZAState ZAState = GetBlockState (PredOrSuccBlock, !Forwards);
480+ if (isLegalEdgeBundleZAState (ZAState))
481+ StateCounts[ZAState]++;
482+ }
483+
484+ ZAState PropagatedState = ZAState (max_element (StateCounts) - StateCounts);
485+ ZAState &CurrentState = GetBlockState (Block, Forwards);
486+ if (PropagatedState != CurrentState) {
487+ CurrentState = PropagatedState;
488+ ZAState &OtherState = GetBlockState (Block, !Forwards);
489+ // Propagate to the incoming/outgoing state if that is also "ANY".
490+ if (OtherState == ZAState::ANY)
491+ OtherState = PropagatedState;
492+ // Push any successors/predecessors that may need updating to the
493+ // worklist.
494+ for (MachineBasicBlock *SuccOrPred :
495+ Forwards ? successors (MBB) : predecessors (MBB)) {
496+ BlockInfo &SuccOrPredBlock = FnInfo.Blocks [SuccOrPred->getNumber ()];
497+ if (!isLegalEdgeBundleZAState (GetBlockState (SuccOrPredBlock, Forwards)))
498+ Worklist.push_back (SuccOrPred);
499+ }
500+ }
501+ }
502+ }
503+
431504// / Assigns each edge bundle a ZA state based on the needed states of blocks
432505// / that have incoming or outgoing edges in that bundle.
433506SmallVector<ZAState>
@@ -440,40 +513,36 @@ MachineSMEABI::assignBundleZAStates(const EdgeBundles &Bundles,
440513 // Attempt to assign a ZA state for this bundle that minimizes state
441514 // transitions. Edges within loops are given a higher weight as we assume
442515 // they will be executed more than once.
443- // TODO: We should propagate desired incoming/outgoing states through blocks
444- // that have the "ANY" state first to make better global decisions.
445516 int EdgeStateCounts[ZAState::NUM_ZA_STATE] = {0 };
446517 for (unsigned BlockID : Bundles.getBlocks (I)) {
447518 LLVM_DEBUG (dbgs () << " - bb." << BlockID);
448519
449520 const BlockInfo &Block = FnInfo.Blocks [BlockID];
450- if (Block.Insts .empty ()) {
451- LLVM_DEBUG (dbgs () << " (no state preference)\n " );
452- continue ;
453- }
454521 bool InEdge = Bundles.getBundle (BlockID, /* Out=*/ false ) == I;
455522 bool OutEdge = Bundles.getBundle (BlockID, /* Out=*/ true ) == I;
456523
457- ZAState DesiredIncomingState = Block.Insts .front ().NeededState ;
458- if (InEdge && isLegalEdgeBundleZAState (DesiredIncomingState)) {
459- EdgeStateCounts[DesiredIncomingState]++;
524+ bool LegalInEdge =
525+ InEdge && isLegalEdgeBundleZAState (Block.DesiredIncomingState );
526+ bool LegalOutEgde =
527+ OutEdge && isLegalEdgeBundleZAState (Block.DesiredOutgoingState );
528+ if (LegalInEdge) {
460529 LLVM_DEBUG (dbgs () << " DesiredIncomingState: "
461- << getZAStateString (DesiredIncomingState));
530+ << getZAStateString (Block.DesiredIncomingState ));
531+ EdgeStateCounts[Block.DesiredIncomingState ]++;
462532 }
463- ZAState DesiredOutgoingState = Block.Insts .back ().NeededState ;
464- if (OutEdge && isLegalEdgeBundleZAState (DesiredOutgoingState)) {
465- EdgeStateCounts[DesiredOutgoingState]++;
533+ if (LegalOutEgde) {
466534 LLVM_DEBUG (dbgs () << " DesiredOutgoingState: "
467- << getZAStateString (DesiredOutgoingState));
535+ << getZAStateString (Block.DesiredOutgoingState ));
536+ EdgeStateCounts[Block.DesiredOutgoingState ]++;
468537 }
538+ if (!LegalInEdge && !LegalOutEgde)
539+ LLVM_DEBUG (dbgs () << " (no state preference)" );
469540 LLVM_DEBUG (dbgs () << ' \n ' );
470541 }
471542
472543 ZAState BundleState =
473544 ZAState (max_element (EdgeStateCounts) - EdgeStateCounts);
474545
475- // Force ZA to be active in bundles that don't have a preferred state.
476- // TODO: Something better here (to avoid extra mode switches).
477546 if (BundleState == ZAState::ANY)
478547 BundleState = ZAState::ACTIVE;
479548
@@ -918,6 +987,43 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
918987 getAnalysis<EdgeBundlesWrapperLegacy>().getEdgeBundles ();
919988
920989 FunctionInfo FnInfo = collectNeededZAStates (SMEFnAttrs);
990+
991+ if (OptLevel != CodeGenOptLevel::None) {
992+ // Propagate desired states forward, then backwards. Most of the propagation
993+ // should be done in the forward step, and backwards propagation is then
994+ // used to fill in the gaps. Note: Doing both in one step can give poor
995+ // results. For example, consider this subgraph:
996+ //
997+ // ┌─────┐
998+ // ┌─┤ BB0 ◄───┐
999+ // │ └─┬───┘ │
1000+ // │ ┌─▼───◄──┐│
1001+ // │ │ BB1 │ ││
1002+ // │ └─┬┬──┘ ││
1003+ // │ │└─────┘│
1004+ // │ ┌─▼───┐ │
1005+ // │ │ BB2 ├───┘
1006+ // │ └─┬───┘
1007+ // │ ┌─▼───┐
1008+ // └─► BB3 │
1009+ // └─────┘
1010+ //
1011+ // If:
1012+ // - "BB0" and "BB2" (outer loop) has no state preference
1013+ // - "BB1" (inner loop) desires the ACTIVE state on entry/exit
1014+ // - "BB3" desires the LOCAL_SAVED state on entry
1015+ //
1016+ // If we propagate forwards first, ACTIVE is propagated from BB1 to BB2,
1017+ // then from BB2 to BB0. Which results in the inner and outer loops having
1018+ // the "ACTIVE" state. This avoids any state changes in the loops.
1019+ //
1020+ // If we propagate backwards first, we _could_ propagate LOCAL_SAVED from
1021+ // BB3 to BB0, which would result in a transition from ACTIVE -> LOCAL_SAVED
1022+ // in the outer loop.
1023+ for (bool Forwards : {true , false })
1024+ propagateDesiredStates (FnInfo, Forwards);
1025+ }
1026+
9211027 SmallVector<ZAState> BundleStates = assignBundleZAStates (Bundles, FnInfo);
9221028
9231029 EmitContext Context;
@@ -941,4 +1047,6 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
9411047 return true ;
9421048}
9431049
944- FunctionPass *llvm::createMachineSMEABIPass () { return new MachineSMEABI (); }
1050+ FunctionPass *llvm::createMachineSMEABIPass (CodeGenOptLevel OptLevel) {
1051+ return new MachineSMEABI (OptLevel);
1052+ }
0 commit comments