@@ -213,6 +213,11 @@ struct MachineSMEABI : public MachineFunctionPass {
213213 // / E.g., ACTIVE -> LOCAL_SAVED will insert code required to save ZA.
214214 void insertStateChanges ();
215215
216+ // / Propagates desired states forwards (from predecessors -> successors) if
217+ // / \p Forwards, otherwise, propagates backwards (from successors ->
218+ // / predecessors).
219+ void propagateDesiredStates (bool Forwards = true );
220+
216221 // Emission routines for private and shared ZA functions (using lazy saves).
217222 void emitNewZAPrologue (MachineBasicBlock &MBB,
218223 MachineBasicBlock::iterator MBBI);
@@ -287,8 +292,10 @@ struct MachineSMEABI : public MachineFunctionPass {
287292 // / Contains the needed ZA state for each instruction in a block.
288293 // / Instructions that do not require a ZA state are not recorded.
289294 struct BlockInfo {
290- ZAState FixedEntryState{ZAState::ANY};
291295 SmallVector<InstInfo> Insts;
296+ ZAState FixedEntryState{ZAState::ANY};
297+ ZAState DesiredIncomingState{ZAState::ANY};
298+ ZAState DesiredOutgoingState{ZAState::ANY};
292299 LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
293300 LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
294301 };
@@ -381,28 +388,80 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
381388
382389 // Reverse vector (as we had to iterate backwards for liveness).
383390 std::reverse (Block.Insts .begin (), Block.Insts .end ());
391+
392+ // Record the desired states on entry/exit of this block. These are the
393+ // states that would not incur a state transition.
394+ if (!Block.Insts .empty ()) {
395+ Block.DesiredIncomingState = Block.Insts .front ().NeededState ;
396+ Block.DesiredOutgoingState = Block.Insts .back ().NeededState ;
397+ }
398+ }
399+ }
400+
401+ void MachineSMEABI::propagateDesiredStates (bool Forwards) {
402+ // If `Forwards`, this propagates desired states from predecessors to
403+ // successors, otherwise, this propagates states from successors to
404+ // predecessors.
405+ auto GetBlockState = [](BlockInfo &Block, bool Incoming) -> ZAState & {
406+ return Incoming ? Block.DesiredIncomingState : Block.DesiredOutgoingState ;
407+ };
408+
409+ SmallVector<MachineBasicBlock *> Worklist;
410+ for (auto [BlockID, BlockInfo] : enumerate(State.Blocks )) {
411+ if (!isLegalEdgeBundleZAState (GetBlockState (BlockInfo, Forwards)))
412+ Worklist.push_back (MF->getBlockNumbered (BlockID));
413+ }
414+
415+ while (!Worklist.empty ()) {
416+ MachineBasicBlock *MBB = Worklist.pop_back_val ();
417+ auto &BlockInfo = State.Blocks [MBB->getNumber ()];
418+
419+ // Pick a legal edge bundle state that matches the majority of
420+ // predecessors/successors.
421+ int StateCounts[ZAState::NUM_ZA_STATE] = {0 };
422+ for (MachineBasicBlock *PredOrSucc :
423+ Forwards ? predecessors (MBB) : successors (MBB)) {
424+ auto &PredOrSuccBlockInfo = State.Blocks [PredOrSucc->getNumber ()];
425+ auto ZAState = GetBlockState (PredOrSuccBlockInfo, !Forwards);
426+ if (isLegalEdgeBundleZAState (ZAState))
427+ StateCounts[ZAState]++;
428+ }
429+
430+ ZAState PropagatedState = ZAState (max_element (StateCounts) - StateCounts);
431+ auto &CurrentState = GetBlockState (BlockInfo, Forwards);
432+ if (PropagatedState != CurrentState) {
433+ CurrentState = PropagatedState;
434+ auto &OtherState = GetBlockState (BlockInfo, !Forwards);
435+ // Propagate to the incoming/outgoing state if that is also "ANY".
436+ if (OtherState == ZAState::ANY)
437+ OtherState = PropagatedState;
438+ // Push any successors/predecessors that may need updating to the
439+ // worklist.
440+ for (MachineBasicBlock *SuccOrPred :
441+ Forwards ? successors (MBB) : predecessors (MBB)) {
442+ auto &SuccOrPredBlockInfo = State.Blocks [SuccOrPred->getNumber ()];
443+ if (!isLegalEdgeBundleZAState (
444+ GetBlockState (SuccOrPredBlockInfo, Forwards)))
445+ Worklist.push_back (SuccOrPred);
446+ }
447+ }
384448 }
385449}
386450
387451void MachineSMEABI::assignBundleZAStates () {
388452 State.BundleStates .resize (Bundles->getNumBundles ());
453+
389454 for (unsigned I = 0 , E = Bundles->getNumBundles (); I != E; ++I) {
390455 LLVM_DEBUG (dbgs () << " Assigning ZA state for edge bundle: " << I << ' \n ' );
391456
392457 // Attempt to assign a ZA state for this bundle that minimizes state
393458 // transitions. Edges within loops are given a higher weight as we assume
394459 // they will be executed more than once.
395- // TODO: We should propagate desired incoming/outgoing states through blocks
396- // that have the "ANY" state first to make better global decisions.
397460 int EdgeStateCounts[ZAState::NUM_ZA_STATE] = {0 };
398461 for (unsigned BlockID : Bundles->getBlocks (I)) {
399462 LLVM_DEBUG (dbgs () << " - bb." << BlockID);
400463
401- const BlockInfo &Block = State.Blocks [BlockID];
402- if (Block.Insts .empty ()) {
403- LLVM_DEBUG (dbgs () << " (no state preference)\n " );
404- continue ;
405- }
464+ BlockInfo &Block = State.Blocks [BlockID];
406465 bool IsLoop = MLI && MLI->getLoopFor (MF->getBlockNumbered (BlockID));
407466 bool InEdge = Bundles->getBundle (BlockID, /* Out=*/ false ) == I;
408467 bool OutEdge = Bundles->getBundle (BlockID, /* Out=*/ true ) == I;
@@ -411,26 +470,28 @@ void MachineSMEABI::assignBundleZAStates() {
411470 LLVM_DEBUG (dbgs () << " IsLoop" );
412471
413472 LLVM_DEBUG (dbgs () << " (EdgeWeight: " << EdgeWeight << ' )' );
414- ZAState DesiredIncomingState = Block.Insts .front ().NeededState ;
415- if (InEdge && isLegalEdgeBundleZAState (DesiredIncomingState)) {
416- EdgeStateCounts[DesiredIncomingState] += EdgeWeight;
473+ bool LegalInEdge =
474+ InEdge && isLegalEdgeBundleZAState (Block.DesiredIncomingState );
475+ bool LegalOutEgde =
476+ OutEdge && isLegalEdgeBundleZAState (Block.DesiredOutgoingState );
477+ if (LegalInEdge) {
417478 LLVM_DEBUG (dbgs () << " DesiredIncomingState: "
418- << getZAStateString (DesiredIncomingState));
479+ << getZAStateString (Block.DesiredIncomingState ));
480+ EdgeStateCounts[Block.DesiredIncomingState ] += EdgeWeight;
419481 }
420- ZAState DesiredOutgoingState = Block.Insts .back ().NeededState ;
421- if (OutEdge && isLegalEdgeBundleZAState (DesiredOutgoingState)) {
422- EdgeStateCounts[DesiredOutgoingState] += EdgeWeight;
482+ if (LegalOutEgde) {
423483 LLVM_DEBUG (dbgs () << " DesiredOutgoingState: "
424- << getZAStateString (DesiredOutgoingState));
484+ << getZAStateString (Block.DesiredOutgoingState ));
485+ EdgeStateCounts[Block.DesiredOutgoingState ] += EdgeWeight;
425486 }
487+ if (!LegalInEdge && !LegalOutEgde)
488+ LLVM_DEBUG (dbgs () << " (no state preference)" );
426489 LLVM_DEBUG (dbgs () << ' \n ' );
427490 }
428491
429492 ZAState BundleState =
430493 ZAState (max_element (EdgeStateCounts) - EdgeStateCounts);
431494
432- // Force ZA to be active in bundles that don't have a preferred state.
433- // TODO: Something better here (to avoid extra mode switches).
434495 if (BundleState == ZAState::ANY)
435496 BundleState = ZAState::ACTIVE;
436497
@@ -839,6 +900,10 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
839900 MLI = &getAnalysis<MachineLoopInfoWrapperPass>().getLI ();
840901
841902 collectNeededZAStates (SMEFnAttrs);
903+ if (OptLevel != CodeGenOptLevel::None) {
904+ for (bool Forwards : {true , false })
905+ propagateDesiredStates (Forwards);
906+ }
842907 assignBundleZAStates ();
843908 insertStateChanges ();
844909
0 commit comments