diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index c9a756da0078d..2e1e44ec48eb8 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -8023,6 +8023,17 @@ static bool isPassedInFPR(EVT VT) { (VT.isFloatingPoint() && !VT.isScalableVector()); } +static SDValue getZT0FrameIndex(MachineFrameInfo &MFI, + AArch64FunctionInfo &FuncInfo, + SelectionDAG &DAG) { + if (!FuncInfo.hasZT0SpillSlotIndex()) + FuncInfo.setZT0SpillSlotIndex(MFI.CreateSpillStackObject(64, Align(16))); + + return DAG.getFrameIndex( + FuncInfo.getZT0SpillSlotIndex(), + DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout())); +} + SDValue AArch64TargetLowering::lowerEHPadEntry(SDValue Chain, SDLoc const &DL, SelectionDAG &DAG) const { assert(Chain.getOpcode() == ISD::EntryToken && "Unexpected Chain value"); @@ -9427,10 +9438,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, // If the caller has ZT0 state which will not be preserved by the callee, // spill ZT0 before the call. if (ShouldPreserveZT0) { - unsigned ZTObj = MFI.CreateSpillStackObject(64, Align(16)); - ZTFrameIdx = DAG.getFrameIndex( - ZTObj, - DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout())); + ZTFrameIdx = getZT0FrameIndex(MFI, *FuncInfo, DAG); Chain = DAG.getNode(AArch64ISD::SAVE_ZT, DL, DAG.getVTList(MVT::Other), {Chain, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx}); diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h index 98fd018bf33a9..897c7e8539608 100644 --- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h +++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h @@ -239,6 +239,9 @@ class AArch64FunctionInfo final : public MachineFunctionInfo { // support). Register EarlyAllocSMESaveBuffer = AArch64::NoRegister; + // Holds the spill slot for ZT0. + int ZT0SpillSlotIndex = std::numeric_limits::max(); + // Note: The following properties are only used for the old SME ABI lowering: /// The frame-index for the TPIDR2 object used for lazy saves. TPIDR2Object TPIDR2; @@ -265,6 +268,15 @@ class AArch64FunctionInfo final : public MachineFunctionInfo { return EarlyAllocSMESaveBuffer; } + void setZT0SpillSlotIndex(int FI) { ZT0SpillSlotIndex = FI; } + int getZT0SpillSlotIndex() const { + assert(hasZT0SpillSlotIndex() && "ZT0 spill slot index not set!"); + return ZT0SpillSlotIndex; + } + bool hasZT0SpillSlotIndex() const { + return ZT0SpillSlotIndex != std::numeric_limits::max(); + } + // Old SME ABI lowering state getters/setters: Register getSMESaveBufferAddr() const { return SMESaveBufferAddr; }; void setSMESaveBufferAddr(Register Reg) { SMESaveBufferAddr = Reg; }; diff --git a/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll b/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll index 80827c2547780..062b68e5909f3 100644 --- a/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll +++ b/llvm/test/CodeGen/AArch64/sme-peephole-opts.ll @@ -224,22 +224,21 @@ define float @test6(float %f) nounwind "aarch64_pstate_sm_enabled" { define void @test7() nounwind "aarch64_inout_zt0" { ; CHECK-LABEL: test7: ; CHECK: // %bb.0: -; CHECK-NEXT: sub sp, sp, #144 -; CHECK-NEXT: stp x30, x19, [sp, #128] // 16-byte Folded Spill -; CHECK-NEXT: add x19, sp, #64 +; CHECK-NEXT: sub sp, sp, #80 +; CHECK-NEXT: stp x30, x19, [sp, #64] // 16-byte Folded Spill +; CHECK-NEXT: mov x19, sp ; CHECK-NEXT: str zt0, [x19] ; CHECK-NEXT: smstop za ; CHECK-NEXT: bl callee ; CHECK-NEXT: smstart za ; CHECK-NEXT: ldr zt0, [x19] -; CHECK-NEXT: mov x19, sp ; CHECK-NEXT: str zt0, [x19] ; CHECK-NEXT: smstop za ; CHECK-NEXT: bl callee ; CHECK-NEXT: smstart za ; CHECK-NEXT: ldr zt0, [x19] -; CHECK-NEXT: ldp x30, x19, [sp, #128] // 16-byte Folded Reload -; CHECK-NEXT: add sp, sp, #144 +; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload +; CHECK-NEXT: add sp, sp, #80 ; CHECK-NEXT: ret call void @callee() call void @callee() diff --git a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll index 49eb368662b5d..2583a93e514a2 100644 --- a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll +++ b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll @@ -386,3 +386,43 @@ define void @shared_za_new_zt0(ptr %callee) "aarch64_inout_za" "aarch64_new_zt0" call void %callee() "aarch64_inout_za" "aarch64_in_zt0"; ret void; } + + +define void @zt0_multiple_private_za_calls(ptr %callee) "aarch64_in_zt0" nounwind { +; CHECK-COMMON-LABEL: zt0_multiple_private_za_calls: +; CHECK-COMMON: // %bb.0: +; CHECK-COMMON-NEXT: sub sp, sp, #96 +; CHECK-COMMON-NEXT: stp x20, x19, [sp, #80] // 16-byte Folded Spill +; CHECK-COMMON-NEXT: mov x20, sp +; CHECK-COMMON-NEXT: mov x19, x0 +; CHECK-COMMON-NEXT: str x30, [sp, #64] // 8-byte Folded Spill +; CHECK-COMMON-NEXT: str zt0, [x20] +; CHECK-COMMON-NEXT: smstop za +; CHECK-COMMON-NEXT: blr x0 +; CHECK-COMMON-NEXT: smstart za +; CHECK-COMMON-NEXT: ldr zt0, [x20] +; CHECK-COMMON-NEXT: str zt0, [x20] +; CHECK-COMMON-NEXT: smstop za +; CHECK-COMMON-NEXT: blr x19 +; CHECK-COMMON-NEXT: smstart za +; CHECK-COMMON-NEXT: ldr zt0, [x20] +; CHECK-COMMON-NEXT: str zt0, [x20] +; CHECK-COMMON-NEXT: smstop za +; CHECK-COMMON-NEXT: blr x19 +; CHECK-COMMON-NEXT: smstart za +; CHECK-COMMON-NEXT: ldr zt0, [x20] +; CHECK-COMMON-NEXT: str zt0, [x20] +; CHECK-COMMON-NEXT: smstop za +; CHECK-COMMON-NEXT: blr x19 +; CHECK-COMMON-NEXT: smstart za +; CHECK-COMMON-NEXT: ldr zt0, [x20] +; CHECK-COMMON-NEXT: ldp x20, x19, [sp, #80] // 16-byte Folded Reload +; CHECK-COMMON-NEXT: ldr x30, [sp, #64] // 8-byte Folded Reload +; CHECK-COMMON-NEXT: add sp, sp, #96 +; CHECK-COMMON-NEXT: ret + call void %callee() + call void %callee() + call void %callee() + call void %callee() + ret void +}