Skip to content

Commit ce36365

Browse files
committed
[AArch64][SME] Propagate desired ZA states in the MachineSMEABIPass
This patch adds a propagation step to the MachineSMEABIPass that propagates desired ZA states forwards/backwards (from predecessors to successors, or vice versa). The aim of this is to pick better ZA states for edge bundles, as when many (or all) blocks in a bundle do not have a preferred ZA state, the ZA state assigned to a bundle can be less than ideal. An important case is nested loops, where only the inner loop has a preferred ZA state. Here we'd like to propagate the ZA state up from the inner loop to the outer loops (to avoid saves/restores in any loop). Change-Id: I39f9c7d7608e2fa070be2fb88351b4d1d0079041
1 parent 42b551a commit ce36365

File tree

6 files changed

+472
-187
lines changed

6 files changed

+472
-187
lines changed

llvm/lib/Target/AArch64/MachineSMEABIPass.cpp

Lines changed: 83 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,11 @@ struct MachineSMEABI : public MachineFunctionPass {
203203
/// E.g., ACTIVE -> LOCAL_SAVED will insert code required to save ZA.
204204
void insertStateChanges();
205205

206+
/// Propagates desired states forwards (from predecessors -> successors) if
207+
/// \p Forwards, otherwise, propagates backwards (from successors ->
208+
/// predecessors).
209+
void propagateDesiredStates(bool Forwards = true);
210+
206211
// Emission routines for private and shared ZA functions (using lazy saves).
207212
void emitNewZAPrologue(MachineBasicBlock &MBB,
208213
MachineBasicBlock::iterator MBBI);
@@ -277,8 +282,10 @@ struct MachineSMEABI : public MachineFunctionPass {
277282
/// Contains the needed ZA state for each instruction in a block.
278283
/// Instructions that do not require a ZA state are not recorded.
279284
struct BlockInfo {
280-
ZAState FixedEntryState{ZAState::ANY};
281285
SmallVector<InstInfo> Insts;
286+
ZAState FixedEntryState{ZAState::ANY};
287+
ZAState DesiredIncomingState{ZAState::ANY};
288+
ZAState DesiredOutgoingState{ZAState::ANY};
282289
LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
283290
LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
284291
};
@@ -371,51 +378,105 @@ void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
371378

372379
// Reverse vector (as we had to iterate backwards for liveness).
373380
std::reverse(Block.Insts.begin(), Block.Insts.end());
381+
382+
// Record the desired states on entry/exit of this block. These are the
383+
// states that would not incur a state transition.
384+
if (!Block.Insts.empty()) {
385+
Block.DesiredIncomingState = Block.Insts.front().NeededState;
386+
Block.DesiredOutgoingState = Block.Insts.back().NeededState;
387+
}
388+
}
389+
}
390+
391+
void MachineSMEABI::propagateDesiredStates(bool Forwards) {
392+
// If `Forwards`, this propagates desired states from predecessors to
393+
// successors, otherwise, this propagates states from successors to
394+
// predecessors.
395+
auto GetBlockState = [](BlockInfo &Block, bool Incoming) -> ZAState & {
396+
return Incoming ? Block.DesiredIncomingState : Block.DesiredOutgoingState;
397+
};
398+
399+
SmallVector<MachineBasicBlock *> Worklist;
400+
for (auto [BlockID, BlockInfo] : enumerate(State.Blocks)) {
401+
if (!isLegalEdgeBundleZAState(GetBlockState(BlockInfo, Forwards)))
402+
Worklist.push_back(MF->getBlockNumbered(BlockID));
403+
}
404+
405+
while (!Worklist.empty()) {
406+
MachineBasicBlock *MBB = Worklist.pop_back_val();
407+
auto &BlockInfo = State.Blocks[MBB->getNumber()];
408+
409+
// Pick a legal edge bundle state that matches the majority of
410+
// predecessors/successors.
411+
int StateCounts[ZAState::NUM_ZA_STATE] = {0};
412+
for (MachineBasicBlock *PredOrSucc :
413+
Forwards ? predecessors(MBB) : successors(MBB)) {
414+
auto &PredOrSuccBlockInfo = State.Blocks[PredOrSucc->getNumber()];
415+
auto ZAState = GetBlockState(PredOrSuccBlockInfo, !Forwards);
416+
if (isLegalEdgeBundleZAState(ZAState))
417+
StateCounts[ZAState]++;
418+
}
419+
420+
ZAState PropagatedState = ZAState(max_element(StateCounts) - StateCounts);
421+
auto &CurrentState = GetBlockState(BlockInfo, Forwards);
422+
if (PropagatedState != CurrentState) {
423+
CurrentState = PropagatedState;
424+
auto &OtherState = GetBlockState(BlockInfo, !Forwards);
425+
// Propagate to the incoming/outgoing state if that is also "ANY".
426+
if (OtherState == ZAState::ANY)
427+
OtherState = PropagatedState;
428+
// Push any successors/predecessors that may need updating to the
429+
// worklist.
430+
for (MachineBasicBlock *SuccOrPred :
431+
Forwards ? successors(MBB) : predecessors(MBB)) {
432+
auto &SuccOrPredBlockInfo = State.Blocks[SuccOrPred->getNumber()];
433+
if (!isLegalEdgeBundleZAState(
434+
GetBlockState(SuccOrPredBlockInfo, Forwards)))
435+
Worklist.push_back(SuccOrPred);
436+
}
437+
}
374438
}
375439
}
376440

