Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8023,6 +8023,17 @@ static bool isPassedInFPR(EVT VT) {
(VT.isFloatingPoint() && !VT.isScalableVector());
}

static SDValue getZT0FrameIndex(MachineFrameInfo &MFI,
AArch64FunctionInfo &FuncInfo,
SelectionDAG &DAG) {
if (!FuncInfo.hasZT0SpillSlotIndex())
FuncInfo.setZT0SpillSlotIndex(MFI.CreateSpillStackObject(64, Align(16)));

return DAG.getFrameIndex(
FuncInfo.getZT0SpillSlotIndex(),
DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
Copy link
Contributor

@gbossu gbossu Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking: we make the assumption that a function cannot have multiple ZT0 scopes, which means we can always re-use the same frame index for saving/restoring.

Curious: Does that mean that functions using ZT0 can never be inlined?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can't inline a new ZT0 function (like you can't inline a new ZA function), a shared ZT0 function can be inlined as that only results in a single ZT0 value/scope.

}

SDValue AArch64TargetLowering::lowerEHPadEntry(SDValue Chain, SDLoc const &DL,
SelectionDAG &DAG) const {
assert(Chain.getOpcode() == ISD::EntryToken && "Unexpected Chain value");
Expand Down Expand Up @@ -9427,10 +9438,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
// If the caller has ZT0 state which will not be preserved by the callee,
// spill ZT0 before the call.
if (ShouldPreserveZT0) {
unsigned ZTObj = MFI.CreateSpillStackObject(64, Align(16));
ZTFrameIdx = DAG.getFrameIndex(
ZTObj,
DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
ZTFrameIdx = getZT0FrameIndex(MFI, *FuncInfo, DAG);

Chain = DAG.getNode(AArch64ISD::SAVE_ZT, DL, DAG.getVTList(MVT::Other),
{Chain, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
Expand Down
12 changes: 12 additions & 0 deletions llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
// support).
Register EarlyAllocSMESaveBuffer = AArch64::NoRegister;

// Holds the spill slot for ZT0.
int ZT0SpillSlotIndex = std::numeric_limits<int>::max();

// Note: The following properties are only used for the old SME ABI lowering:
/// The frame-index for the TPIDR2 object used for lazy saves.
TPIDR2Object TPIDR2;
Expand All @@ -265,6 +268,15 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
return EarlyAllocSMESaveBuffer;
}

void setZT0SpillSlotIndex(int FI) { ZT0SpillSlotIndex = FI; }
int getZT0SpillSlotIndex() const {
assert(hasZT0SpillSlotIndex() && "ZT0 spill slot index not set!");
return ZT0SpillSlotIndex;
}
bool hasZT0SpillSlotIndex() const {
return ZT0SpillSlotIndex != std::numeric_limits<int>::max();
}

// Old SME ABI lowering state getters/setters:
Register getSMESaveBufferAddr() const { return SMESaveBufferAddr; };
void setSMESaveBufferAddr(Register Reg) { SMESaveBufferAddr = Reg; };
Expand Down
11 changes: 5 additions & 6 deletions llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
Original file line number Diff line number Diff line change
Expand Up @@ -224,22 +224,21 @@ define float @test6(float %f) nounwind "aarch64_pstate_sm_enabled" {
define void @test7() nounwind "aarch64_inout_zt0" {
; CHECK-LABEL: test7:
; CHECK: // %bb.0:
; CHECK-NEXT: sub sp, sp, #144
; CHECK-NEXT: stp x30, x19, [sp, #128] // 16-byte Folded Spill
; CHECK-NEXT: add x19, sp, #64
; CHECK-NEXT: sub sp, sp, #80
; CHECK-NEXT: stp x30, x19, [sp, #64] // 16-byte Folded Spill
; CHECK-NEXT: mov x19, sp
; CHECK-NEXT: str zt0, [x19]
; CHECK-NEXT: smstop za
; CHECK-NEXT: bl callee
; CHECK-NEXT: smstart za
; CHECK-NEXT: ldr zt0, [x19]
; CHECK-NEXT: mov x19, sp
; CHECK-NEXT: str zt0, [x19]
; CHECK-NEXT: smstop za
; CHECK-NEXT: bl callee
; CHECK-NEXT: smstart za
; CHECK-NEXT: ldr zt0, [x19]
; CHECK-NEXT: ldp x30, x19, [sp, #128] // 16-byte Folded Reload
; CHECK-NEXT: add sp, sp, #144
; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
; CHECK-NEXT: add sp, sp, #80
; CHECK-NEXT: ret
call void @callee()
call void @callee()
Expand Down
40 changes: 40 additions & 0 deletions llvm/test/CodeGen/AArch64/sme-zt0-state.ll
Original file line number Diff line number Diff line change
Expand Up @@ -386,3 +386,43 @@ define void @shared_za_new_zt0(ptr %callee) "aarch64_inout_za" "aarch64_new_zt0"
call void %callee() "aarch64_inout_za" "aarch64_in_zt0";
ret void;
}


define void @zt0_multiple_private_za_calls(ptr %callee) "aarch64_in_zt0" nounwind {
; CHECK-COMMON-LABEL: zt0_multiple_private_za_calls:
; CHECK-COMMON: // %bb.0:
; CHECK-COMMON-NEXT: sub sp, sp, #96
; CHECK-COMMON-NEXT: stp x20, x19, [sp, #80] // 16-byte Folded Spill
; CHECK-COMMON-NEXT: mov x20, sp
; CHECK-COMMON-NEXT: mov x19, x0
; CHECK-COMMON-NEXT: str x30, [sp, #64] // 8-byte Folded Spill
; CHECK-COMMON-NEXT: str zt0, [x20]
; CHECK-COMMON-NEXT: smstop za
; CHECK-COMMON-NEXT: blr x0
; CHECK-COMMON-NEXT: smstart za
; CHECK-COMMON-NEXT: ldr zt0, [x20]
; CHECK-COMMON-NEXT: str zt0, [x20]
; CHECK-COMMON-NEXT: smstop za
; CHECK-COMMON-NEXT: blr x19
; CHECK-COMMON-NEXT: smstart za
; CHECK-COMMON-NEXT: ldr zt0, [x20]
; CHECK-COMMON-NEXT: str zt0, [x20]
; CHECK-COMMON-NEXT: smstop za
; CHECK-COMMON-NEXT: blr x19
; CHECK-COMMON-NEXT: smstart za
; CHECK-COMMON-NEXT: ldr zt0, [x20]
; CHECK-COMMON-NEXT: str zt0, [x20]
; CHECK-COMMON-NEXT: smstop za
; CHECK-COMMON-NEXT: blr x19
; CHECK-COMMON-NEXT: smstart za
; CHECK-COMMON-NEXT: ldr zt0, [x20]
; CHECK-COMMON-NEXT: ldp x20, x19, [sp, #80] // 16-byte Folded Reload
; CHECK-COMMON-NEXT: ldr x30, [sp, #64] // 8-byte Folded Reload
; CHECK-COMMON-NEXT: add sp, sp, #96
; CHECK-COMMON-NEXT: ret
call void %callee()
call void %callee()
call void %callee()
call void %callee()
ret void
}
Loading