Skip to content

Commit ae3ec41

Browse files
committed
[AArch64][SME] Handle zeroing ZA and ZT0 in functions with ZT0 state
In the MachineSMEABIPass, if we have a function with ZT0 state, then there are some additional cases where we need to zero ZA and ZT0. If the function has a private ZA interface, i.e., new ZT0 (and new ZA if present). Then ZT0/ZA must be zeroed when committing the incoming ZA save. If the function has a shared ZA interface, e.g. new ZA and shared ZT0. Then ZA must be zeroed on function entry (without a ZA save commit). The logic in the ABI pass has been reworked to use an "ENTRY" state to handle this (rather than the more specific "CALLER_DORMANT" state). Change-Id: Ib91e9b13ffa4752320fe6a7a720afe919cf00198
1 parent aea31f3 commit ae3ec41

File tree

3 files changed

+68
-69
lines changed

3 files changed

+68
-69
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8735,15 +8735,6 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
87358735
}
87368736
}
87378737

8738-
if (getTM().useNewSMEABILowering()) {
8739-
// Clear new ZT0 state. TODO: Move this to the SME ABI pass.
8740-
if (Attrs.isNewZT0())
8741-
Chain = DAG.getNode(
8742-
ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
8743-
DAG.getConstant(Intrinsic::aarch64_sme_zero_zt, DL, MVT::i32),
8744-
DAG.getTargetConstant(0, DL, MVT::i32));
8745-
}
8746-
87478738
return Chain;
87488739
}
87498740

llvm/lib/Target/AArch64/MachineSMEABIPass.cpp

