Skip to content

Commit 6266ecb

Browse files
committed
[AArch64][SME] Propagate desired ZA states in the MachineSMEABIPass
This patch adds a propagation step to the MachineSMEABIPass that propagates desired ZA states forwards/backwards (from predecessors to successors, or vice versa). The aim of this is to pick better ZA states for edge bundles, as when many (or all) blocks in a bundle do not have a preferred ZA state, the ZA state assigned to a bundle can be less than ideal. An important case is nested loops, where only the inner loop has a preferred ZA state. Here we'd like to propagate the ZA state up from the inner loop to the outer loops (to avoid saves/restores in any loop). Change-Id: I39f9c7d7608e2fa070be2fb88351b4d1d0079041
1 parent 1a80766 commit 6266ecb

File tree

6 files changed

+504
-186
lines changed

6 files changed

+504
-186
lines changed

llvm/lib/Target/AArch64/MachineSMEABIPass.cpp

Lines changed: 115 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,10 @@ struct InstInfo {
121121
/// Contains the needed ZA state for each instruction in a block. Instructions
122122
/// that do not require a ZA state are not recorded.
123123
struct BlockInfo {
124-
ZAState FixedEntryState{ZAState::ANY};
125124
SmallVector<InstInfo> Insts;
125+
ZAState FixedEntryState{ZAState::ANY};
126+
ZAState DesiredIncomingState{ZAState::ANY};
127+
ZAState DesiredOutgoingState{ZAState::ANY};
126128
LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
127129
LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
128130
};
@@ -268,6 +270,11 @@ struct MachineSMEABI : public MachineFunctionPass {
268270
const EdgeBundles &Bundles,
269271
ArrayRef<ZAState> BundleStates);
270272

273+
/// Propagates desired states forwards (from predecessors -> successors) if
274+
/// \p Forwards, otherwise, propagates backwards (from successors ->
275+
/// predecessors).
276+
void propagateDesiredStates(FunctionInfo &FnInfo, bool Forwards = true);
277+
271278
// Emission routines for private and shared ZA functions (using lazy saves).
272279
void emitNewZAPrologue(MachineBasicBlock &MBB,
273280
MachineBasicBlock::iterator MBBI);
@@ -411,12 +418,70 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
411418

412419
// Reverse vector (as we had to iterate backwards for liveness).
413420
std::reverse(Block.Insts.begin(), Block.Insts.end());
421+
422+
// Record the desired states on entry/exit of this block. These are the
423+
// states that would not incur a state transition.
424+
if (!Block.Insts.empty()) {
425+
Block.DesiredIncomingState = Block.Insts.front().NeededState;
426+
Block.DesiredOutgoingState = Block.Insts.back().NeededState;
427+
}
414428
}
415429

416430
return FunctionInfo{std::move(Blocks), AfterSMEProloguePt,
417431
PhysLiveRegsAfterSMEPrologue};
418432
}
419433

