1717#include "AArch64PerfectShuffle.h"
1818#include "AArch64RegisterInfo.h"
1919#include "AArch64Subtarget.h"
20+ #include "AArch64TargetMachine.h"
2021#include "MCTargetDesc/AArch64AddressingModes.h"
2122#include "Utils/AArch64BaseInfo.h"
2223#include "Utils/AArch64SMEAttributes.h"
@@ -1998,6 +1999,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
19981999 setOperationAction(Op, MVT::f16, Promote);
19992000}
20002001
2002+ const AArch64TargetMachine &AArch64TargetLowering::getTM() const {
2003+ return static_cast<const AArch64TargetMachine &>(getTargetMachine());
2004+ }
2005+
20012006void AArch64TargetLowering::addTypeForNEON(MVT VT) {
20022007 assert(VT.isVector() && "VT should be a vector type");
20032008
@@ -8285,53 +8290,54 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
82858290 if (Subtarget->hasCustomCallingConv())
82868291 Subtarget->getRegisterInfo()->UpdateCustomCalleeSavedRegs(MF);
82878292
8288- // Create a 16 Byte TPIDR2 object. The dynamic buffer
8289- // will be expanded and stored in the static object later using a pseudonode.
8290- if (Attrs.hasZAState()) {
8291- TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
8292- TPIDR2.FrameIndex = MFI.CreateStackObject(16, Align(16), false);
8293- SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
8294- DAG.getConstant(1, DL, MVT::i32));
8295-
8296- SDValue Buffer;
8297- if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8298- Buffer = DAG.getNode(AArch64ISD::ALLOCATE_ZA_BUFFER, DL,
8299- DAG.getVTList(MVT::i64, MVT::Other), {Chain, SVL});
8300- } else {
8301- SDValue Size = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL);
8302- Buffer = DAG.getNode(ISD::DYNAMIC_STACKALLOC, DL,
8303- DAG.getVTList(MVT::i64, MVT::Other),
8304- {Chain, Size, DAG.getConstant(1, DL, MVT::i64)});
8305- MFI.CreateVariableSizedObject(Align(16), nullptr);
8306- }
8307- Chain = DAG.getNode(
8308- AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
8309- {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
8310- } else if (Attrs.hasAgnosticZAInterface()) {
8311- // Call __arm_sme_state_size().
8312- SDValue BufferSize =
8313- DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
8314- DAG.getVTList(MVT::i64, MVT::Other), Chain);
8315- Chain = BufferSize.getValue(1);
8316-
8317- SDValue Buffer;
8318- if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8319- Buffer =
8320- DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
8321- DAG.getVTList(MVT::i64, MVT::Other), {Chain, BufferSize});
8322- } else {
8323- // Allocate space dynamically.
8324- Buffer = DAG.getNode(
8325- ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
8326- {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
8327- MFI.CreateVariableSizedObject(Align(16), nullptr);
8293+ if (!getTM().useNewSMEABILowering() || Attrs.hasAgnosticZAInterface()) {
8294+ // Old SME ABI lowering (deprecated):
8295+ // Create a 16 Byte TPIDR2 object. The dynamic buffer
8296+ // will be expanded and stored in the static object later using a
8297+ // pseudonode.
8298+ if (Attrs.hasZAState()) {
8299+ TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
8300+ TPIDR2.FrameIndex = MFI.CreateStackObject(16, Align(16), false);
8301+ SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
8302+ DAG.getConstant(1, DL, MVT::i32));
8303+ SDValue Buffer;
8304+ if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8305+ Buffer = DAG.getNode(AArch64ISD::ALLOCATE_ZA_BUFFER, DL,
8306+ DAG.getVTList(MVT::i64, MVT::Other), {Chain, SVL});
8307+ } else {
8308+ SDValue Size = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL);
8309+ Buffer = DAG.getNode(ISD::DYNAMIC_STACKALLOC, DL,
8310+ DAG.getVTList(MVT::i64, MVT::Other),
8311+ {Chain, Size, DAG.getConstant(1, DL, MVT::i64)});
8312+ MFI.CreateVariableSizedObject(Align(16), nullptr);
8313+ }
8314+ Chain = DAG.getNode(
8315+ AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
8316+ {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
8317+ } else if (Attrs.hasAgnosticZAInterface()) {
8318+ // Call __arm_sme_state_size().
8319+ SDValue BufferSize =
8320+ DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
8321+ DAG.getVTList(MVT::i64, MVT::Other), Chain);
8322+ Chain = BufferSize.getValue(1);
8323+ SDValue Buffer;
8324+ if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8325+ Buffer = DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
8326+ DAG.getVTList(MVT::i64, MVT::Other),
8327+ {Chain, BufferSize});
8328+ } else {
8329+ // Allocate space dynamically.
8330+ Buffer = DAG.getNode(
8331+ ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
8332+ {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
8333+ MFI.CreateVariableSizedObject(Align(16), nullptr);
8334+ }
8335+ // Copy the value to a virtual register, and save that in FuncInfo.
8336+ Register BufferPtr =
8337+ MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
8338+ FuncInfo->setSMESaveBufferAddr(BufferPtr);
8339+ Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
83288340 }
8329-
8330- // Copy the value to a virtual register, and save that in FuncInfo.
8331- Register BufferPtr =
8332- MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
8333- FuncInfo->setSMESaveBufferAddr(BufferPtr);
8334- Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
83358341 }
83368342
83378343 if (CallConv == CallingConv::PreserveNone) {
@@ -8348,6 +8354,15 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
83488354 }
83498355 }
83508356
8357+ if (getTM().useNewSMEABILowering()) {
8358+ // Clear new ZT0 state. TODO: Move this to the SME ABI pass.
8359+ if (Attrs.isNewZT0())
8360+ Chain = DAG.getNode(
8361+ ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
8362+ DAG.getConstant(Intrinsic::aarch64_sme_zero_zt, DL, MVT::i32),
8363+ DAG.getTargetConstant(0, DL, MVT::i32));
8364+ }
8365+
83518366 return Chain;
83528367}
83538368
@@ -8919,7 +8934,6 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
89198934 MachineFunction &MF = DAG.getMachineFunction();
89208935 AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
89218936 FuncInfo->setSMESaveBufferUsed();
8922-
89238937 TargetLowering::ArgListTy Args;
89248938 Args.emplace_back(
89258939 DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64),
@@ -9060,14 +9074,28 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
90609074 CallConv = CallingConv::AArch64_SVE_VectorCall;
90619075 }
90629076
9077+ // Determine whether we need any streaming mode changes.
9078+ SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), *this, CLI);
9079+ bool UseNewSMEABILowering = getTM().useNewSMEABILowering();
9080+ bool IsAgnosticZAFunction = CallAttrs.caller().hasAgnosticZAInterface();
9081+ auto ZAMarkerNode = [&]() -> std::optional<unsigned> {
9082+ // TODO: Handle agnostic ZA functions.
9083+ if (!UseNewSMEABILowering || IsAgnosticZAFunction)
9084+ return std::nullopt;
9085+ if (!CallAttrs.caller().hasZAState() && !CallAttrs.caller().hasZT0State())
9086+ return std::nullopt;
9087+ return CallAttrs.requiresLazySave() ? AArch64ISD::REQUIRES_ZA_SAVE
9088+ : AArch64ISD::INOUT_ZA_USE;
9089+ }();
9090+
90639091 if (IsTailCall) {
90649092 // Check if it's really possible to do a tail call.
90659093 IsTailCall = isEligibleForTailCallOptimization(CLI);
90669094
90679095 // A sibling call is one where we're under the usual C ABI and not planning
90689096 // to change that but can still do a tail call:
9069- if (!TailCallOpt && IsTailCall && CallConv != CallingConv::Tail &&
9070- CallConv != CallingConv::SwiftTail)
9097+ if (!ZAMarkerNode && !TailCallOpt && IsTailCall &&
9098+ CallConv != CallingConv::Tail && CallConv != CallingConv:: SwiftTail)
90719099 IsSibCall = true;
90729100
90739101 if (IsTailCall)
@@ -9119,9 +9147,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91199147 assert(FPDiff % 16 == 0 && "unaligned stack on tail call");
91209148 }
91219149
9122- // Determine whether we need any streaming mode changes.
9123- SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), *this, CLI);
9124-
91259150 auto DescribeCallsite =
91269151 [&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
91279152 R << "call from '" << ore::NV("Caller", MF.getName()) << "' to '";
@@ -9135,7 +9160,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91359160 return R;
91369161 };
91379162
9138- bool RequiresLazySave = CallAttrs.requiresLazySave();
9163+ bool RequiresLazySave = !UseNewSMEABILowering && CallAttrs.requiresLazySave();
91399164 bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
91409165 if (RequiresLazySave) {
91419166 const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
@@ -9210,10 +9235,20 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
92109235 AArch64ISD::SMSTOP, DL, DAG.getVTList(MVT::Other, MVT::Glue), Chain,
92119236 DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32));
92129237
9213- // Adjust the stack pointer for the new arguments...
9238+ // Adjust the stack pointer for the new arguments... and mark ZA uses.
92149239 // These operations are automatically eliminated by the prolog/epilog pass
9215- if (!IsSibCall)
9240+ assert((!IsSibCall || !ZAMarkerNode) && "ZA markers require CALLSEQ_START");
9241+ if (!IsSibCall) {
92169242 Chain = DAG.getCALLSEQ_START(Chain, IsTailCall ? 0 : NumBytes, 0, DL);
9243+ if (ZAMarkerNode) {
9244+ // Note: We need the CALLSEQ_START to glue the ZAMarkerNode to, simply
9245+ // using a chain can result in incorrect scheduling. The markers refer to
9246+ // the position just before the CALLSEQ_START (though occur after as
9247+ // CALLSEQ_START lacks in-glue).
9248+ Chain = DAG.getNode(*ZAMarkerNode, DL, DAG.getVTList(MVT::Other),
9249+ {Chain, Chain.getValue(1)});
9250+ }
9251+ }
92179252
92189253 SDValue StackPtr = DAG.getCopyFromReg(Chain, DL, AArch64::SP,
92199254 getPointerTy(DAG.getDataLayout()));
@@ -9684,7 +9719,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96849719 }
96859720 }
96869721
9687- if (CallAttrs.requiresEnablingZAAfterCall())
9722+ if (RequiresLazySave || CallAttrs.requiresEnablingZAAfterCall())
96889723 // Unconditionally resume ZA.
96899724 Result = DAG.getNode(
96909725 AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Result,
@@ -9706,7 +9741,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
97069741 SDValue TPIDR2_EL0 = DAG.getNode(
97079742 ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
97089743 DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
9709-
97109744 // Copy the address of the TPIDR2 block into X0 before 'calling' the
97119745 // RESTORE_ZA pseudo.
97129746 SDValue Glue;
@@ -9718,7 +9752,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
97189752 DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
97199753 {Result, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64),
97209754 RestoreRoutine, RegMask, Result.getValue(1)});
9721-
97229755 // Finally reset the TPIDR2_EL0 register to 0.
97239756 Result = DAG.getNode(
97249757 ISD::INTRINSIC_VOID, DL, MVT::Other, Result,
0 commit comments