Lines changed: 57 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ enum ZAState {
8282
// A ZA save has been set up or committed (i.e. ZA is dormant or off)
8383
LOCAL_SAVED,
8484

85-
// ZA is off or a lazy save has been set up by the caller
86-
CALLER_DORMANT,
85+
// The ZA/ZT0 state on entry to the function.
86+
ENTRY,
8787

8888
// ZA is off
8989
OFF,
@@ -200,7 +200,7 @@ StringRef getZAStateString(ZAState State) {
200200
MAKE_CASE(ZAState::ANY)
201201
MAKE_CASE(ZAState::ACTIVE)
202202
MAKE_CASE(ZAState::LOCAL_SAVED)
203-
MAKE_CASE(ZAState::CALLER_DORMANT)
203+
MAKE_CASE(ZAState::ENTRY)
204204
MAKE_CASE(ZAState::OFF)
205205
default:
206206
llvm_unreachable("Unexpected ZAState");
@@ -281,8 +281,8 @@ struct MachineSMEABI : public MachineFunctionPass {
281281
void propagateDesiredStates(FunctionInfo &FnInfo, bool Forwards = true);
282282

283283
// Emission routines for private and shared ZA functions (using lazy saves).
284-
void emitNewZAPrologue(MachineBasicBlock &MBB,
285-
MachineBasicBlock::iterator MBBI);
284+
void emitSMEPrologue(MachineBasicBlock &MBB,
285+
MachineBasicBlock::iterator MBBI);
286286
void emitRestoreLazySave(EmitContext &, MachineBasicBlock &MBB,
287287
MachineBasicBlock::iterator MBBI,
288288
LiveRegs PhysLiveRegs);
@@ -395,9 +395,7 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
395395

396396
if (MBB.isEntryBlock()) {
397397
// Entry block:
398-
Block.FixedEntryState = SMEFnAttrs.hasPrivateZAInterface()
399-
? ZAState::CALLER_DORMANT
400-
: ZAState::ACTIVE;
398+
Block.FixedEntryState = ZAState::ENTRY;
401399
} else if (MBB.isEHPad()) {
402400
// EH entry block:
403401
Block.FixedEntryState = ZAState::LOCAL_SAVED;
@@ -815,32 +813,49 @@ void MachineSMEABI::emitAllocateLazySaveBuffer(
815813
}
816814
}
817815

818-
void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB,
819-
MachineBasicBlock::iterator MBBI) {
816+
static constexpr unsigned ZERO_ALL_ZA_MASK = 0b11111111;
817+
818+
void MachineSMEABI::emitSMEPrologue(MachineBasicBlock &MBB,
819+
MachineBasicBlock::iterator MBBI) {
820820
auto *TLI = Subtarget->getTargetLowering();
821821
DebugLoc DL = getDebugLoc(MBB, MBBI);
822822

823-
// Get current TPIDR2_EL0.
824-
Register TPIDR2EL0 = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
825-
BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS))
826-
.addReg(TPIDR2EL0, RegState::Define)
827-
.addImm(AArch64SysReg::TPIDR2_EL0);
828-
// If TPIDR2_EL0 is non-zero, commit the lazy save.
829-
// NOTE: Functions that only use ZT0 don't need to zero ZA.
830-
bool ZeroZA = AFI->getSMEFnAttrs().hasZAState();
831-
auto CommitZASave =
832-
BuildMI(MBB, MBBI, DL, TII->get(AArch64::CommitZASavePseudo))
833-
.addReg(TPIDR2EL0)
834-
.addImm(ZeroZA ? 1 : 0)
835-
.addImm(/*ZeroZT0=*/false)
836-
.addExternalSymbol(TLI->getLibcallName(RTLIB::SMEABI_TPIDR2_SAVE))
837-
.addRegMask(TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
838-
if (ZeroZA)
839-
CommitZASave.addDef(AArch64::ZAB0, RegState::ImplicitDefine);
840-
// Enable ZA (as ZA could have previously been in the OFF state).
841-
BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1))
842-
.addImm(AArch64SVCR::SVCRZA)
843-
.addImm(1);
823+
bool ZeroZA = AFI->getSMEFnAttrs().isNewZA();
824+
bool ZeroZT0 = AFI->getSMEFnAttrs().isNewZT0();
825+
if (AFI->getSMEFnAttrs().hasPrivateZAInterface()) {
826+
// Get current TPIDR2_EL0.
827+
Register TPIDR2EL0 = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
828+
BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS))
829+
.addReg(TPIDR2EL0, RegState::Define)
830+
.addImm(AArch64SysReg::TPIDR2_EL0);
831+
// If TPIDR2_EL0 is non-zero, commit the lazy save.
832+
// NOTE: Functions that only use ZT0 don't need to zero ZA.
833+
auto CommitZASave =
834+
BuildMI(MBB, MBBI, DL, TII->get(AArch64::CommitZASavePseudo))
835+
.addReg(TPIDR2EL0)
836+
.addImm(ZeroZA)
837+
.addImm(ZeroZT0)
838+
.addExternalSymbol(TLI->getLibcallName(RTLIB::SMEABI_TPIDR2_SAVE))
839+
.addRegMask(TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
840+
if (ZeroZA)
841+
CommitZASave.addDef(AArch64::ZAB0, RegState::ImplicitDefine);
842+
if (ZeroZT0)
843+
CommitZASave.addDef(AArch64::ZT0, RegState::ImplicitDefine);
844+
// Enable ZA (as ZA could have previously been in the OFF state).
845+
BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1))
846+
.addImm(AArch64SVCR::SVCRZA)
847+
.addImm(1);
848+
} else if (AFI->getSMEFnAttrs().hasSharedZAInterface()) {
849+
if (ZeroZA) {
850+
BuildMI(MBB, MBBI, DL, TII->get(AArch64::ZERO_M))
851+
.addImm(ZERO_ALL_ZA_MASK)
852+
.addDef(AArch64::ZAB0, RegState::ImplicitDefine);
853+
}
854+
if (ZeroZT0) {
855+
DebugLoc DL = getDebugLoc(MBB, MBBI);
856+
BuildMI(MBB, MBBI, DL, TII->get(AArch64::ZERO_T)).addDef(AArch64::ZT0);
857+
}
858+
}
844859
}
845860

