From 1f80d30951a619acb9d694a5046d46a3a6a8c70d Mon Sep 17 00:00:00 2001 From: Sander de Smalen Date: Wed, 4 Sep 2024 15:45:45 +0100 Subject: [PATCH 1/5] [AArch64] SME implementation for agnostic-ZA functions This implements the lowering of calls from agnostic-ZA functions to non-agnostic-ZA functions, using the ABI routines `__arm_sme_state_size`, `__arm_sme_save` and `__arm_sme_restore`. This implements the proposal described in the following PRs: * https://github.com/ARM-software/acle/pull/336 * https://github.com/ARM-software/abi-aa/pull/264 --- llvm/lib/IR/Verifier.cpp | 24 ++-- llvm/lib/Target/AArch64/AArch64FastISel.cpp | 3 +- .../Target/AArch64/AArch64ISelLowering.cpp | 129 +++++++++++++++++- llvm/lib/Target/AArch64/AArch64ISelLowering.h | 6 + .../AArch64/AArch64MachineFunctionInfo.h | 14 ++ .../lib/Target/AArch64/AArch64SMEInstrInfo.td | 16 +++ .../AArch64/AArch64TargetTransformInfo.cpp | 8 +- .../AArch64/Utils/AArch64SMEAttributes.cpp | 9 ++ .../AArch64/Utils/AArch64SMEAttributes.h | 17 ++- llvm/test/CodeGen/AArch64/sme-agnostic-za.ll | 84 ++++++++++++ .../AArch64/sme-disable-gisel-fisel.ll | 24 ++++ llvm/test/Verifier/sme-attributes.ll | 46 ++++--- 12 files changed, 342 insertions(+), 38 deletions(-) create mode 100644 llvm/test/CodeGen/AArch64/sme-agnostic-za.ll diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp index 48e27763017be..7b6f7b5aa6171 100644 --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -2268,19 +2268,23 @@ void Verifier::verifyFunctionAttrs(FunctionType *FT, AttributeList Attrs, Check((Attrs.hasFnAttr("aarch64_new_za") + Attrs.hasFnAttr("aarch64_in_za") + Attrs.hasFnAttr("aarch64_inout_za") + Attrs.hasFnAttr("aarch64_out_za") + - Attrs.hasFnAttr("aarch64_preserves_za")) <= 1, + Attrs.hasFnAttr("aarch64_preserves_za") + + Attrs.hasFnAttr("aarch64_za_state_agnostic")) <= 1, "Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', " - "'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive", + "'aarch64_inout_za', 'aarch64_preserves_za' and " + "'aarch64_za_state_agnostic' are mutually exclusive", V); - Check( - (Attrs.hasFnAttr("aarch64_new_zt0") + Attrs.hasFnAttr("aarch64_in_zt0") + - Attrs.hasFnAttr("aarch64_inout_zt0") + - Attrs.hasFnAttr("aarch64_out_zt0") + - Attrs.hasFnAttr("aarch64_preserves_zt0")) <= 1, - "Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', " - "'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive", - V); + Check((Attrs.hasFnAttr("aarch64_new_zt0") + + Attrs.hasFnAttr("aarch64_in_zt0") + + Attrs.hasFnAttr("aarch64_inout_zt0") + + Attrs.hasFnAttr("aarch64_out_zt0") + + Attrs.hasFnAttr("aarch64_preserves_zt0") + + Attrs.hasFnAttr("aarch64_za_state_agnostic")) <= 1, + "Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', " + "'aarch64_inout_zt0', 'aarch64_preserves_zt0' and " + "'aarch64_za_state_agnostic' are mutually exclusive", + V); if (Attrs.hasFnAttr(Attribute::JumpTable)) { const GlobalValue *GV = cast(V); diff --git a/llvm/lib/Target/AArch64/AArch64FastISel.cpp b/llvm/lib/Target/AArch64/AArch64FastISel.cpp index 9f0f23b6e6a65..738895998c119 100644 --- a/llvm/lib/Target/AArch64/AArch64FastISel.cpp +++ b/llvm/lib/Target/AArch64/AArch64FastISel.cpp @@ -5197,7 +5197,8 @@ FastISel *AArch64::createFastISel(FunctionLoweringInfo &FuncInfo, SMEAttrs CallerAttrs(*FuncInfo.Fn); if (CallerAttrs.hasZAState() || CallerAttrs.hasZT0State() || CallerAttrs.hasStreamingInterfaceOrBody() || - CallerAttrs.hasStreamingCompatibleInterface()) + CallerAttrs.hasStreamingCompatibleInterface() || + CallerAttrs.hasAgnosticZAInterface()) return nullptr; return new AArch64FastISel(FuncInfo, LibInfo); } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index e455fabfe2e8d..a9ff4860da580 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -2643,6 +2643,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { break; MAKE_CASE(AArch64ISD::ALLOCATE_ZA_BUFFER) MAKE_CASE(AArch64ISD::INIT_TPIDR2OBJ) + MAKE_CASE(AArch64ISD::GET_SME_SAVE_SIZE) + MAKE_CASE(AArch64ISD::ALLOC_SME_SAVE_BUFFER) MAKE_CASE(AArch64ISD::COALESCER_BARRIER) MAKE_CASE(AArch64ISD::VG_SAVE) MAKE_CASE(AArch64ISD::VG_RESTORE) @@ -3230,6 +3232,39 @@ AArch64TargetLowering::EmitAllocateZABuffer(MachineInstr &MI, return BB; } +// TODO: Find a way to merge this with EmitAllocateZABuffer. +MachineBasicBlock * +AArch64TargetLowering::EmitAllocateSMESaveBuffer(MachineInstr &MI, + MachineBasicBlock *BB) const { + MachineFunction *MF = BB->getParent(); + MachineFrameInfo &MFI = MF->getFrameInfo(); + AArch64FunctionInfo *FuncInfo = MF->getInfo(); + assert(!MF->getSubtarget().isTargetWindows() && + "Lazy ZA save is not yet supported on Windows"); + + const TargetInstrInfo *TII = Subtarget->getInstrInfo(); + if (FuncInfo->getSMESaveBufferUsed()) { + // Allocate a lazy-save buffer object of the size given, normally SVL * SVL + auto Size = MI.getOperand(1).getReg(); + auto Dest = MI.getOperand(0).getReg(); + BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::SUBXrx64), Dest) + .addReg(AArch64::SP) + .addReg(Size) + .addImm(AArch64_AM::getArithExtendImm(AArch64_AM::UXTX, 0)); + BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), + AArch64::SP) + .addReg(Dest); + + // We have just allocated a variable sized object, tell this to PEI. + MFI.CreateVariableSizedObject(Align(16), nullptr); + } else + BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::IMPLICIT_DEF), + MI.getOperand(0).getReg()); + + BB->remove_instr(&MI); + return BB; +} + MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter( MachineInstr &MI, MachineBasicBlock *BB) const { @@ -3264,6 +3299,31 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter( return EmitInitTPIDR2Object(MI, BB); case AArch64::AllocateZABuffer: return EmitAllocateZABuffer(MI, BB); + case AArch64::AllocateSMESaveBuffer: + return EmitAllocateSMESaveBuffer(MI, BB); + case AArch64::GetSMESaveSize: { + // If the buffer is used, emit a call to __arm_sme_state_size() + MachineFunction *MF = BB->getParent(); + AArch64FunctionInfo *FuncInfo = MF->getInfo(); + const TargetInstrInfo *TII = Subtarget->getInstrInfo(); + if (FuncInfo->getSMESaveBufferUsed()) { + const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo(); + BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL)) + .addExternalSymbol("__arm_sme_state_size") + .addReg(AArch64::X0, RegState::ImplicitDefine) + .addRegMask(TRI->getCallPreservedMask( + *MF, CallingConv:: + AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1)); + BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), + MI.getOperand(0).getReg()) + .addReg(AArch64::X0); + } else + BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), + MI.getOperand(0).getReg()) + .addReg(AArch64::XZR); + BB->remove_instr(&MI); + return BB; + } case AArch64::F128CSEL: return EmitF128CSEL(MI, BB); case TargetOpcode::STATEPOINT: @@ -7663,6 +7723,7 @@ CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC, case CallingConv::AArch64_VectorCall: case CallingConv::AArch64_SVE_VectorCall: case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0: + case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1: case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2: return CC_AArch64_AAPCS; case CallingConv::ARM64EC_Thunk_X64: @@ -8122,6 +8183,31 @@ SDValue AArch64TargetLowering::LowerFormalArguments( Chain = DAG.getNode( AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other), {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)}); + } else if (SMEAttrs(MF.getFunction()).hasAgnosticZAInterface()) { + // Call __arm_sme_state_size(). + SDValue BufferSize = + DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL, + DAG.getVTList(MVT::i64, MVT::Other), Chain); + Chain = BufferSize.getValue(1); + + SDValue Buffer; + if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) { + Buffer = + DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL, + DAG.getVTList(MVT::i64, MVT::Other), {Chain, BufferSize}); + } else { + // Allocate space dynamically. + Buffer = DAG.getNode( + ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other), + {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)}); + MFI.CreateVariableSizedObject(Align(16), nullptr); + } + + // Copy the value to a virtual register, and save that in FuncInfo. + Register BufferPtr = + MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass); + FuncInfo->setSMESaveBufferAddr(BufferPtr); + Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer); } if (CallConv == CallingConv::PreserveNone) { @@ -8410,6 +8496,7 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization( auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal); if (CallerAttrs.requiresSMChange(CalleeAttrs) || CallerAttrs.requiresLazySave(CalleeAttrs) || + CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) || CallerAttrs.hasStreamingBody()) return false; @@ -8734,6 +8821,30 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL, return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops); } +// Emit a call to __arm_sme_save or __arm_sme_restore. +static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI, + SelectionDAG &DAG, + AArch64FunctionInfo *Info, SDLoc DL, + SDValue Chain, bool IsSave) { + TargetLowering::ArgListTy Args; + TargetLowering::ArgListEntry Entry; + Entry.Ty = PointerType::getUnqual(*DAG.getContext()); + Entry.Node = + DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64); + Args.push_back(Entry); + + SDValue Callee = + DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore", + TLI.getPointerTy(DAG.getDataLayout())); + auto *RetTy = Type::getVoidTy(*DAG.getContext()); + TargetLowering::CallLoweringInfo CLI(DAG); + CLI.setDebugLoc(DL).setChain(Chain).setLibCallee( + CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1, RetTy, + Callee, std::move(Args)); + + return TLI.LowerCallTo(CLI).second; +} + static unsigned getSMCondition(const SMEAttrs &CallerAttrs, const SMEAttrs &CalleeAttrs) { if (!CallerAttrs.hasStreamingCompatibleInterface() || @@ -8894,6 +9005,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, }; bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs); + bool RequiresSaveAllZA = + CallerAttrs.requiresPreservingAllZAState(CalleeAttrs); + SDValue ZAStateBuffer; if (RequiresLazySave) { const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj(); MachinePointerInfo MPI = @@ -8920,6 +9034,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, &MF.getFunction()); return DescribeCallsite(R) << " sets up a lazy save for ZA"; }); + } else if (RequiresSaveAllZA) { + assert(!CalleeAttrs.hasSharedZAInterface() && + "Cannot share state that may not exist"); + Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain, + /*IsSave=*/true); } SDValue PStateSM; @@ -9467,9 +9586,14 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32), DAG.getConstant(0, DL, MVT::i64)); TPIDR2.Uses++; + } else if (RequiresSaveAllZA) { + Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain, + /*IsSave=*/false); + FuncInfo->setSMESaveBufferUsed(); } - if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0) { + if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0 || + RequiresSaveAllZA) { for (unsigned I = 0; I < InVals.size(); ++I) { // The smstart/smstop is chained as part of the call, but when the // resulting chain is discarded (which happens when the call is not part @@ -28084,7 +28208,8 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const { auto CalleeAttrs = SMEAttrs(*Base); if (CallerAttrs.requiresSMChange(CalleeAttrs) || CallerAttrs.requiresLazySave(CalleeAttrs) || - CallerAttrs.requiresPreservingZT0(CalleeAttrs)) + CallerAttrs.requiresPreservingZT0(CalleeAttrs) || + CallerAttrs.requiresPreservingAllZAState(CalleeAttrs)) return true; } return false; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 36d62ca69ca08..9b5eaf38a00b1 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -466,6 +466,10 @@ enum NodeType : unsigned { ALLOCATE_ZA_BUFFER, INIT_TPIDR2OBJ, + // Needed for __arm_agnostic("sme_za_state") + GET_SME_SAVE_SIZE, + ALLOC_SME_SAVE_BUFFER, + // Asserts that a function argument (i32) is zero-extended to i8 by // the caller ASSERT_ZEXT_BOOL, @@ -667,6 +671,8 @@ class AArch64TargetLowering : public TargetLowering { MachineBasicBlock *BB) const; MachineBasicBlock *EmitAllocateZABuffer(MachineInstr &MI, MachineBasicBlock *BB) const; + MachineBasicBlock *EmitAllocateSMESaveBuffer(MachineInstr &MI, + MachineBasicBlock *BB) const; MachineBasicBlock * EmitInstrWithCustomInserter(MachineInstr &MI, diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h index a77fdaf19bcf5..7fd3a6c560329 100644 --- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h +++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h @@ -229,6 +229,14 @@ class AArch64FunctionInfo final : public MachineFunctionInfo { // on function entry to record the initial pstate of a function. Register PStateSMReg = MCRegister::NoRegister; + // Holds a pointer to a buffer that is large enough to represent + // all SME ZA state and any additional state required by the + // __arm_sme_save/restore support routines. + Register SMESaveBufferAddr = MCRegister::NoRegister; + + // true if SMESaveBufferAddr is used. + bool SMESaveBufferUsed = false; + // Has the PNReg used to build PTRUE instruction. // The PTRUE is used for the LD/ST of ZReg pairs in save and restore. unsigned PredicateRegForFillSpill = 0; @@ -252,6 +260,12 @@ class AArch64FunctionInfo final : public MachineFunctionInfo { return PredicateRegForFillSpill; } + Register getSMESaveBufferAddr() const { return SMESaveBufferAddr; }; + void setSMESaveBufferAddr(Register Reg) { SMESaveBufferAddr = Reg; }; + + unsigned getSMESaveBufferUsed() const { return SMESaveBufferUsed; }; + void setSMESaveBufferUsed(bool Used = true) { SMESaveBufferUsed = Used; }; + Register getPStateSMReg() const { return PStateSMReg; }; void setPStateSMReg(Register Reg) { PStateSMReg = Reg; }; diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td index fedf761f53b64..8b8d73d78a1ea 100644 --- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td @@ -52,6 +52,22 @@ let usesCustomInserter = 1 in { def InitTPIDR2Obj : Pseudo<(outs), (ins GPR64:$buffer), [(AArch64InitTPIDR2Obj GPR64:$buffer)]>, Sched<[WriteI]> {} } +// Nodes to allocate a save buffer for SME. +def AArch64SMESaveSize : SDNode<"AArch64ISD::GET_SME_SAVE_SIZE", SDTypeProfile<1, 0, + [SDTCisInt<0>]>, [SDNPHasChain]>; +let usesCustomInserter = 1, Defs = [X0] in { + def GetSMESaveSize : Pseudo<(outs GPR64:$dst), (ins), []>, Sched<[]> {} +} +def : Pat<(i64 AArch64SMESaveSize), (GetSMESaveSize)>; + +def AArch64AllocateSMESaveBuffer : SDNode<"AArch64ISD::ALLOC_SME_SAVE_BUFFER", SDTypeProfile<1, 1, + [SDTCisInt<0>, SDTCisInt<1>]>, [SDNPHasChain]>; +let usesCustomInserter = 1, Defs = [SP] in { + def AllocateSMESaveBuffer : Pseudo<(outs GPR64sp:$dst), (ins GPR64:$size), []>, Sched<[WriteI]> {} +} +def : Pat<(i64 (AArch64AllocateSMESaveBuffer GPR64:$size)), + (AllocateSMESaveBuffer $size)>; + //===----------------------------------------------------------------------===// // Instruction naming conventions. //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 80be7649d0fd7..77fc5cafae93d 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -261,7 +261,13 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller, if (CallerAttrs.requiresLazySave(CalleeAttrs) || CallerAttrs.requiresSMChange(CalleeAttrs) || - CallerAttrs.requiresPreservingZT0(CalleeAttrs)) { + CallerAttrs.requiresPreservingZT0(CalleeAttrs) || + CallerAttrs.requiresPreservingAllZAState(CalleeAttrs)) { + if (hasPossibleIncompatibleOps(Callee)) + return false; + } + + if (CalleeAttrs.hasAgnosticZAInterface()) { if (hasPossibleIncompatibleOps(Callee)) return false; } diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp index 015ca4cb92b25..bf16acd7f8f7e 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp @@ -38,6 +38,10 @@ void SMEAttrs::set(unsigned M, bool Enable) { isPreservesZT0())) && "Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', " "'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive"); + + assert(!(hasAgnosticZAInterface() && hasSharedZAInterface()) && + "Function cannot have a shared-ZA interface and an agnostic-ZA " + "interface"); } SMEAttrs::SMEAttrs(const CallBase &CB) { @@ -56,6 +60,9 @@ SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) { if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" || FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr") Bitmask |= SMEAttrs::SM_Compatible; + if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" || + FuncName == "__arm_sme_state_size") + Bitmask |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine; } SMEAttrs::SMEAttrs(const AttributeList &Attrs) { @@ -66,6 +73,8 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) { Bitmask |= SM_Compatible; if (Attrs.hasFnAttr("aarch64_pstate_sm_body")) Bitmask |= SM_Body; + if (Attrs.hasFnAttr("aarch64_za_state_agnostic")) + Bitmask |= ZA_State_Agnostic; if (Attrs.hasFnAttr("aarch64_in_za")) Bitmask |= encodeZAState(StateValue::In); if (Attrs.hasFnAttr("aarch64_out_za")) diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h index 4c7c1c9b07953..fb093da70c46b 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h @@ -42,9 +42,10 @@ class SMEAttrs { SM_Compatible = 1 << 1, // aarch64_pstate_sm_compatible SM_Body = 1 << 2, // aarch64_pstate_sm_body SME_ABI_Routine = 1 << 3, // Used for SME ABI routines to avoid lazy saves - ZA_Shift = 4, + ZA_State_Agnostic = 1 << 4, + ZA_Shift = 5, ZA_Mask = 0b111 << ZA_Shift, - ZT0_Shift = 7, + ZT0_Shift = 8, ZT0_Mask = 0b111 << ZT0_Shift }; @@ -96,8 +97,11 @@ class SMEAttrs { return State == StateValue::In || State == StateValue::Out || State == StateValue::InOut || State == StateValue::Preserved; } + bool hasAgnosticZAInterface() const { return Bitmask & ZA_State_Agnostic; } bool hasSharedZAInterface() const { return sharesZA() || sharesZT0(); } - bool hasPrivateZAInterface() const { return !hasSharedZAInterface(); } + bool hasPrivateZAInterface() const { + return !hasSharedZAInterface() && !hasAgnosticZAInterface(); + } bool hasZAState() const { return isNewZA() || sharesZA(); } bool requiresLazySave(const SMEAttrs &Callee) const { return hasZAState() && Callee.hasPrivateZAInterface() && @@ -128,7 +132,8 @@ class SMEAttrs { } bool hasZT0State() const { return isNewZT0() || sharesZT0(); } bool requiresPreservingZT0(const SMEAttrs &Callee) const { - return hasZT0State() && !Callee.sharesZT0(); + return hasZT0State() && !Callee.sharesZT0() && + !Callee.hasAgnosticZAInterface(); } bool requiresDisablingZABeforeCall(const SMEAttrs &Callee) const { return hasZT0State() && !hasZAState() && Callee.hasPrivateZAInterface() && @@ -137,6 +142,10 @@ class SMEAttrs { bool requiresEnablingZAAfterCall(const SMEAttrs &Callee) const { return requiresLazySave(Callee) || requiresDisablingZABeforeCall(Callee); } + bool requiresPreservingAllZAState(const SMEAttrs &Callee) const { + return hasAgnosticZAInterface() && !Callee.hasAgnosticZAInterface() && + !(Callee.Bitmask & SME_ABI_Routine); + } }; } // namespace llvm diff --git a/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll b/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll new file mode 100644 index 0000000000000..2e613118acbe0 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll @@ -0,0 +1,84 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc -mattr=+sme2 < %s | FileCheck %s + +target triple = "aarch64" + +declare i64 @private_za_decl(i64) +declare i64 @agnostic_decl(i64) "aarch64_za_state_agnostic" + +; No calls. Test that no buffer is allocated. +define i64 @agnostic_caller_no_callees(ptr %ptr) nounwind "aarch64_za_state_agnostic" { +; CHECK-LABEL: agnostic_caller_no_callees: +; CHECK: // %bb.0: +; CHECK-NEXT: ldr x0, [x0] +; CHECK-NEXT: ret + %v = load i64, ptr %ptr + ret i64 %v +} + +; agnostic-ZA -> private-ZA +; +; Test that a buffer is allocated and that the appropriate save/restore calls are +; inserted for calls to non-agnostic functions and that the arg/result registers are +; preserved by the register allocator. +define i64 @agnostic_caller_private_za_callee(i64 %v) nounwind "aarch64_za_state_agnostic" { +; CHECK-LABEL: agnostic_caller_private_za_callee: +; CHECK: // %bb.0: +; CHECK-NEXT: stp x29, x30, [sp, #-32]! // 16-byte Folded Spill +; CHECK-NEXT: str x19, [sp, #16] // 8-byte Folded Spill +; CHECK-NEXT: mov x29, sp +; CHECK-NEXT: mov x8, x0 +; CHECK-NEXT: bl __arm_sme_state_size +; CHECK-NEXT: sub x19, sp, x0 +; CHECK-NEXT: mov sp, x19 +; CHECK-NEXT: mov x0, x19 +; CHECK-NEXT: bl __arm_sme_save +; CHECK-NEXT: mov x0, x8 +; CHECK-NEXT: bl private_za_decl +; CHECK-NEXT: mov x1, x0 +; CHECK-NEXT: mov x0, x19 +; CHECK-NEXT: bl __arm_sme_restore +; CHECK-NEXT: mov x0, x19 +; CHECK-NEXT: bl __arm_sme_save +; CHECK-NEXT: mov x0, x1 +; CHECK-NEXT: bl private_za_decl +; CHECK-NEXT: mov x1, x0 +; CHECK-NEXT: mov x0, x19 +; CHECK-NEXT: bl __arm_sme_restore +; CHECK-NEXT: mov x0, x1 +; CHECK-NEXT: mov sp, x29 +; CHECK-NEXT: ldr x19, [sp, #16] // 8-byte Folded Reload +; CHECK-NEXT: ldp x29, x30, [sp], #32 // 16-byte Folded Reload +; CHECK-NEXT: ret + %res = call i64 @private_za_decl(i64 %v) + %res2 = call i64 @private_za_decl(i64 %res) + ret i64 %res2 +} + +; agnostic-ZA -> agnostic-ZA +; +; Should not result in save/restore code. +define i64 @agnostic_caller_agnostic_callee(i64 %v) nounwind "aarch64_za_state_agnostic" { +; CHECK-LABEL: agnostic_caller_agnostic_callee: +; CHECK: // %bb.0: +; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; CHECK-NEXT: bl agnostic_decl +; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; CHECK-NEXT: ret + %res = call i64 @agnostic_decl(i64 %v) + ret i64 %res +} + +; shared-ZA -> agnostic-ZA +; +; Should not result in lazy-save or save of ZT0 +define i64 @shared_caller_agnostic_callee(i64 %v) nounwind "aarch64_inout_za" "aarch64_inout_zt0" { +; CHECK-LABEL: shared_caller_agnostic_callee: +; CHECK: // %bb.0: +; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; CHECK-NEXT: bl agnostic_decl +; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; CHECK-NEXT: ret + %res = call i64 @agnostic_decl(i64 %v) + ret i64 %res +} diff --git a/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll b/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll index 42dba22d25708..d9dc2ad841f16 100644 --- a/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll +++ b/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll @@ -526,3 +526,27 @@ entry: %add = fadd double %call, 4.200000e+01 ret double %add; } + +define void @agnostic_za_function(ptr %ptr) nounwind "aarch64_za_state_agnostic" { +; CHECK-COMMON-LABEL: agnostic_za_function: +; CHECK-COMMON: // %bb.0: +; CHECK-COMMON-NEXT: stp x29, x30, [sp, #-32]! // 16-byte Folded Spill +; CHECK-COMMON-NEXT: stp x20, x19, [sp, #16] // 16-byte Folded Spill +; CHECK-COMMON-NEXT: mov x29, sp +; CHECK-COMMON-NEXT: mov x8, x0 +; CHECK-COMMON-NEXT: bl __arm_sme_state_size +; CHECK-COMMON-NEXT: sub x20, sp, x0 +; CHECK-COMMON-NEXT: mov sp, x20 +; CHECK-COMMON-NEXT: mov x0, x20 +; CHECK-COMMON-NEXT: bl __arm_sme_save +; CHECK-COMMON-NEXT: blr x8 +; CHECK-COMMON-NEXT: mov x0, x20 +; CHECK-COMMON-NEXT: bl __arm_sme_restore +; CHECK-COMMON-NEXT: mov sp, x29 +; CHECK-COMMON-NEXT: ldp x20, x19, [sp, #16] // 16-byte Folded Reload +; CHECK-COMMON-NEXT: ldp x29, x30, [sp], #32 // 16-byte Folded Reload +; CHECK-COMMON-NEXT: ret + call void %ptr() + ret void +} + diff --git a/llvm/test/Verifier/sme-attributes.ll b/llvm/test/Verifier/sme-attributes.ll index 3d01613ebf2fe..4bf5e813daf2f 100644 --- a/llvm/test/Verifier/sme-attributes.ll +++ b/llvm/test/Verifier/sme-attributes.ll @@ -4,61 +4,67 @@ declare void @sm_attrs() "aarch64_pstate_sm_enabled" "aarch64_pstate_sm_compatib ; CHECK: Attributes 'aarch64_pstate_sm_enabled and aarch64_pstate_sm_compatible' are incompatible! declare void @za_new_preserved() "aarch64_new_za" "aarch64_preserves_za"; -; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive +; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za', 'aarch64_preserves_za' and 'aarch64_za_state_agnostic' are mutually exclusive declare void @za_new_in() "aarch64_new_za" "aarch64_in_za"; -; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive +; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za', 'aarch64_preserves_za' and 'aarch64_za_state_agnostic' are mutually exclusive declare void @za_new_inout() "aarch64_new_za" "aarch64_inout_za"; -; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive +; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za', 'aarch64_preserves_za' and 'aarch64_za_state_agnostic' are mutually exclusive declare void @za_new_out() "aarch64_new_za" "aarch64_out_za"; -; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive +; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za', 'aarch64_preserves_za' and 'aarch64_za_state_agnostic' are mutually exclusive declare void @za_preserved_in() "aarch64_preserves_za" "aarch64_in_za"; -; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive +; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za', 'aarch64_preserves_za' and 'aarch64_za_state_agnostic' are mutually exclusive declare void @za_preserved_inout() "aarch64_preserves_za" "aarch64_inout_za"; -; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive +; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za', 'aarch64_preserves_za' and 'aarch64_za_state_agnostic' are mutually exclusive declare void @za_preserved_out() "aarch64_preserves_za" "aarch64_out_za"; -; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive +; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za', 'aarch64_preserves_za' and 'aarch64_za_state_agnostic' are mutually exclusive declare void @za_in_inout() "aarch64_in_za" "aarch64_inout_za"; -; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive +; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za', 'aarch64_preserves_za' and 'aarch64_za_state_agnostic' are mutually exclusive declare void @za_in_out() "aarch64_in_za" "aarch64_out_za"; -; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive +; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za', 'aarch64_preserves_za' and 'aarch64_za_state_agnostic' are mutually exclusive declare void @za_inout_out() "aarch64_inout_za" "aarch64_out_za"; -; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive +; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za', 'aarch64_preserves_za' and 'aarch64_za_state_agnostic' are mutually exclusive + +declare void @za_inout_agnostic() "aarch64_inout_za" "aarch64_za_state_agnostic"; +; CHECK: Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', 'aarch64_inout_za', 'aarch64_preserves_za' and 'aarch64_za_state_agnostic' are mutually exclusive declare void @zt0_new_preserved() "aarch64_new_zt0" "aarch64_preserves_zt0"; -; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive +; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0', 'aarch64_preserves_zt0' and 'aarch64_za_state_agnostic' are mutually exclusive declare void @zt0_new_in() "aarch64_new_zt0" "aarch64_in_zt0"; -; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive +; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0', 'aarch64_preserves_zt0' and 'aarch64_za_state_agnostic' are mutually exclusive declare void @zt0_new_inout() "aarch64_new_zt0" "aarch64_inout_zt0"; -; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive +; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0', 'aarch64_preserves_zt0' and 'aarch64_za_state_agnostic' are mutually exclusive declare void @zt0_new_out() "aarch64_new_zt0" "aarch64_out_zt0"; -; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive +; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0', 'aarch64_preserves_zt0' and 'aarch64_za_state_agnostic' are mutually exclusive declare void @zt0_preserved_in() "aarch64_preserves_zt0" "aarch64_in_zt0"; -; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive +; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0', 'aarch64_preserves_zt0' and 'aarch64_za_state_agnostic' are mutually exclusive declare void @zt0_preserved_inout() "aarch64_preserves_zt0" "aarch64_inout_zt0"; -; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive +; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0', 'aarch64_preserves_zt0' and 'aarch64_za_state_agnostic' are mutually exclusive declare void @zt0_preserved_out() "aarch64_preserves_zt0" "aarch64_out_zt0"; -; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive +; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0', 'aarch64_preserves_zt0' and 'aarch64_za_state_agnostic' are mutually exclusive declare void @zt0_in_inout() "aarch64_in_zt0" "aarch64_inout_zt0"; -; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive +; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0', 'aarch64_preserves_zt0' and 'aarch64_za_state_agnostic' are mutually exclusive declare void @zt0_in_out() "aarch64_in_zt0" "aarch64_out_zt0"; -; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive +; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0', 'aarch64_preserves_zt0' and 'aarch64_za_state_agnostic' are mutually exclusive declare void @zt0_inout_out() "aarch64_inout_zt0" "aarch64_out_zt0"; -; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive +; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0', 'aarch64_preserves_zt0' and 'aarch64_za_state_agnostic' are mutually exclusive + +declare void @zt0_inout_agnostic() "aarch64_inout_zt0" "aarch64_za_state_agnostic"; +; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0', 'aarch64_preserves_zt0' and 'aarch64_za_state_agnostic' are mutually exclusive From b5c04bbf4100673d9b848d10f66e8e210c4100a3 Mon Sep 17 00:00:00 2001 From: Sander de Smalen Date: Mon, 23 Dec 2024 09:59:59 +0000 Subject: [PATCH 2/5] Address review comments --- .../Target/AArch64/AArch64ISelLowering.cpp | 67 ++++++++++--------- llvm/lib/Target/AArch64/AArch64ISelLowering.h | 2 + llvm/test/CodeGen/AArch64/sme-agnostic-za.ll | 4 +- .../AArch64/sme-disable-gisel-fisel.ll | 4 +- 4 files changed, 42 insertions(+), 35 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index a9ff4860da580..7941bd4ce4b59 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -3244,16 +3244,15 @@ AArch64TargetLowering::EmitAllocateSMESaveBuffer(MachineInstr &MI, const TargetInstrInfo *TII = Subtarget->getInstrInfo(); if (FuncInfo->getSMESaveBufferUsed()) { - // Allocate a lazy-save buffer object of the size given, normally SVL * SVL + // Allocate a buffer object of the size given by MI.getOperand(1). auto Size = MI.getOperand(1).getReg(); auto Dest = MI.getOperand(0).getReg(); - BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::SUBXrx64), Dest) + BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::SUBXrx64), AArch64::SP) .addReg(AArch64::SP) .addReg(Size) .addImm(AArch64_AM::getArithExtendImm(AArch64_AM::UXTX, 0)); - BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), - AArch64::SP) - .addReg(Dest); + BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), Dest) + .addReg(AArch64::SP); // We have just allocated a variable sized object, tell this to PEI. MFI.CreateVariableSizedObject(Align(16), nullptr); @@ -3265,6 +3264,32 @@ AArch64TargetLowering::EmitAllocateSMESaveBuffer(MachineInstr &MI, return BB; } +MachineBasicBlock * +AArch64TargetLowering::EmitGetSMESaveSize(MachineInstr &MI, + MachineBasicBlock *BB) const { + // If the buffer is used, emit a call to __arm_sme_state_size() + MachineFunction *MF = BB->getParent(); + AArch64FunctionInfo *FuncInfo = MF->getInfo(); + const TargetInstrInfo *TII = Subtarget->getInstrInfo(); + if (FuncInfo->getSMESaveBufferUsed()) { + const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo(); + BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL)) + .addExternalSymbol("__arm_sme_state_size") + .addReg(AArch64::X0, RegState::ImplicitDefine) + .addRegMask(TRI->getCallPreservedMask( + *MF, CallingConv:: + AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1)); + BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), + MI.getOperand(0).getReg()) + .addReg(AArch64::X0); + } else + BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), + MI.getOperand(0).getReg()) + .addReg(AArch64::XZR); + BB->remove_instr(&MI); + return BB; +} + MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter( MachineInstr &MI, MachineBasicBlock *BB) const { @@ -3301,29 +3326,8 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter( return EmitAllocateZABuffer(MI, BB); case AArch64::AllocateSMESaveBuffer: return EmitAllocateSMESaveBuffer(MI, BB); - case AArch64::GetSMESaveSize: { - // If the buffer is used, emit a call to __arm_sme_state_size() - MachineFunction *MF = BB->getParent(); - AArch64FunctionInfo *FuncInfo = MF->getInfo(); - const TargetInstrInfo *TII = Subtarget->getInstrInfo(); - if (FuncInfo->getSMESaveBufferUsed()) { - const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo(); - BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL)) - .addExternalSymbol("__arm_sme_state_size") - .addReg(AArch64::X0, RegState::ImplicitDefine) - .addRegMask(TRI->getCallPreservedMask( - *MF, CallingConv:: - AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1)); - BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), - MI.getOperand(0).getReg()) - .addReg(AArch64::X0); - } else - BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), - MI.getOperand(0).getReg()) - .addReg(AArch64::XZR); - BB->remove_instr(&MI); - return BB; - } + case AArch64::GetSMESaveSize: + return EmitGetSMESaveSize(MI, BB); case AArch64::F128CSEL: return EmitF128CSEL(MI, BB); case TargetOpcode::STATEPOINT: @@ -8826,6 +8830,10 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI, SelectionDAG &DAG, AArch64FunctionInfo *Info, SDLoc DL, SDValue Chain, bool IsSave) { + MachineFunction &MF = DAG.getMachineFunction(); + AArch64FunctionInfo *FuncInfo = MF.getInfo(); + FuncInfo->setSMESaveBufferUsed(); + TargetLowering::ArgListTy Args; TargetLowering::ArgListEntry Entry; Entry.Ty = PointerType::getUnqual(*DAG.getContext()); @@ -8841,7 +8849,6 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI, CLI.setDebugLoc(DL).setChain(Chain).setLibCallee( CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1, RetTy, Callee, std::move(Args)); - return TLI.LowerCallTo(CLI).second; } @@ -9007,7 +9014,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs); bool RequiresSaveAllZA = CallerAttrs.requiresPreservingAllZAState(CalleeAttrs); - SDValue ZAStateBuffer; if (RequiresLazySave) { const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj(); MachinePointerInfo MPI = @@ -9589,7 +9595,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, } else if (RequiresSaveAllZA) { Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain, /*IsSave=*/false); - FuncInfo->setSMESaveBufferUsed(); } if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0 || diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 9b5eaf38a00b1..1b7f328fa729a 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -673,6 +673,8 @@ class AArch64TargetLowering : public TargetLowering { MachineBasicBlock *BB) const; MachineBasicBlock *EmitAllocateSMESaveBuffer(MachineInstr &MI, MachineBasicBlock *BB) const; + MachineBasicBlock *EmitGetSMESaveSize(MachineInstr &MI, + MachineBasicBlock *BB) const; MachineBasicBlock * EmitInstrWithCustomInserter(MachineInstr &MI, diff --git a/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll b/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll index 2e613118acbe0..97522b9a319c0 100644 --- a/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll +++ b/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll @@ -29,8 +29,8 @@ define i64 @agnostic_caller_private_za_callee(i64 %v) nounwind "aarch64_za_state ; CHECK-NEXT: mov x29, sp ; CHECK-NEXT: mov x8, x0 ; CHECK-NEXT: bl __arm_sme_state_size -; CHECK-NEXT: sub x19, sp, x0 -; CHECK-NEXT: mov sp, x19 +; CHECK-NEXT: sub sp, sp, x0 +; CHECK-NEXT: mov x19, sp ; CHECK-NEXT: mov x0, x19 ; CHECK-NEXT: bl __arm_sme_save ; CHECK-NEXT: mov x0, x8 diff --git a/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll b/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll index d9dc2ad841f16..fc0208d605dd7 100644 --- a/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll +++ b/llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll @@ -535,8 +535,8 @@ define void @agnostic_za_function(ptr %ptr) nounwind "aarch64_za_state_agnostic" ; CHECK-COMMON-NEXT: mov x29, sp ; CHECK-COMMON-NEXT: mov x8, x0 ; CHECK-COMMON-NEXT: bl __arm_sme_state_size -; CHECK-COMMON-NEXT: sub x20, sp, x0 -; CHECK-COMMON-NEXT: mov sp, x20 +; CHECK-COMMON-NEXT: sub sp, sp, x0 +; CHECK-COMMON-NEXT: mov x20, sp ; CHECK-COMMON-NEXT: mov x0, x20 ; CHECK-COMMON-NEXT: bl __arm_sme_save ; CHECK-COMMON-NEXT: blr x8 From cb3e1e2990d6245f4e136c9cf4132568f57900b6 Mon Sep 17 00:00:00 2001 From: Sander de Smalen Date: Mon, 23 Dec 2024 13:40:45 +0000 Subject: [PATCH 3/5] Fix inlining --- .../AArch64/AArch64TargetTransformInfo.cpp | 5 - .../Inline/AArch64/sme-pstateza-attrs.ll | 145 +++++++++++++++++- 2 files changed, 142 insertions(+), 8 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 77fc5cafae93d..0566a87590012 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -267,11 +267,6 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller, return false; } - if (CalleeAttrs.hasAgnosticZAInterface()) { - if (hasPossibleIncompatibleOps(Callee)) - return false; - } - return BaseT::areInlineCompatible(Caller, Callee); } diff --git a/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll b/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll index aaab5115261e5..f783c7e582552 100644 --- a/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll +++ b/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll @@ -50,16 +50,30 @@ define void @new_za_callee() "aarch64_new_za" { ret void } +define void @agnostic_za_callee() "aarch64_za_state_agnostic" { +; CHECK-LABEL: define void @agnostic_za_callee +; CHECK-SAME: () #[[ATTR3:[0-9]+]] { +; CHECK-NEXT: call void asm sideeffect " +; CHECK-NEXT: call void @inlined_body() +; CHECK-NEXT: ret void +; + call void asm sideeffect "; inlineasm", ""() + call void @inlined_body() + ret void +} + ; ; Now test that inlining only happens when no lazy-save is needed. ; Test for a number of combinations, where: ; N Not using ZA. ; S Shared ZA interface ; Z New ZA with Private-ZA interface +; A Agnostic ZA interface ; [x] N -> N ; [ ] N -> S (This combination is invalid) ; [ ] N -> Z +; [ ] N -> A define void @nonza_caller_nonza_callee_inline() { ; CHECK-LABEL: define void @nonza_caller_nonza_callee_inline ; CHECK-SAME: () #[[ATTR0]] { @@ -76,6 +90,7 @@ entry: ; [ ] N -> N ; [ ] N -> S (This combination is invalid) ; [x] N -> Z +; [ ] N -> A define void @nonza_caller_new_za_callee_dont_inline() { ; CHECK-LABEL: define void @nonza_caller_new_za_callee_dont_inline ; CHECK-SAME: () #[[ATTR0]] { @@ -88,9 +103,27 @@ entry: ret void } +; [ ] N -> N +; [ ] N -> S (This combination is invalid) +; [ ] N -> Z +; [x] N -> A +define void @nonza_caller_agnostic_za_callee_inline() { +; CHECK-LABEL: define void @nonza_caller_agnostic_za_callee_inline +; CHECK-SAME: () #[[ATTR0]] { +; CHECK-NEXT: entry: +; CHECK-NEXT: call void asm sideeffect " +; CHECK-NEXT: call void @inlined_body() +; CHECK-NEXT: ret void +; +entry: + call void @agnostic_za_callee() + ret void +} + ; [x] Z -> N ; [ ] Z -> S ; [ ] Z -> Z +; [ ] Z -> A define void @new_za_caller_nonza_callee_dont_inline() "aarch64_new_za" { ; CHECK-LABEL: define void @new_za_caller_nonza_callee_dont_inline ; CHECK-SAME: () #[[ATTR2]] { @@ -106,6 +139,7 @@ entry: ; [ ] Z -> N ; [x] Z -> S ; [ ] Z -> Z +; [ ] Z -> A define void @new_za_caller_shared_za_callee_inline() "aarch64_new_za" { ; CHECK-LABEL: define void @new_za_caller_shared_za_callee_inline ; CHECK-SAME: () #[[ATTR2]] { @@ -122,6 +156,7 @@ entry: ; [ ] Z -> N ; [ ] Z -> S ; [x] Z -> Z +; [ ] Z -> A define void @new_za_caller_new_za_callee_dont_inline() "aarch64_new_za" { ; CHECK-LABEL: define void @new_za_caller_new_za_callee_dont_inline ; CHECK-SAME: () #[[ATTR2]] { @@ -134,9 +169,27 @@ entry: ret void } +; [ ] Z -> N +; [ ] Z -> S +; [ ] Z -> Z +; [x] Z -> A +define void @new_za_caller_agnostic_za_callee_inline() "aarch64_new_za" { +; CHECK-LABEL: define void @new_za_caller_agnostic_za_callee_inline +; CHECK-SAME: () #[[ATTR2]] { +; CHECK-NEXT: entry: +; CHECK-NEXT: call void asm sideeffect " +; CHECK-NEXT: call void @inlined_body() +; CHECK-NEXT: ret void +; +entry: + call void @agnostic_za_callee() + ret void +} + ; [x] S -> N ; [ ] S -> S ; [ ] S -> Z +; [ ] S -> A define void @shared_za_caller_nonza_callee_dont_inline() "aarch64_inout_za" { ; CHECK-LABEL: define void @shared_za_caller_nonza_callee_dont_inline ; CHECK-SAME: () #[[ATTR1]] { @@ -152,6 +205,7 @@ entry: ; [ ] S -> N ; [x] S -> Z ; [ ] S -> S +; [ ] S -> A define void @shared_za_caller_new_za_callee_dont_inline() "aarch64_inout_za" { ; CHECK-LABEL: define void @shared_za_caller_new_za_callee_dont_inline ; CHECK-SAME: () #[[ATTR1]] { @@ -167,6 +221,7 @@ entry: ; [ ] S -> N ; [ ] S -> Z ; [x] S -> S +; [ ] S -> A define void @shared_za_caller_shared_za_callee_inline() "aarch64_inout_za" { ; CHECK-LABEL: define void @shared_za_caller_shared_za_callee_inline ; CHECK-SAME: () #[[ATTR1]] { @@ -180,6 +235,90 @@ entry: ret void } +; [ ] S -> N +; [ ] S -> Z +; [ ] S -> S +; [x] S -> A +define void @shared_za_caller_agnostic_za_callee_inline() "aarch64_inout_za" { +; CHECK-LABEL: define void @shared_za_caller_agnostic_za_callee_inline +; CHECK-SAME: () #[[ATTR1]] { +; CHECK-NEXT: entry: +; CHECK-NEXT: call void asm sideeffect " +; CHECK-NEXT: call void @inlined_body() +; CHECK-NEXT: ret void +; +entry: + call void @agnostic_za_callee() + ret void +} + +; [x] A -> N +; [ ] A -> Z +; [ ] A -> S +; [ ] A -> A +define void @agnostic_za_caller_nonza_callee_dont_inline() "aarch64_za_state_agnostic" { +; CHECK-LABEL: define void @agnostic_za_caller_nonza_callee_dont_inline +; CHECK-SAME: () #[[ATTR3]] { +; CHECK-NEXT: entry: +; CHECK-NEXT: call void @nonza_callee() +; CHECK-NEXT: ret void +; +entry: + call void @nonza_callee() + ret void +} + +; [ ] A -> N +; [x] A -> Z +; [ ] A -> S +; [ ] A -> A +define void @agnostic_za_caller_now_za_callee_dont_inline() "aarch64_za_state_agnostic" { +; CHECK-LABEL: define void @agnostic_za_caller_now_za_callee_dont_inline +; CHECK-SAME: () #[[ATTR3]] { +; CHECK-NEXT: entry: +; CHECK-NEXT: call void @new_za_callee() +; CHECK-NEXT: ret void +; +entry: + call void @new_za_callee() + ret void +} + +; [ ] A -> N +; [ ] A -> Z +; [x] A -> S (invalid) +; [ ] A -> A +define void @agnostic_za_caller_shared_za_callee_dont_inline() "aarch64_za_state_agnostic" { +; CHECK-LABEL: define void @agnostic_za_caller_shared_za_callee_dont_inline +; CHECK-SAME: () #[[ATTR3]] { +; CHECK-NEXT: entry: +; CHECK-NEXT: call void @shared_za_callee() +; CHECK-NEXT: ret void +; +entry: + call void @shared_za_callee() + ret void +} + +; [ ] A -> N +; [ ] A -> Z +; [ ] A -> S +; [x] A -> A +define void @agnostic_za_caller_agnostic_za_callee_inline() "aarch64_za_state_agnostic" { +; CHECK-LABEL: define void @agnostic_za_caller_agnostic_za_callee_inline +; CHECK-SAME: () #[[ATTR3]] { +; CHECK-NEXT: entry: +; CHECK-NEXT: call void asm sideeffect " +; CHECK-NEXT: call void @inlined_body() +; CHECK-NEXT: ret void +; +entry: + call void @agnostic_za_callee() + ret void +} + + + define void @private_za_callee_call_za_disable() { ; CHECK-LABEL: define void @private_za_callee_call_za_disable ; CHECK-SAME: () #[[ATTR0]] { @@ -254,7 +393,7 @@ define void @nonzt0_callee() { define void @shared_zt0_caller_nonzt0_callee_dont_inline() "aarch64_inout_zt0" { ; CHECK-LABEL: define void @shared_zt0_caller_nonzt0_callee_dont_inline -; CHECK-SAME: () #[[ATTR3:[0-9]+]] { +; CHECK-SAME: () #[[ATTR4:[0-9]+]] { ; CHECK-NEXT: call void @nonzt0_callee() ; CHECK-NEXT: ret void ; @@ -264,7 +403,7 @@ define void @shared_zt0_caller_nonzt0_callee_dont_inline() "aarch64_inout_zt0" { define void @shared_zt0_callee() "aarch64_inout_zt0" { ; CHECK-LABEL: define void @shared_zt0_callee -; CHECK-SAME: () #[[ATTR3]] { +; CHECK-SAME: () #[[ATTR4]] { ; CHECK-NEXT: call void asm sideeffect " ; CHECK-NEXT: call void @inlined_body() ; CHECK-NEXT: ret void @@ -276,7 +415,7 @@ define void @shared_zt0_callee() "aarch64_inout_zt0" { define void @shared_zt0_caller_shared_zt0_callee_inline() "aarch64_inout_zt0" { ; CHECK-LABEL: define void @shared_zt0_caller_shared_zt0_callee_inline -; CHECK-SAME: () #[[ATTR3]] { +; CHECK-SAME: () #[[ATTR4]] { ; CHECK-NEXT: call void asm sideeffect " ; CHECK-NEXT: call void @inlined_body() ; CHECK-NEXT: ret void From dcf6f5217b392a1c251f1b0f0125f1eb0ba60c10 Mon Sep 17 00:00:00 2001 From: Sander de Smalen Date: Mon, 23 Dec 2024 14:02:16 +0000 Subject: [PATCH 4/5] Rename getSMESaveBufferUsed -> isSMESaveBufferUsed --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 4 ++-- llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 7941bd4ce4b59..a6f8f47f31fa5 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -3243,7 +3243,7 @@ AArch64TargetLowering::EmitAllocateSMESaveBuffer(MachineInstr &MI, "Lazy ZA save is not yet supported on Windows"); const TargetInstrInfo *TII = Subtarget->getInstrInfo(); - if (FuncInfo->getSMESaveBufferUsed()) { + if (FuncInfo->isSMESaveBufferUsed()) { // Allocate a buffer object of the size given by MI.getOperand(1). auto Size = MI.getOperand(1).getReg(); auto Dest = MI.getOperand(0).getReg(); @@ -3271,7 +3271,7 @@ AArch64TargetLowering::EmitGetSMESaveSize(MachineInstr &MI, MachineFunction *MF = BB->getParent(); AArch64FunctionInfo *FuncInfo = MF->getInfo(); const TargetInstrInfo *TII = Subtarget->getInstrInfo(); - if (FuncInfo->getSMESaveBufferUsed()) { + if (FuncInfo->isSMESaveBufferUsed()) { const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo(); BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL)) .addExternalSymbol("__arm_sme_state_size") diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h index 7fd3a6c560329..427d86ee1bb8e 100644 --- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h +++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h @@ -263,7 +263,7 @@ class AArch64FunctionInfo final : public MachineFunctionInfo { Register getSMESaveBufferAddr() const { return SMESaveBufferAddr; }; void setSMESaveBufferAddr(Register Reg) { SMESaveBufferAddr = Reg; }; - unsigned getSMESaveBufferUsed() const { return SMESaveBufferUsed; }; + unsigned isSMESaveBufferUsed() const { return SMESaveBufferUsed; }; void setSMESaveBufferUsed(bool Used = true) { SMESaveBufferUsed = Used; }; Register getPStateSMReg() const { return PStateSMReg; }; From 2fd87f7ecf8fe30b17a577a002f54d779bdbb8a2 Mon Sep 17 00:00:00 2001 From: Sander de Smalen Date: Mon, 23 Dec 2024 15:34:40 +0000 Subject: [PATCH 5/5] s/now/new/ --- llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll b/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll index f783c7e582552..7ffbd64c700aa 100644 --- a/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll +++ b/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll @@ -272,8 +272,8 @@ entry: ; [x] A -> Z ; [ ] A -> S ; [ ] A -> A -define void @agnostic_za_caller_now_za_callee_dont_inline() "aarch64_za_state_agnostic" { -; CHECK-LABEL: define void @agnostic_za_caller_now_za_callee_dont_inline +define void @agnostic_za_caller_new_za_callee_dont_inline() "aarch64_za_state_agnostic" { +; CHECK-LABEL: define void @agnostic_za_caller_new_za_callee_dont_inline ; CHECK-SAME: () #[[ATTR3]] { ; CHECK-NEXT: entry: ; CHECK-NEXT: call void @new_za_callee()