@@ -3101,6 +3101,31 @@ AArch64TargetLowering::EmitGetSMESaveSize(MachineInstr &MI,
3101
3101
return BB;
3102
3102
}
3103
3103
3104
+ MachineBasicBlock *
3105
+ AArch64TargetLowering::EmitEntryPStateSM(MachineInstr &MI,
3106
+ MachineBasicBlock *BB) const {
3107
+ MachineFunction *MF = BB->getParent();
3108
+ AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
3109
+ const TargetInstrInfo *TII = Subtarget->getInstrInfo();
3110
+ Register ResultReg = MI.getOperand(0).getReg();
3111
+ if (FuncInfo->isPStateSMRegUsed()) {
3112
+ const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
3113
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
3114
+ .addExternalSymbol("__arm_sme_state")
3115
+ .addReg(AArch64::X0, RegState::ImplicitDefine)
3116
+ .addRegMask(TRI->getCallPreservedMask(
3117
+ *MF, CallingConv::
3118
+ AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2));
3119
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), ResultReg)
3120
+ .addReg(AArch64::X0);
3121
+ } else {
3122
+ assert(MI.getMF()->getRegInfo().use_empty(ResultReg) &&
3123
+ "Expected no users of the entry pstate.sm!");
3124
+ }
3125
+ MI.eraseFromParent();
3126
+ return BB;
3127
+ }
3128
+
3104
3129
// Helper function to find the instruction that defined a virtual register.
3105
3130
// If unable to find such instruction, returns nullptr.
3106
3131
static const MachineInstr *stripVRegCopies(const MachineRegisterInfo &MRI,
@@ -3216,6 +3241,8 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
3216
3241
return EmitAllocateSMESaveBuffer(MI, BB);
3217
3242
case AArch64::GetSMESaveSize:
3218
3243
return EmitGetSMESaveSize(MI, BB);
3244
+ case AArch64::EntryPStateSM:
3245
+ return EmitEntryPStateSM(MI, BB);
3219
3246
case AArch64::F128CSEL:
3220
3247
return EmitF128CSEL(MI, BB);
3221
3248
case TargetOpcode::STATEPOINT:
@@ -8133,19 +8160,26 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
8133
8160
}
8134
8161
assert((ArgLocs.size() + ExtraArgLocs) == Ins.size());
8135
8162
8163
+ if (Attrs.hasStreamingCompatibleInterface()) {
8164
+ SDValue EntryPStateSM =
8165
+ DAG.getNode(AArch64ISD::ENTRY_PSTATE_SM, DL,
8166
+ DAG.getVTList(MVT::i64, MVT::Other), {Chain});
8167
+
8168
+ // Copy the value to a virtual register, and save that in FuncInfo.
8169
+ Register EntryPStateSMReg =
8170
+ MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
8171
+ Chain = DAG.getCopyToReg(EntryPStateSM.getValue(1), DL, EntryPStateSMReg,
8172
+ EntryPStateSM);
8173
+ FuncInfo->setPStateSMReg(EntryPStateSMReg);
8174
+ }
8175
+
8136
8176
// Insert the SMSTART if this is a locally streaming function and
8137
8177
// make sure it is Glued to the last CopyFromReg value.
8138
8178
if (IsLocallyStreaming) {
8139
- SDValue PStateSM;
8140
- if (Attrs.hasStreamingCompatibleInterface()) {
8141
- PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
8142
- Register Reg = MF.getRegInfo().createVirtualRegister(
8143
- getRegClassFor(PStateSM.getValueType().getSimpleVT()));
8144
- FuncInfo->setPStateSMReg(Reg);
8145
- Chain = DAG.getCopyToReg(Chain, DL, Reg, PStateSM);
8179
+ if (Attrs.hasStreamingCompatibleInterface())
8146
8180
Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue,
8147
- AArch64SME::IfCallerIsNonStreaming, PStateSM );
8148
- } else
8181
+ AArch64SME::IfCallerIsNonStreaming);
8182
+ else
8149
8183
Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue,
8150
8184
AArch64SME::Always);
8151
8185
@@ -8836,8 +8870,7 @@ void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
8836
8870
SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
8837
8871
bool Enable, SDValue Chain,
8838
8872
SDValue InGlue,
8839
- unsigned Condition,
8840
- SDValue PStateSM) const {
8873
+ unsigned Condition) const {
8841
8874
MachineFunction &MF = DAG.getMachineFunction();
8842
8875
AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
8843
8876
FuncInfo->setHasStreamingModeChanges(true);
@@ -8849,9 +8882,16 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
8849
8882
SmallVector<SDValue> Ops = {Chain, MSROp};
8850
8883
unsigned Opcode;
8851
8884
if (Condition != AArch64SME::Always) {
8885
+ FuncInfo->setPStateSMRegUsed(true);
8886
+ Register PStateReg = FuncInfo->getPStateSMReg();
8887
+ assert(PStateReg.isValid() && "PStateSM Register is invalid");
8888
+ SDValue PStateSM =
8889
+ DAG.getCopyFromReg(Chain, DL, PStateReg, MVT::i64, InGlue);
8890
+ // Use chain and glue from the CopyFromReg.
8891
+ Ops[0] = PStateSM.getValue(1);
8892
+ InGlue = PStateSM.getValue(2);
8852
8893
SDValue ConditionOp = DAG.getTargetConstant(Condition, DL, MVT::i64);
8853
8894
Opcode = Enable ? AArch64ISD::COND_SMSTART : AArch64ISD::COND_SMSTOP;
8854
- assert(PStateSM && "PStateSM should be defined");
8855
8895
Ops.push_back(ConditionOp);
8856
8896
Ops.push_back(PStateSM);
8857
8897
} else {
@@ -9126,15 +9166,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9126
9166
/*IsSave=*/true);
9127
9167
}
9128
9168
9129
- SDValue PStateSM;
9130
9169
bool RequiresSMChange = CallAttrs.requiresSMChange();
9131
9170
if (RequiresSMChange) {
9132
- if (CallAttrs.caller().hasStreamingInterfaceOrBody())
9133
- PStateSM = DAG.getConstant(1, DL, MVT::i64);
9134
- else if (CallAttrs.caller().hasNonStreamingInterface())
9135
- PStateSM = DAG.getConstant(0, DL, MVT::i64);
9136
- else
9137
- PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
9138
9171
OptimizationRemarkEmitter ORE(&MF.getFunction());
9139
9172
ORE.emit([&]() {
9140
9173
auto R = CLI.CB ? OptimizationRemarkAnalysis("sme", "SMETransition",
@@ -9449,9 +9482,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9449
9482
InGlue = Chain.getValue(1);
9450
9483
}
9451
9484
9452
- SDValue NewChain = changeStreamingMode(
9453
- DAG, DL, CallAttrs.callee().hasStreamingInterface(), Chain, InGlue ,
9454
- getSMToggleCondition(CallAttrs), PStateSM );
9485
+ SDValue NewChain =
9486
+ changeStreamingMode( DAG, DL, CallAttrs.callee().hasStreamingInterface(),
9487
+ Chain, InGlue, getSMToggleCondition(CallAttrs));
9455
9488
Chain = NewChain.getValue(0);
9456
9489
InGlue = NewChain.getValue(1);
9457
9490
}
@@ -9635,10 +9668,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9635
9668
InGlue = Result.getValue(Result->getNumValues() - 1);
9636
9669
9637
9670
if (RequiresSMChange) {
9638
- assert(PStateSM && "Expected a PStateSM to be set");
9639
9671
Result = changeStreamingMode(
9640
9672
DAG, DL, !CallAttrs.callee().hasStreamingInterface(), Result, InGlue,
9641
- getSMToggleCondition(CallAttrs), PStateSM );
9673
+ getSMToggleCondition(CallAttrs));
9642
9674
9643
9675
if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
9644
9676
InGlue = Result.getValue(1);
@@ -9804,14 +9836,11 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
9804
9836
// Emit SMSTOP before returning from a locally streaming function
9805
9837
SMEAttrs FuncAttrs = FuncInfo->getSMEFnAttrs();
9806
9838
if (FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface()) {
9807
- if (FuncAttrs.hasStreamingCompatibleInterface()) {
9808
- Register Reg = FuncInfo->getPStateSMReg();
9809
- assert(Reg.isValid() && "PStateSM Register is invalid");
9810
- SDValue PStateSM = DAG.getCopyFromReg(Chain, DL, Reg, MVT::i64);
9839
+ if (FuncAttrs.hasStreamingCompatibleInterface())
9811
9840
Chain = changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
9812
9841
/*Glue*/ SDValue(),
9813
- AArch64SME::IfCallerIsNonStreaming, PStateSM );
9814
- } else
9842
+ AArch64SME::IfCallerIsNonStreaming);
9843
+ else
9815
9844
Chain = changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
9816
9845
/*Glue*/ SDValue(), AArch64SME::Always);
9817
9846
Glue = Chain.getValue(1);
@@ -28196,6 +28225,7 @@ void AArch64TargetLowering::ReplaceNodeResults(
28196
28225
case Intrinsic::aarch64_sme_in_streaming_mode: {
28197
28226
SDLoc DL(N);
28198
28227
SDValue Chain = DAG.getEntryNode();
28228
+
28199
28229
SDValue RuntimePStateSM =
28200
28230
getRuntimePStateSM(DAG, Chain, DL, N->getValueType(0));
28201
28231
Results.push_back(
0 commit comments