377441
void MachineSMEABI::assignBundleZAStates() {
378442
State.BundleStates.resize(Bundles->getNumBundles());
443+
379444
for (unsigned I = 0, E = Bundles->getNumBundles(); I != E; ++I) {
380445
LLVM_DEBUG(dbgs() << "Assigning ZA state for edge bundle: " << I << '\n');
381446

382447
// Attempt to assign a ZA state for this bundle that minimizes state
383448
// transitions. Edges within loops are given a higher weight as we assume
384449
// they will be executed more than once.
385-
// TODO: We should propagate desired incoming/outgoing states through blocks
386-
// that have the "ANY" state first to make better global decisions.
387450
int EdgeStateCounts[ZAState::NUM_ZA_STATE] = {0};
388451
for (unsigned BlockID : Bundles->getBlocks(I)) {
389452
LLVM_DEBUG(dbgs() << "- bb." << BlockID);
390453

391-
const BlockInfo &Block = State.Blocks[BlockID];
392-
if (Block.Insts.empty()) {
393-
LLVM_DEBUG(dbgs() << " (no state preference)\n");
394-
continue;
395-
}
454+
BlockInfo &Block = State.Blocks[BlockID];
396455
bool InEdge = Bundles->getBundle(BlockID, /*Out=*/false) == I;
397456
bool OutEdge = Bundles->getBundle(BlockID, /*Out=*/true) == I;
398457

399-
ZAState DesiredIncomingState = Block.Insts.front().NeededState;
400-
if (InEdge && isLegalEdgeBundleZAState(DesiredIncomingState)) {
401-
EdgeStateCounts[DesiredIncomingState]++;
458+
bool LegalInEdge =
459+
InEdge && isLegalEdgeBundleZAState(Block.DesiredIncomingState);
460+
bool LegalOutEgde =
461+
OutEdge && isLegalEdgeBundleZAState(Block.DesiredOutgoingState);
462+
if (LegalInEdge) {
402463
LLVM_DEBUG(dbgs() << " DesiredIncomingState: "
403-
<< getZAStateString(DesiredIncomingState));
464+
<< getZAStateString(Block.DesiredIncomingState));
465+
EdgeStateCounts[Block.DesiredIncomingState]++;
404466
}
405-
ZAState DesiredOutgoingState = Block.Insts.back().NeededState;
406-
if (OutEdge && isLegalEdgeBundleZAState(DesiredOutgoingState)) {
407-
EdgeStateCounts[DesiredOutgoingState]++;
467+
if (LegalOutEgde) {
408468
LLVM_DEBUG(dbgs() << " DesiredOutgoingState: "
409-
<< getZAStateString(DesiredOutgoingState));
469+
<< getZAStateString(Block.DesiredOutgoingState));
470+
EdgeStateCounts[Block.DesiredOutgoingState]++;
410471
}
472+
if (!LegalInEdge && !LegalOutEgde)
473+
LLVM_DEBUG(dbgs() << " (no state preference)");
411474
LLVM_DEBUG(dbgs() << '\n');
412475
}
413476

414477
ZAState BundleState =
415478
ZAState(max_element(EdgeStateCounts) - EdgeStateCounts);
416479

417-
// Force ZA to be active in bundles that don't have a preferred state.
418-
// TODO: Something better here (to avoid extra mode switches).
419480
if (BundleState == ZAState::ANY)
420481
BundleState = ZAState::ACTIVE;
421482

@@ -821,6 +882,10 @@ bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
821882
MRI = &MF.getRegInfo();
822883

823884
collectNeededZAStates(SMEFnAttrs);
885+
if (OptLevel != CodeGenOptLevel::None) {
886+
for (bool Forwards : {true, false})
887+
propagateDesiredStates(Forwards);
888+
}
824889
assignBundleZAStates();
825890
insertStateChanges();
826891

llvm/test/CodeGen/AArch64/sme-agnostic-za.ll

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,6 @@ define i64 @test_many_callee_arguments(
361361
ret i64 %ret
362362
}
363363

364-
; FIXME: The new lowering should avoid saves/restores in the probing loop.
365364
define void @agnostic_za_buffer_alloc_with_stack_probes() nounwind "aarch64_za_state_agnostic" "probe-stack"="inline-asm" "stack-probe-size"="65536"{
366365
; CHECK-LABEL: agnostic_za_buffer_alloc_with_stack_probes:
367366
; CHECK: // %bb.0:
@@ -399,18 +398,14 @@ define void @agnostic_za_buffer_alloc_with_stack_probes() nounwind "aarch64_za_s
399398
; CHECK-NEWLOWERING-NEXT: bl __arm_sme_state_size
400399
; CHECK-NEWLOWERING-NEXT: mov x8, sp
401400
; CHECK-NEWLOWERING-NEXT: sub x19, x8, x0
401+
; CHECK-NEWLOWERING-NEXT: mov x0, x19
402+
; CHECK-NEWLOWERING-NEXT: bl __arm_sme_save
402403
; CHECK-NEWLOWERING-NEXT: .LBB7_1: // =>This Inner Loop Header: Depth=1
403404
; CHECK-NEWLOWERING-NEXT: sub sp, sp, #16, lsl #12 // =65536
404405
; CHECK-NEWLOWERING-NEXT: cmp sp, x19
405-
; CHECK-NEWLOWERING-NEXT: mov x0, x19
406-
; CHECK-NEWLOWERING-NEXT: mrs x8, NZCV
407-
; CHECK-NEWLOWERING-NEXT: bl __arm_sme_save
408-
; CHECK-NEWLOWERING-NEXT: msr NZCV, x8
409406
; CHECK-NEWLOWERING-NEXT: b.le .LBB7_3
410407
; CHECK-NEWLOWERING-NEXT: // %bb.2: // in Loop: Header=BB7_1 Depth=1
411-
; CHECK-NEWLOWERING-NEXT: mov x0, x19
412408
; CHECK-NEWLOWERING-NEXT: str xzr, [sp]
413-
; CHECK-NEWLOWERING-NEXT: bl __arm_sme_restore
414409
; CHECK-NEWLOWERING-NEXT: b .LBB7_1
415410
; CHECK-NEWLOWERING-NEXT: .LBB7_3:
416411
; CHECK-NEWLOWERING-NEXT: mov sp, x19

llvm/test/CodeGen/AArch64/sme-za-control-flow.ll

Lines changed: 40 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -228,65 +228,34 @@ exit:
228228
ret void
229229
}
230230

231-
; FIXME: The codegen for this case could be improved (by tuning weights).
232-
; Here the ZA save has been hoisted out of the conditional, but would be better
233-
; to sink it.
234231
define void @cond_private_za_call(i1 %cond) "aarch64_inout_za" nounwind {
235-
; CHECK-LABEL: cond_private_za_call:
236-
; CHECK: // %bb.0:
237-
; CHECK-NEXT: stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
238-
; CHECK-NEXT: mov x29, sp
239-
; CHECK-NEXT: sub sp, sp, #16
240-
; CHECK-NEXT: rdsvl x8, #1
241-
; CHECK-NEXT: mov x9, sp
242-
; CHECK-NEXT: msub x9, x8, x8, x9
243-
; CHECK-NEXT: mov sp, x9
244-
; CHECK-NEXT: stp x9, x8, [x29, #-16]
245-
; CHECK-NEXT: tbz w0, #0, .LBB3_4
246-
; CHECK-NEXT: // %bb.1: // %private_za_call
247-
; CHECK-NEXT: sub x8, x29, #16
248-
; CHECK-NEXT: msr TPIDR2_EL0, x8
249-
; CHECK-NEXT: bl private_za_call
250-
; CHECK-NEXT: smstart za
251-
; CHECK-NEXT: mrs x8, TPIDR2_EL0
252-
; CHECK-NEXT: sub x0, x29, #16
253-
; CHECK-NEXT: cbnz x8, .LBB3_3
254-
; CHECK-NEXT: // %bb.2: // %private_za_call
255-
; CHECK-NEXT: bl __arm_tpidr2_restore
256-
; CHECK-NEXT: .LBB3_3: // %private_za_call
257-
; CHECK-NEXT: msr TPIDR2_EL0, xzr
258-
; CHECK-NEXT: .LBB3_4: // %exit
259-
; CHECK-NEXT: mov sp, x29
260-
; CHECK-NEXT: ldp x29, x30, [sp], #16 // 16-byte Folded Reload
261-
; CHECK-NEXT: b shared_za_call
262-
;
263-
; CHECK-NEWLOWERING-LABEL: cond_private_za_call:
264-
; CHECK-NEWLOWERING: // %bb.0:
265-
; CHECK-NEWLOWERING-NEXT: stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
266-
; CHECK-NEWLOWERING-NEXT: mov x29, sp
267-
; CHECK-NEWLOWERING-NEXT: sub sp, sp, #16
268-
; CHECK-NEWLOWERING-NEXT: rdsvl x8, #1
269-
; CHECK-NEWLOWERING-NEXT: mov x9, sp
270-
; CHECK-NEWLOWERING-NEXT: msub x9, x8, x8, x9
271-
; CHECK-NEWLOWERING-NEXT: mov sp, x9
272-
; CHECK-NEWLOWERING-NEXT: sub x10, x29, #16
273-
; CHECK-NEWLOWERING-NEXT: stp x9, x8, [x29, #-16]
274-
; CHECK-NEWLOWERING-NEXT: msr TPIDR2_EL0, x10
275-
; CHECK-NEWLOWERING-NEXT: tbz w0, #0, .LBB3_2
276-
; CHECK-NEWLOWERING-NEXT: // %bb.1: // %private_za_call
277-
; CHECK-NEWLOWERING-NEXT: bl private_za_call
278-
; CHECK-NEWLOWERING-NEXT: .LBB3_2: // %exit
279-
; CHECK-NEWLOWERING-NEXT: smstart za
280-
; CHECK-NEWLOWERING-NEXT: mrs x8, TPIDR2_EL0
281-
; CHECK-NEWLOWERING-NEXT: sub x0, x29, #16
282-
; CHECK-NEWLOWERING-NEXT: cbnz x8, .LBB3_4
283-
; CHECK-NEWLOWERING-NEXT: // %bb.3: // %exit
284-
; CHECK-NEWLOWERING-NEXT: bl __arm_tpidr2_restore
285-
; CHECK-NEWLOWERING-NEXT: .LBB3_4: // %exit
286-
; CHECK-NEWLOWERING-NEXT: msr TPIDR2_EL0, xzr
287-
; CHECK-NEWLOWERING-NEXT: mov sp, x29
288-
; CHECK-NEWLOWERING-NEXT: ldp x29, x30, [sp], #16 // 16-byte Folded Reload
289-
; CHECK-NEWLOWERING-NEXT: b shared_za_call
232+
; CHECK-COMMON-LABEL: cond_private_za_call:
233+
; CHECK-COMMON: // %bb.0:
234+
; CHECK-COMMON-NEXT: stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
235+
; CHECK-COMMON-NEXT: mov x29, sp
236+
; CHECK-COMMON-NEXT: sub sp, sp, #16
237+
; CHECK-COMMON-NEXT: rdsvl x8, #1
238+
; CHECK-COMMON-NEXT: mov x9, sp
239+
; CHECK-COMMON-NEXT: msub x9, x8, x8, x9
240+
; CHECK-COMMON-NEXT: mov sp, x9
241+
; CHECK-COMMON-NEXT: stp x9, x8, [x29, #-16]
242+
; CHECK-COMMON-NEXT: tbz w0, #0, .LBB3_4
243+
; CHECK-COMMON-NEXT: // %bb.1: // %private_za_call
244+
; CHECK-COMMON-NEXT: sub x8, x29, #16
245+
; CHECK-COMMON-NEXT: msr TPIDR2_EL0, x8
246+
; CHECK-COMMON-NEXT: bl private_za_call
247+
; CHECK-COMMON-NEXT: smstart za
248+
; CHECK-COMMON-NEXT: mrs x8, TPIDR2_EL0
249+
; CHECK-COMMON-NEXT: sub x0, x29, #16
250+
; CHECK-COMMON-NEXT: cbnz x8, .LBB3_3
251+
; CHECK-COMMON-NEXT: // %bb.2: // %private_za_call
252+
; CHECK-COMMON-NEXT: bl __arm_tpidr2_restore
253+
; CHECK-COMMON-NEXT: .LBB3_3: // %private_za_call
254+
; CHECK-COMMON-NEXT: msr TPIDR2_EL0, xzr
255+
; CHECK-COMMON-NEXT: .LBB3_4: // %exit
256+
; CHECK-COMMON-NEXT: mov sp, x29
257+
; CHECK-COMMON-NEXT: ldp x29, x30, [sp], #16 // 16-byte Folded Reload
258+
; CHECK-COMMON-NEXT: b shared_za_call
290259
br i1 %cond, label %private_za_call, label %exit
291260

292261
private_za_call:
@@ -910,7 +879,7 @@ define void @loop_with_external_entry(i1 %c1, i1 %c2) "aarch64_inout_za" nounwin
910879
; CHECK-NEWLOWERING-LABEL: loop_with_external_entry:
911880
; CHECK-NEWLOWERING: // %bb.0: // %entry
912881
; CHECK-NEWLOWERING-NEXT: stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
913-
; CHECK-NEWLOWERING-NEXT: str x19, [sp, #16] // 8-byte Folded Spill
882+
; CHECK-NEWLOWERING-NEXT: stp x20, x19, [sp, #16] // 16-byte Folded Spill
914883
; CHECK-NEWLOWERING-NEXT: mov x29, sp
915884
; CHECK-NEWLOWERING-NEXT: sub sp, sp, #16
916885
; CHECK-NEWLOWERING-NEXT: rdsvl x8, #1
@@ -923,23 +892,27 @@ define void @loop_with_external_entry(i1 %c1, i1 %c2) "aarch64_inout_za" nounwin
923892
; CHECK-NEWLOWERING-NEXT: // %bb.1: // %init
924893
; CHECK-NEWLOWERING-NEXT: bl shared_za_call
925894
; CHECK-NEWLOWERING-NEXT: .LBB11_2: // %loop.preheader
926-
; CHECK-NEWLOWERING-NEXT: sub x8, x29, #16
927-
; CHECK-NEWLOWERING-NEXT: msr TPIDR2_EL0, x8
895+
; CHECK-NEWLOWERING-NEXT: sub x20, x29, #16
896+
; CHECK-NEWLOWERING-NEXT: b .LBB11_4
928897
; CHECK-NEWLOWERING-NEXT: .LBB11_3: // %loop
898+
; CHECK-NEWLOWERING-NEXT: // in Loop: Header=BB11_4 Depth=1
899+
; CHECK-NEWLOWERING-NEXT: msr TPIDR2_EL0, xzr
900+
; CHECK-NEWLOWERING-NEXT: tbz w19, #0, .LBB11_6
901+
; CHECK-NEWLOWERING-NEXT: .LBB11_4: // %loop
929902
; CHECK-NEWLOWERING-NEXT: // =>This Inner Loop Header: Depth=1
903+
; CHECK-NEWLOWERING-NEXT: msr TPIDR2_EL0, x20
930904
; CHECK-NEWLOWERING-NEXT: bl private_za_call
931-
; CHECK-NEWLOWERING-NEXT: tbnz w19, #0, .LBB11_3
932-
; CHECK-NEWLOWERING-NEXT: // %bb.4: // %exit
933905
; CHECK-NEWLOWERING-NEXT: smstart za
934906
; CHECK-NEWLOWERING-NEXT: mrs x8, TPIDR2_EL0
935907
; CHECK-NEWLOWERING-NEXT: sub x0, x29, #16
936-
; CHECK-NEWLOWERING-NEXT: cbnz x8, .LBB11_6
937-
; CHECK-NEWLOWERING-NEXT: // %bb.5: // %exit
908+
; CHECK-NEWLOWERING-NEXT: cbnz x8, .LBB11_3
909+
; CHECK-NEWLOWERING-NEXT: // %bb.5: // %loop
910+
; CHECK-NEWLOWERING-NEXT: // in Loop: Header=BB11_4 Depth=1
938911
; CHECK-NEWLOWERING-NEXT: bl __arm_tpidr2_restore
912+
; CHECK-NEWLOWERING-NEXT: b .LBB11_3
939913
; CHECK-NEWLOWERING-NEXT: .LBB11_6: // %exit
940-
; CHECK-NEWLOWERING-NEXT: msr TPIDR2_EL0, xzr
941914
; CHECK-NEWLOWERING-NEXT: mov sp, x29
942-
; CHECK-NEWLOWERING-NEXT: ldr x19, [sp, #16] // 8-byte Folded Reload
915+
; CHECK-NEWLOWERING-NEXT: ldp x20, x19, [sp, #16] // 16-byte Folded Reload
943916
; CHECK-NEWLOWERING-NEXT: ldp x29, x30, [sp], #32 // 16-byte Folded Reload
944917
; CHECK-NEWLOWERING-NEXT: ret
945918
entry:

0 commit comments

Comments
 (0)