@@ -2940,6 +2940,63 @@ AArch64TargetLowering::EmitDynamicProbedAlloc(MachineInstr &MI,
2940
2940
return NextInst->getParent();
2941
2941
}
2942
2942
2943
+ MachineBasicBlock *
2944
+ AArch64TargetLowering::EmitCheckMatchingVL(MachineInstr &MI,
2945
+ MachineBasicBlock *MBB) const {
2946
+ MachineFunction *MF = MBB->getParent();
2947
+ MachineRegisterInfo &MRI = MF->getRegInfo();
2948
+
2949
+ const TargetRegisterClass *RC_GPR = &AArch64::GPR64RegClass;
2950
+ const TargetRegisterClass *RC_GPRsp = &AArch64::GPR64spRegClass;
2951
+
2952
+ Register RegVL_GPR = MRI.createVirtualRegister(RC_GPR);
2953
+ Register RegVL_GPRsp = MRI.createVirtualRegister(RC_GPRsp); // for ADDSVL src
2954
+ Register RegSVL_GPR = MRI.createVirtualRegister(RC_GPR);
2955
+ Register RegSVL_GPRsp = MRI.createVirtualRegister(RC_GPRsp); // for ADDSVL dst
2956
+
2957
+ const TargetInstrInfo *TII = Subtarget->getInstrInfo();
2958
+ DebugLoc DL = MI.getDebugLoc();
2959
+
2960
+ // RDVL requires GPR64, ADDSVL requires GPR64sp
2961
+ // We need to insert COPY instructions, these will later be removed by the
2962
+ // RegisterCoalescer
2963
+ BuildMI(*MBB, MI, DL, TII->get(AArch64::RDVLI_XI), RegVL_GPR).addImm(1);
2964
+ BuildMI(*MBB, MI, DL, TII->get(TargetOpcode::COPY), RegVL_GPRsp)
2965
+ .addReg(RegVL_GPR);
2966
+
2967
+ BuildMI(*MBB, MI, DL, TII->get(AArch64::ADDSVL_XXI), RegSVL_GPRsp)
2968
+ .addReg(RegVL_GPRsp)
2969
+ .addImm(-1);
2970
+ BuildMI(*MBB, MI, DL, TII->get(TargetOpcode::COPY), RegSVL_GPR)
2971
+ .addReg(RegSVL_GPRsp);
2972
+
2973
+ const BasicBlock *LLVM_BB = MBB->getBasicBlock();
2974
+ MachineFunction::iterator It = ++MBB->getIterator();
2975
+ MachineBasicBlock *TrapBB = MF->CreateMachineBasicBlock(LLVM_BB);
2976
+ MachineBasicBlock *PassBB = MF->CreateMachineBasicBlock(LLVM_BB);
2977
+ MF->insert(It, TrapBB);
2978
+ MF->insert(It, PassBB);
2979
+
2980
+ // Continue if vector lengths match
2981
+ BuildMI(*MBB, MI, DL, TII->get(AArch64::CBZX))
2982
+ .addReg(RegSVL_GPR)
2983
+ .addMBB(PassBB);
2984
+
2985
+ // Transfer rest of current BB to PassBB
2986
+ PassBB->splice(PassBB->begin(), MBB,
2987
+ std::next(MachineBasicBlock::iterator(MI)), MBB->end());
2988
+ PassBB->transferSuccessorsAndUpdatePHIs(MBB);
2989
+
2990
+ // Trap if vector lengths mismatch
2991
+ BuildMI(TrapBB, DL, TII->get(AArch64::BRK)).addImm(1);
2992
+
2993
+ MBB->addSuccessor(TrapBB);
2994
+ MBB->addSuccessor(PassBB);
2995
+
2996
+ MI.eraseFromParent();
2997
+ return PassBB;
2998
+ }
2999
+
2943
3000
MachineBasicBlock *
2944
3001
AArch64TargetLowering::EmitTileLoad(unsigned Opc, unsigned BaseReg,
2945
3002
MachineInstr &MI,
@@ -3343,6 +3400,9 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
3343
3400
case AArch64::PROBED_STACKALLOC_DYN:
3344
3401
return EmitDynamicProbedAlloc(MI, BB);
3345
3402
3403
+ case AArch64::CHECK_MATCHING_VL_PSEUDO:
3404
+ return EmitCheckMatchingVL(MI, BB);
3405
+
3346
3406
case AArch64::LD1_MXIPXX_H_PSEUDO_B:
3347
3407
return EmitTileLoad(AArch64::LD1_MXIPXX_H_B, AArch64::ZAB0, MI, BB);
3348
3408
case AArch64::LD1_MXIPXX_H_PSEUDO_H:
@@ -9119,14 +9179,29 @@ void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
9119
9179
}
9120
9180
}
9121
9181
9122
- SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
9123
- bool Enable, SDValue Chain,
9124
- SDValue InGlue,
9125
- unsigned Condition) const {
9182
+ SDValue AArch64TargetLowering::changeStreamingMode(
9183
+ SelectionDAG &DAG, SDLoc DL, bool Enable, SDValue Chain, SDValue InGlue,
9184
+ unsigned Condition, bool InsertVectorLengthCheck) const {
9126
9185
MachineFunction &MF = DAG.getMachineFunction();
9127
9186
AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
9128
9187
FuncInfo->setHasStreamingModeChanges(true);
9129
9188
9189
+ auto GetCheckVL = [&](SDValue Chain, SDValue InGlue = SDValue()) -> SDValue {
9190
+ SmallVector<SDValue, 2> Ops = {Chain};
9191
+ if (InGlue)
9192
+ Ops.push_back(InGlue);
9193
+ return DAG.getNode(AArch64ISD::CHECK_MATCHING_VL, DL,
9194
+ DAG.getVTList(MVT::Other, MVT::Glue), Ops);
9195
+ };
9196
+
9197
+ if (InsertVectorLengthCheck && Enable) {
9198
+ // Non-streaming -> Streaming
9199
+ // Insert vector length check before smstart
9200
+ SDValue CheckVL = GetCheckVL(Chain, InGlue);
9201
+ Chain = CheckVL.getValue(0);
9202
+ InGlue = CheckVL.getValue(1);
9203
+ }
9204
+
9130
9205
const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
9131
9206
SDValue RegMask = DAG.getRegisterMask(TRI->getSMStartStopCallPreservedMask());
9132
9207
SDValue MSROp =
@@ -9153,7 +9228,16 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
9153
9228
if (InGlue)
9154
9229
Ops.push_back(InGlue);
9155
9230
9156
- return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
9231
+ SDValue SMChange =
9232
+ DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
9233
+
9234
+ if (!InsertVectorLengthCheck || Enable)
9235
+ return SMChange;
9236
+
9237
+ // Streaming -> Non-streaming
9238
+ // Insert vector length check after smstop since we cannot read VL
9239
+ // in streaming mode
9240
+ return GetCheckVL(SMChange.getValue(0), SMChange.getValue(1));
9157
9241
}
9158
9242
9159
9243
// Emit a call to __arm_sme_save or __arm_sme_restore.
@@ -9735,9 +9819,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9735
9819
9736
9820
SDValue InGlue;
9737
9821
if (RequiresSMChange) {
9738
- Chain =
9739
- changeStreamingMode(DAG, DL, CallAttrs.callee().hasStreamingInterface(),
9740
- Chain, InGlue, getSMToggleCondition(CallAttrs));
9822
+ bool InsertVectorLengthCheck =
9823
+ (CallConv == CallingConv::AArch64_SVE_VectorCall);
9824
+ Chain = changeStreamingMode(
9825
+ DAG, DL, CallAttrs.callee().hasStreamingInterface(), Chain, InGlue,
9826
+ getSMToggleCondition(CallAttrs), InsertVectorLengthCheck);
9741
9827
InGlue = Chain.getValue(1);
9742
9828
}
9743
9829
0 commit comments