846861
void MachineSMEABI::emitFullZASaveRestore(EmitContext &Context,
@@ -922,19 +937,19 @@ void MachineSMEABI::emitStateChange(EmitContext &Context,
922937
if (From == ZAState::ANY || To == ZAState::ANY)
923938
return;
924939

925-
// If we're exiting from the CALLER_DORMANT state that means this new ZA
926-
// function did not touch ZA (so ZA was never turned on).
927-
if (From == ZAState::CALLER_DORMANT && To == ZAState::OFF)
940+
// If we're exiting from the ENTRY state that means that the function has not
941+
// used ZA, so in the case of private ZA/ZT0 functions we can omit any set up.
942+
if (From == ZAState::ENTRY && To == ZAState::OFF)
928943
return;
929944

945+
SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();
946+
930947
// TODO: Avoid setting up the save buffer if there's no transition to
931948
// LOCAL_SAVED.
932-
if (From == ZAState::CALLER_DORMANT) {
933-
assert(AFI->getSMEFnAttrs().hasPrivateZAInterface() &&
934-
"CALLER_DORMANT state requires private ZA interface");
949+
if (From == ZAState::ENTRY) {
935950
assert(&MBB == &MBB.getParent()->front() &&
936-
"CALLER_DORMANT state only valid in entry block");
937-
emitNewZAPrologue(MBB, MBB.getFirstNonPHI());
951+
"ENTRY state only valid in entry block");
952+
emitSMEPrologue(MBB, MBB.getFirstNonPHI());
938953
if (To == ZAState::ACTIVE)
939954
return; // Nothing more to do (ZA is active after the prologue).
940955

@@ -949,9 +964,9 @@ void MachineSMEABI::emitStateChange(EmitContext &Context,
949964
else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE)
950965
emitZARestore(Context, MBB, InsertPt, PhysLiveRegs);
951966
else if (To == ZAState::OFF) {
952-
assert(From != ZAState::CALLER_DORMANT &&
953-
"CALLER_DORMANT to OFF should have already been handled");
954-
assert(!AFI->getSMEFnAttrs().hasAgnosticZAInterface() &&
967+
assert(From != ZAState::ENTRY &&
968+
"ENTRY to OFF should have already been handled");
969+
assert(!SMEFnAttrs.hasAgnosticZAInterface() &&
955970
"Should not turn ZA off in agnostic ZA function");
956971
emitZAOff(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED);
957972
} else {

llvm/test/CodeGen/AArch64/sme-zt0-state.ll

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,9 @@ define void @zt0_new_caller_zt0_new_callee(ptr %callee) "aarch64_new_zt0" nounwi
199199
; CHECK-NEWLOWERING-NEXT: // %bb.1:
200200
; CHECK-NEWLOWERING-NEXT: bl __arm_tpidr2_save
201201
; CHECK-NEWLOWERING-NEXT: msr TPIDR2_EL0, xzr
202+
; CHECK-NEWLOWERING-NEXT: zero { zt0 }
202203
; CHECK-NEWLOWERING-NEXT: .LBB6_2:
203204
; CHECK-NEWLOWERING-NEXT: smstart za
204-
; CHECK-NEWLOWERING-NEXT: zero { zt0 }
205205
; CHECK-NEWLOWERING-NEXT: mov x19, sp
206206
; CHECK-NEWLOWERING-NEXT: str zt0, [x19]
207207
; CHECK-NEWLOWERING-NEXT: smstop za
@@ -252,9 +252,9 @@ define i64 @zt0_new_caller_abi_routine_callee() "aarch64_new_zt0" nounwind {
252252
; CHECK-NEWLOWERING-NEXT: // %bb.1:
253253
; CHECK-NEWLOWERING-NEXT: bl __arm_tpidr2_save
254254
; CHECK-NEWLOWERING-NEXT: msr TPIDR2_EL0, xzr
255+
; CHECK-NEWLOWERING-NEXT: zero { zt0 }
255256
; CHECK-NEWLOWERING-NEXT: .LBB7_2:
256257
; CHECK-NEWLOWERING-NEXT: smstart za
257-
; CHECK-NEWLOWERING-NEXT: zero { zt0 }
258258
; CHECK-NEWLOWERING-NEXT: mov x19, sp
259259
; CHECK-NEWLOWERING-NEXT: str zt0, [x19]
260260
; CHECK-NEWLOWERING-NEXT: bl __arm_sme_state
@@ -302,9 +302,9 @@ define void @zt0_new_caller(ptr %callee) "aarch64_new_zt0" nounwind {
302302
; CHECK-NEWLOWERING-NEXT: // %bb.1:
303303
; CHECK-NEWLOWERING-NEXT: bl __arm_tpidr2_save
304304
; CHECK-NEWLOWERING-NEXT: msr TPIDR2_EL0, xzr
305+
; CHECK-NEWLOWERING-NEXT: zero { zt0 }
305306
; CHECK-NEWLOWERING-NEXT: .LBB8_2:
306307
; CHECK-NEWLOWERING-NEXT: smstart za
307-
; CHECK-NEWLOWERING-NEXT: zero { zt0 }
308308
; CHECK-NEWLOWERING-NEXT: blr x0
309309
; CHECK-NEWLOWERING-NEXT: smstop za
310310
; CHECK-NEWLOWERING-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
@@ -343,9 +343,9 @@ define void @new_za_zt0_caller(ptr %callee) "aarch64_new_za" "aarch64_new_zt0" n
343343
; CHECK-NEWLOWERING-NEXT: bl __arm_tpidr2_save
344344
; CHECK-NEWLOWERING-NEXT: msr TPIDR2_EL0, xzr
345345
; CHECK-NEWLOWERING-NEXT: zero {za}
346+
; CHECK-NEWLOWERING-NEXT: zero { zt0 }
346347
; CHECK-NEWLOWERING-NEXT: .LBB9_2:
347348
; CHECK-NEWLOWERING-NEXT: smstart za
348-
; CHECK-NEWLOWERING-NEXT: zero { zt0 }
349349
; CHECK-NEWLOWERING-NEXT: blr x0
350350
; CHECK-NEWLOWERING-NEXT: smstop za
351351
; CHECK-NEWLOWERING-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
@@ -356,20 +356,13 @@ define void @new_za_zt0_caller(ptr %callee) "aarch64_new_za" "aarch64_new_zt0" n
356356

357357
; Expect clear ZA on entry
358358
define void @new_za_shared_zt0_caller(ptr %callee) "aarch64_new_za" "aarch64_in_zt0" nounwind {
359-
; CHECK-LABEL: new_za_shared_zt0_caller:
360-
; CHECK: // %bb.0:
361-
; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
362-
; CHECK-NEXT: zero {za}
363-
; CHECK-NEXT: blr x0
364-
; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
365-
; CHECK-NEXT: ret
366-
;
367-
; CHECK-NEWLOWERING-LABEL: new_za_shared_zt0_caller:
368-
; CHECK-NEWLOWERING: // %bb.0:
369-
; CHECK-NEWLOWERING-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
370-
; CHECK-NEWLOWERING-NEXT: blr x0
371-
; CHECK-NEWLOWERING-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
372-
; CHECK-NEWLOWERING-NEXT: ret
359+
; CHECK-COMMON-LABEL: new_za_shared_zt0_caller:
360+
; CHECK-COMMON: // %bb.0:
361+
; CHECK-COMMON-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
362+
; CHECK-COMMON-NEXT: zero {za}
363+
; CHECK-COMMON-NEXT: blr x0
364+
; CHECK-COMMON-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
365+
; CHECK-COMMON-NEXT: ret
373366
call void %callee() "aarch64_inout_za" "aarch64_in_zt0";
374367
ret void;
375368
}

0 commit comments

Comments
 (0)