434+
void MachineSMEABI::propagateDesiredStates(FunctionInfo &FnInfo,
435+
bool Forwards) {
436+
// If `Forwards`, this propagates desired states from predecessors to
437+
// successors, otherwise, this propagates states from successors to
438+
// predecessors.
439+
auto GetBlockState = [](BlockInfo &Block, bool Incoming) -> ZAState & {
440+
return Incoming ? Block.DesiredIncomingState : Block.DesiredOutgoingState;
441+
};
442+
443+
SmallVector<MachineBasicBlock *> Worklist;
444+
for (auto [BlockID, BlockInfo] : enumerate(FnInfo.Blocks)) {
445+
if (!isLegalEdgeBundleZAState(GetBlockState(BlockInfo, Forwards)))
446+
Worklist.push_back(MF->getBlockNumbered(BlockID));
447+
}
448+
449+
while (!Worklist.empty()) {
450+
MachineBasicBlock *MBB = Worklist.pop_back_val();
451+
auto &BlockInfo = FnInfo.Blocks[MBB->getNumber()];
452+
453+
// Pick a legal edge bundle state that matches the majority of
454+
// predecessors/successors.
455+
int StateCounts[ZAState::NUM_ZA_STATE] = {0};
456+
for (MachineBasicBlock *PredOrSucc :
457+
Forwards ? predecessors(MBB) : successors(MBB)) {
458+
auto &PredOrSuccBlockInfo = FnInfo.Blocks[PredOrSucc->getNumber()];
459+
auto ZAState = GetBlockState(PredOrSuccBlockInfo, !Forwards);
460+
if (isLegalEdgeBundleZAState(ZAState))
461+
StateCounts[ZAState]++;
462+
}
463+
464+
ZAState PropagatedState = ZAState(max_element(StateCounts) - StateCounts);
465+
auto &CurrentState = GetBlockState(BlockInfo, Forwards);
466+
if (PropagatedState != CurrentState) {
467+
CurrentState = PropagatedState;
468+
auto &OtherState = GetBlockState(BlockInfo, !Forwards);
469+
// Propagate to the incoming/outgoing state if that is also "ANY".
470+
if (OtherState == ZAState::ANY)
471+
OtherState = PropagatedState;
472+
// Push any successors/predecessors that may need updating to the
473+
// worklist.
474+
for (MachineBasicBlock *SuccOrPred :
475+
Forwards ? successors(MBB) : predecessors(MBB)) {
476+
auto &SuccOrPredBlockInfo = FnInfo.Blocks[SuccOrPred->getNumber()];
477+
if (!isLegalEdgeBundleZAState(
478+
GetBlockState(SuccOrPredBlockInfo, Forwards)))
479+
Worklist.push_back(SuccOrPred);
480+
}
481+
}
482+
}
483+
}
484+
420485
/// Assigns each edge bundle a ZA state based on the needed states of blocks
421486
/// that have incoming or outgoing edges in that bundle.
422487
SmallVector<ZAState>
@@ -429,40 +494,36 @@ MachineSMEABI::assignBundleZAStates(const EdgeBundles &Bundles,
429494
// Attempt to assign a ZA state for this bundle that minimizes state
430495
// transitions. Edges within loops are given a higher weight as we assume
431496
// they will be executed more than once.
432-
// TODO: We should propagate desired incoming/outgoing states through blocks
433-
// that have the "ANY" state first to make better global decisions.
434497
int EdgeStateCounts[ZAState::NUM_ZA_STATE] = {0};
435498
for (unsigned BlockID : Bundles.getBlocks(I)) {
436499
LLVM_DEBUG(dbgs() << "- bb." << BlockID);
437500

438501
const BlockInfo &Block = FnInfo.Blocks[BlockID];
439-
if (Block.Insts.empty()) {
440-
LLVM_DEBUG(dbgs() << " (no state preference)\n");
441-
continue;
442-
}
443502
bool InEdge = Bundles.getBundle(BlockID, /*Out=*/false) == I;
444503
bool OutEdge = Bundles.getBundle(BlockID, /*Out=*/true) == I;
445504

446-
ZAState DesiredIncomingState = Block.Insts.front().NeededState;
447-
if (InEdge && isLegalEdgeBundleZAState(DesiredIncomingState)) {
448-
EdgeStateCounts[DesiredIncomingState]++;
505+
bool LegalInEdge =
506+
InEdge && isLegalEdgeBundleZAState(Block.DesiredIncomingState);
507+
bool LegalOutEgde =
508+
OutEdge && isLegalEdgeBundleZAState(Block.DesiredOutgoingState);
509+
if (LegalInEdge) {
449510
LLVM_DEBUG(dbgs() << " DesiredIncomingState: "
450-
<< getZAStateString(DesiredIncomingState));
511+
<< getZAStateString(Block.DesiredIncomingState));
512+
EdgeStateCounts[Block.DesiredIncomingState]++;
451513
}
452-
ZAState DesiredOutgoingState = Block.Insts.back().NeededState;
453-
if (OutEdge && isLegalEdgeBundleZAState(DesiredOutgoingState)) {
454-
EdgeStateCounts[DesiredOutgoingState]++;
514+
if (LegalOutEgde) {
455515
LLVM_DEBUG(dbgs() << " DesiredOutgoingState: "
456-
<< getZAStateString(DesiredOutgoingState));
516+
<< getZAStateString(Block.DesiredOutgoingState));
517+
EdgeStateCounts[Block.DesiredOutgoingState]++;
457518
}
519+
if (!LegalInEdge && !LegalOutEgde)
520+
LLVM_DEBUG(dbgs() << " (no state preference)");
458521
LLVM_DEBUG(dbgs() << '\n');
459522
}
460523

461524
ZAState BundleState =
462525
ZAState(max_element(EdgeStateCounts) - EdgeStateCounts);
463526

464-
// Force ZA to be active in bundles that don't have a preferred state.
465-
// TODO: Something better here (to avoid extra mode switches).
466527
if (BundleState == ZAState::ANY)
467528
BundleState = ZAState::ACTIVE;
468529

@@ -858,6 +919,43 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
858919
getAnalysis<EdgeBundlesWrapperLegacy>().getEdgeBundles();
859920

860921
FunctionInfo FnInfo = collectNeededZAStates(SMEFnAttrs);
922+
923+
if (OptLevel != CodeGenOptLevel::None) {
924+
// Propagate desired states forwards then backwards. We propagate forwards
925+
// first as this propagates desired states from inner to outer loops.
926+
// Backwards propagation is then used to fill in any gaps. Note: Doing both
927+
// in one step can give poor results. For example:
928+
//
929+
// ┌─────┐
930+
// ┌─┤ BB0 ◄───┐
931+
// │ └─┬───┘ │
932+
// │ ┌─▼───◄──┐│
933+
// │ │ BB1 │ ││
934+
// │ └─┬┬──┘ ││
935+
// │ │└─────┘│
936+
// │ ┌─▼───┐ │
937+
// │ │ BB2 ├───┘
938+
// │ └─┬───┘
939+
// │ ┌─▼───┐
940+
// └─► BB3 │
941+
// └─────┘
942+
//
943+
// If:
944+
// - "BB0" and "BB2" (outer loop) has no state preference
945+
// - "BB1" (inner loop) desires the ACTIVE state on entry/exit
946+
// - "BB3" desires the LOCAL_SAVED state on entry
947+
//
948+
// If we propagate forwards first, ACTIVE is propagated from BB1 to BB2,
949+
// then from BB2 to BB0. Which results in the inner and outer loops having
950+
// the "ACTIVE" state. This avoids any state changes in the loops.
951+
//
952+
// If we propagate backwards first, we _could_ propagate LOCAL_SAVED from
953+
// BB3 to BB0, which would result in a transition from ACTIVE -> LOCAL_SAVED
954+
// in the outer loop.
955+
for (bool Forwards : {true, false})
956+
propagateDesiredStates(FnInfo, Forwards);
957+
}
958+
861959
SmallVector<ZAState> BundleStates = assignBundleZAStates(Bundles, FnInfo);
862960

863961
EmitContext Context;

llvm/test/CodeGen/AArch64/sme-agnostic-za.ll

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,6 @@ define i64 @test_many_callee_arguments(
361361
ret i64 %ret
362362
}
363363

364-
; FIXME: The new lowering should avoid saves/restores in the probing loop.
365364
define void @agnostic_za_buffer_alloc_with_stack_probes() nounwind "aarch64_za_state_agnostic" "probe-stack"="inline-asm" "stack-probe-size"="65536"{
366365
; CHECK-LABEL: agnostic_za_buffer_alloc_with_stack_probes:
367366
; CHECK: // %bb.0:
@@ -399,18 +398,14 @@ define void @agnostic_za_buffer_alloc_with_stack_probes() nounwind "aarch64_za_s
399398
; CHECK-NEWLOWERING-NEXT: bl __arm_sme_state_size
400399
; CHECK-NEWLOWERING-NEXT: mov x8, sp
401400
; CHECK-NEWLOWERING-NEXT: sub x19, x8, x0
401+
; CHECK-NEWLOWERING-NEXT: mov x0, x19
402+
; CHECK-NEWLOWERING-NEXT: bl __arm_sme_save
402403
; CHECK-NEWLOWERING-NEXT: .LBB7_1: // =>This Inner Loop Header: Depth=1
403404
; CHECK-NEWLOWERING-NEXT: sub sp, sp, #16, lsl #12 // =65536
404405
; CHECK-NEWLOWERING-NEXT: cmp sp, x19
405-
; CHECK-NEWLOWERING-NEXT: mov x0, x19
406-
; CHECK-NEWLOWERING-NEXT: mrs x8, NZCV
407-
; CHECK-NEWLOWERING-NEXT: bl __arm_sme_save
408-
; CHECK-NEWLOWERING-NEXT: msr NZCV, x8
409406
; CHECK-NEWLOWERING-NEXT: b.le .LBB7_3
410407
; CHECK-NEWLOWERING-NEXT: // %bb.2: // in Loop: Header=BB7_1 Depth=1
411-
; CHECK-NEWLOWERING-NEXT: mov x0, x19
412408
; CHECK-NEWLOWERING-NEXT: str xzr, [sp]
413-
; CHECK-NEWLOWERING-NEXT: bl __arm_sme_restore
414409
; CHECK-NEWLOWERING-NEXT: b .LBB7_1
415410
; CHECK-NEWLOWERING-NEXT: .LBB7_3:
416411
; CHECK-NEWLOWERING-NEXT: mov sp, x19

llvm/test/CodeGen/AArch64/sme-za-control-flow.ll

Lines changed: 40 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -228,65 +228,34 @@ exit:
228228
ret void
229229
}
230230

231-
; FIXME: The codegen for this case could be improved (by tuning weights).
232-
; Here the ZA save has been hoisted out of the conditional, but would be better
233-
; to sink it.
234231
define void @cond_private_za_call(i1 %cond) "aarch64_inout_za" nounwind {
235-
; CHECK-LABEL: cond_private_za_call:
236-
; CHECK: // %bb.0:
237-
; CHECK-NEXT: stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
238-
; CHECK-NEXT: mov x29, sp
239-
; CHECK-NEXT: sub sp, sp, #16
240-
; CHECK-NEXT: rdsvl x8, #1
241-
; CHECK-NEXT: mov x9, sp
242-
; CHECK-NEXT: msub x9, x8, x8, x9
243-
; CHECK-NEXT: mov sp, x9
244-
; CHECK-NEXT: stp x9, x8, [x29, #-16]
245-
; CHECK-NEXT: tbz w0, #0, .LBB3_4
246-
; CHECK-NEXT: // %bb.1: // %private_za_call
247-
; CHECK-NEXT: sub x8, x29, #16
248-
; CHECK-NEXT: msr TPIDR2_EL0, x8
249-
; CHECK-NEXT: bl private_za_call
250-
; CHECK-NEXT: smstart za
251-
; CHECK-NEXT: mrs x8, TPIDR2_EL0
252-
; CHECK-NEXT: sub x0, x29, #16
253-
; CHECK-NEXT: cbnz x8, .LBB3_3
254-
; CHECK-NEXT: // %bb.2: // %private_za_call
255-
; CHECK-NEXT: bl __arm_tpidr2_restore
256-
; CHECK-NEXT: .LBB3_3: // %private_za_call
257-
; CHECK-NEXT: msr TPIDR2_EL0, xzr
258-
; CHECK-NEXT: .LBB3_4: // %exit
259-
; CHECK-NEXT: mov sp, x29
260-
; CHECK-NEXT: ldp x29, x30, [sp], #16 // 16-byte Folded Reload
261-
; CHECK-NEXT: b shared_za_call
262-
;
263-
; CHECK-NEWLOWERING-LABEL: cond_private_za_call:
264-
; CHECK-NEWLOWERING: // %bb.0:
265-
; CHECK-NEWLOWERING-NEXT: stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
266-
; CHECK-NEWLOWERING-NEXT: mov x29, sp
267-
; CHECK-NEWLOWERING-NEXT: sub sp, sp, #16
268-
; CHECK-NEWLOWERING-NEXT: rdsvl x8, #1
269-
; CHECK-NEWLOWERING-NEXT: mov x9, sp
270-
; CHECK-NEWLOWERING-NEXT: msub x9, x8, x8, x9
271-
; CHECK-NEWLOWERING-NEXT: mov sp, x9
272-
; CHECK-NEWLOWERING-NEXT: sub x10, x29, #16
273-
; CHECK-NEWLOWERING-NEXT: stp x9, x8, [x29, #-16]
274-
; CHECK-NEWLOWERING-NEXT: msr TPIDR2_EL0, x10
275-
; CHECK-NEWLOWERING-NEXT: tbz w0, #0, .LBB3_2
276-
; CHECK-NEWLOWERING-NEXT: // %bb.1: // %private_za_call
277-
; CHECK-NEWLOWERING-NEXT: bl private_za_call
278-
; CHECK-NEWLOWERING-NEXT: .LBB3_2: // %exit
279-
; CHECK-NEWLOWERING-NEXT: smstart za
280-
; CHECK-NEWLOWERING-NEXT: mrs x8, TPIDR2_EL0
281-
; CHECK-NEWLOWERING-NEXT: sub x0, x29, #16
282-
; CHECK-NEWLOWERING-NEXT: cbnz x8, .LBB3_4
283-
; CHECK-NEWLOWERING-NEXT: // %bb.3: // %exit
284-
; CHECK-NEWLOWERING-NEXT: bl __arm_tpidr2_restore
285-
; CHECK-NEWLOWERING-NEXT: .LBB3_4: // %exit
286-
; CHECK-NEWLOWERING-NEXT: msr TPIDR2_EL0, xzr
287-
; CHECK-NEWLOWERING-NEXT: mov sp, x29
288-
; CHECK-NEWLOWERING-NEXT: ldp x29, x30, [sp], #16 // 16-byte Folded Reload
289-
; CHECK-NEWLOWERING-NEXT: b shared_za_call
232+
; CHECK-COMMON-LABEL: cond_private_za_call:
233+
; CHECK-COMMON: // %bb.0:
234+
; CHECK-COMMON-NEXT: stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
235+
; CHECK-COMMON-NEXT: mov x29, sp
236+
; CHECK-COMMON-NEXT: sub sp, sp, #16
237+
; CHECK-COMMON-NEXT: rdsvl x8, #1
238+
; CHECK-COMMON-NEXT: mov x9, sp
239+
; CHECK-COMMON-NEXT: msub x9, x8, x8, x9
240+
; CHECK-COMMON-NEXT: mov sp, x9
241+
; CHECK-COMMON-NEXT: stp x9, x8, [x29, #-16]
242+
; CHECK-COMMON-NEXT: tbz w0, #0, .LBB3_4
243+
; CHECK-COMMON-NEXT: // %bb.1: // %private_za_call
244+
; CHECK-COMMON-NEXT: sub x8, x29, #16
245+
; CHECK-COMMON-NEXT: msr TPIDR2_EL0, x8
246+
; CHECK-COMMON-NEXT: bl private_za_call
247+
; CHECK-COMMON-NEXT: smstart za
248+
; CHECK-COMMON-NEXT: mrs x8, TPIDR2_EL0
249+
; CHECK-COMMON-NEXT: sub x0, x29, #16
250+
; CHECK-COMMON-NEXT: cbnz x8, .LBB3_3
251+
; CHECK-COMMON-NEXT: // %bb.2: // %private_za_call
252+
; CHECK-COMMON-NEXT: bl __arm_tpidr2_restore
253+
; CHECK-COMMON-NEXT: .LBB3_3: // %private_za_call
254+
; CHECK-COMMON-NEXT: msr TPIDR2_EL0, xzr
255+
; CHECK-COMMON-NEXT: .LBB3_4: // %exit
256+
; CHECK-COMMON-NEXT: mov sp, x29
257+
; CHECK-COMMON-NEXT: ldp x29, x30, [sp], #16 // 16-byte Folded Reload
258+
; CHECK-COMMON-NEXT: b shared_za_call
290259
br i1 %cond, label %private_za_call, label %exit
291260

292261
private_za_call:
@@ -910,7 +879,7 @@ define void @loop_with_external_entry(i1 %c1, i1 %c2) "aarch64_inout_za" nounwin
910879
; CHECK-NEWLOWERING-LABEL: loop_with_external_entry:
911880
; CHECK-NEWLOWERING: // %bb.0: // %entry
912881
; CHECK-NEWLOWERING-NEXT: stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
913-
; CHECK-NEWLOWERING-NEXT: str x19, [sp, #16] // 8-byte Folded Spill
882+
; CHECK-NEWLOWERING-NEXT: stp x20, x19, [sp, #16] // 16-byte Folded Spill
914883
; CHECK-NEWLOWERING-NEXT: mov x29, sp
915884
; CHECK-NEWLOWERING-NEXT: sub sp, sp, #16
916885
; CHECK-NEWLOWERING-NEXT: rdsvl x8, #1
@@ -923,23 +892,27 @@ define void @loop_with_external_entry(i1 %c1, i1 %c2) "aarch64_inout_za" nounwin
923892
; CHECK-NEWLOWERING-NEXT: // %bb.1: // %init
924893
; CHECK-NEWLOWERING-NEXT: bl shared_za_call
925894
; CHECK-NEWLOWERING-NEXT: .LBB11_2: // %loop.preheader
926-
; CHECK-NEWLOWERING-NEXT: sub x8, x29, #16
927-
; CHECK-NEWLOWERING-NEXT: msr TPIDR2_EL0, x8
895+
; CHECK-NEWLOWERING-NEXT: sub x20, x29, #16
896+
; CHECK-NEWLOWERING-NEXT: b .LBB11_4
928897
; CHECK-NEWLOWERING-NEXT: .LBB11_3: // %loop
898+
; CHECK-NEWLOWERING-NEXT: // in Loop: Header=BB11_4 Depth=1
899+
; CHECK-NEWLOWERING-NEXT: msr TPIDR2_EL0, xzr
900+
; CHECK-NEWLOWERING-NEXT: tbz w19, #0, .LBB11_6
901+
; CHECK-NEWLOWERING-NEXT: .LBB11_4: // %loop
929902
; CHECK-NEWLOWERING-NEXT: // =>This Inner Loop Header: Depth=1
903+
; CHECK-NEWLOWERING-NEXT: msr TPIDR2_EL0, x20
930904
; CHECK-NEWLOWERING-NEXT: bl private_za_call
931-
; CHECK-NEWLOWERING-NEXT: tbnz w19, #0, .LBB11_3
932-
; CHECK-NEWLOWERING-NEXT: // %bb.4: // %exit
933905
; CHECK-NEWLOWERING-NEXT: smstart za
934906
; CHECK-NEWLOWERING-NEXT: mrs x8, TPIDR2_EL0
935907
; CHECK-NEWLOWERING-NEXT: sub x0, x29, #16
936-
; CHECK-NEWLOWERING-NEXT: cbnz x8, .LBB11_6
937-
; CHECK-NEWLOWERING-NEXT: // %bb.5: // %exit
908+
; CHECK-NEWLOWERING-NEXT: cbnz x8, .LBB11_3
909+
; CHECK-NEWLOWERING-NEXT: // %bb.5: // %loop
910+
; CHECK-NEWLOWERING-NEXT: // in Loop: Header=BB11_4 Depth=1
938911
; CHECK-NEWLOWERING-NEXT: bl __arm_tpidr2_restore
912+
; CHECK-NEWLOWERING-NEXT: b .LBB11_3
939913
; CHECK-NEWLOWERING-NEXT: .LBB11_6: // %exit
940-
; CHECK-NEWLOWERING-NEXT: msr TPIDR2_EL0, xzr
941914
; CHECK-NEWLOWERING-NEXT: mov sp, x29
942-
; CHECK-NEWLOWERING-NEXT: ldr x19, [sp, #16] // 8-byte Folded Reload
915+
; CHECK-NEWLOWERING-NEXT: ldp x20, x19, [sp, #16] // 16-byte Folded Reload
943916
; CHECK-NEWLOWERING-NEXT: ldp x29, x30, [sp], #32 // 16-byte Folded Reload
944917
; CHECK-NEWLOWERING-NEXT: ret
945918
entry:

0 commit comments

Comments
 (0)