@@ -3101,6 +3101,31 @@ AArch64TargetLowering::EmitGetSMESaveSize(MachineInstr &MI,
31013101 return BB;
31023102}
31033103
3104+ MachineBasicBlock *
3105+ AArch64TargetLowering::EmitEntryPStateSM(MachineInstr &MI,
3106+ MachineBasicBlock *BB) const {
3107+ MachineFunction *MF = BB->getParent();
3108+ AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
3109+ const TargetInstrInfo *TII = Subtarget->getInstrInfo();
3110+ Register ResultReg = MI.getOperand(0).getReg();
3111+ if (FuncInfo->isPStateSMRegUsed()) {
3112+ const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
3113+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
3114+ .addExternalSymbol("__arm_sme_state")
3115+ .addReg(AArch64::X0, RegState::ImplicitDefine)
3116+ .addRegMask(TRI->getCallPreservedMask(
3117+ *MF, CallingConv::
3118+ AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2));
3119+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), ResultReg)
3120+ .addReg(AArch64::X0);
3121+ } else {
3122+ assert(MI.getMF()->getRegInfo().use_empty(ResultReg) &&
3123+ "Expected no users of the entry pstate.sm!");
3124+ }
3125+ MI.eraseFromParent();
3126+ return BB;
3127+ }
3128+
31043129// Helper function to find the instruction that defined a virtual register.
31053130// If unable to find such instruction, returns nullptr.
31063131static const MachineInstr *stripVRegCopies(const MachineRegisterInfo &MRI,
@@ -3216,6 +3241,8 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
32163241 return EmitAllocateSMESaveBuffer(MI, BB);
32173242 case AArch64::GetSMESaveSize:
32183243 return EmitGetSMESaveSize(MI, BB);
3244+ case AArch64::EntryPStateSM:
3245+ return EmitEntryPStateSM(MI, BB);
32193246 case AArch64::F128CSEL:
32203247 return EmitF128CSEL(MI, BB);
32213248 case TargetOpcode::STATEPOINT:
@@ -8133,19 +8160,26 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
81338160 }
81348161 assert((ArgLocs.size() + ExtraArgLocs) == Ins.size());
81358162
8163+ if (Attrs.hasStreamingCompatibleInterface()) {
8164+ SDValue EntryPStateSM =
8165+ DAG.getNode(AArch64ISD::ENTRY_PSTATE_SM, DL,
8166+ DAG.getVTList(MVT::i64, MVT::Other), {Chain});
8167+
8168+ // Copy the value to a virtual register, and save that in FuncInfo.
8169+ Register EntryPStateSMReg =
8170+ MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
8171+ Chain = DAG.getCopyToReg(EntryPStateSM.getValue(1), DL, EntryPStateSMReg,
8172+ EntryPStateSM);
8173+ FuncInfo->setPStateSMReg(EntryPStateSMReg);
8174+ }
8175+
81368176 // Insert the SMSTART if this is a locally streaming function and
81378177 // make sure it is Glued to the last CopyFromReg value.
81388178 if (IsLocallyStreaming) {
8139- SDValue PStateSM;
8140- if (Attrs.hasStreamingCompatibleInterface()) {
8141- PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
8142- Register Reg = MF.getRegInfo().createVirtualRegister(
8143- getRegClassFor(PStateSM.getValueType().getSimpleVT()));
8144- FuncInfo->setPStateSMReg(Reg);
8145- Chain = DAG.getCopyToReg(Chain, DL, Reg, PStateSM);
8179+ if (Attrs.hasStreamingCompatibleInterface())
81468180 Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue,
8147- AArch64SME::IfCallerIsNonStreaming, PStateSM );
8148- } else
8181+ AArch64SME::IfCallerIsNonStreaming);
8182+ else
81498183 Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue,
81508184 AArch64SME::Always);
81518185
@@ -8836,8 +8870,7 @@ void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
88368870SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
88378871 bool Enable, SDValue Chain,
88388872 SDValue InGlue,
8839- unsigned Condition,
8840- SDValue PStateSM) const {
8873+ unsigned Condition) const {
88418874 MachineFunction &MF = DAG.getMachineFunction();
88428875 AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
88438876 FuncInfo->setHasStreamingModeChanges(true);
@@ -8849,9 +8882,16 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
88498882 SmallVector<SDValue> Ops = {Chain, MSROp};
88508883 unsigned Opcode;
88518884 if (Condition != AArch64SME::Always) {
8885+ FuncInfo->setPStateSMRegUsed(true);
8886+ Register PStateReg = FuncInfo->getPStateSMReg();
8887+ assert(PStateReg.isValid() && "PStateSM Register is invalid");
8888+ SDValue PStateSM =
8889+ DAG.getCopyFromReg(Chain, DL, PStateReg, MVT::i64, InGlue);
8890+ // Use chain and glue from the CopyFromReg.
8891+ Ops[0] = PStateSM.getValue(1);
8892+ InGlue = PStateSM.getValue(2);
88528893 SDValue ConditionOp = DAG.getTargetConstant(Condition, DL, MVT::i64);
88538894 Opcode = Enable ? AArch64ISD::COND_SMSTART : AArch64ISD::COND_SMSTOP;
8854- assert(PStateSM && "PStateSM should be defined");
88558895 Ops.push_back(ConditionOp);
88568896 Ops.push_back(PStateSM);
88578897 } else {
@@ -9126,15 +9166,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91269166 /*IsSave=*/true);
91279167 }
91289168
9129- SDValue PStateSM;
91309169 bool RequiresSMChange = CallAttrs.requiresSMChange();
91319170 if (RequiresSMChange) {
9132- if (CallAttrs.caller().hasStreamingInterfaceOrBody())
9133- PStateSM = DAG.getConstant(1, DL, MVT::i64);
9134- else if (CallAttrs.caller().hasNonStreamingInterface())
9135- PStateSM = DAG.getConstant(0, DL, MVT::i64);
9136- else
9137- PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
91389171 OptimizationRemarkEmitter ORE(&MF.getFunction());
91399172 ORE.emit([&]() {
91409173 auto R = CLI.CB ? OptimizationRemarkAnalysis("sme", "SMETransition",
@@ -9449,9 +9482,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
94499482 InGlue = Chain.getValue(1);
94509483 }
94519484
9452- SDValue NewChain = changeStreamingMode(
9453- DAG, DL, CallAttrs.callee().hasStreamingInterface(), Chain, InGlue ,
9454- getSMToggleCondition(CallAttrs), PStateSM );
9485+ SDValue NewChain =
9486+ changeStreamingMode( DAG, DL, CallAttrs.callee().hasStreamingInterface(),
9487+ Chain, InGlue, getSMToggleCondition(CallAttrs));
94559488 Chain = NewChain.getValue(0);
94569489 InGlue = NewChain.getValue(1);
94579490 }
@@ -9635,10 +9668,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96359668 InGlue = Result.getValue(Result->getNumValues() - 1);
96369669
96379670 if (RequiresSMChange) {
9638- assert(PStateSM && "Expected a PStateSM to be set");
96399671 Result = changeStreamingMode(
96409672 DAG, DL, !CallAttrs.callee().hasStreamingInterface(), Result, InGlue,
9641- getSMToggleCondition(CallAttrs), PStateSM );
9673+ getSMToggleCondition(CallAttrs));
96429674
96439675 if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
96449676 InGlue = Result.getValue(1);
@@ -9804,14 +9836,11 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
98049836 // Emit SMSTOP before returning from a locally streaming function
98059837 SMEAttrs FuncAttrs = FuncInfo->getSMEFnAttrs();
98069838 if (FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface()) {
9807- if (FuncAttrs.hasStreamingCompatibleInterface()) {
9808- Register Reg = FuncInfo->getPStateSMReg();
9809- assert(Reg.isValid() && "PStateSM Register is invalid");
9810- SDValue PStateSM = DAG.getCopyFromReg(Chain, DL, Reg, MVT::i64);
9839+ if (FuncAttrs.hasStreamingCompatibleInterface())
98119840 Chain = changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
98129841 /*Glue*/ SDValue(),
9813- AArch64SME::IfCallerIsNonStreaming, PStateSM );
9814- } else
9842+ AArch64SME::IfCallerIsNonStreaming);
9843+ else
98159844 Chain = changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
98169845 /*Glue*/ SDValue(), AArch64SME::Always);
98179846 Glue = Chain.getValue(1);
@@ -28196,6 +28225,7 @@ void AArch64TargetLowering::ReplaceNodeResults(
2819628225 case Intrinsic::aarch64_sme_in_streaming_mode: {
2819728226 SDLoc DL(N);
2819828227 SDValue Chain = DAG.getEntryNode();
28228+
2819928229 SDValue RuntimePStateSM =
2820028230 getRuntimePStateSM(DAG, Chain, DL, N->getValueType(0));
2820128231 Results.push_back(
0 commit comments