@@ -2940,6 +2940,63 @@ AArch64TargetLowering::EmitDynamicProbedAlloc(MachineInstr &MI,
29402940 return NextInst->getParent();
29412941}
29422942
2943+ MachineBasicBlock *
2944+ AArch64TargetLowering::EmitCheckMatchingVL(MachineInstr &MI,
2945+ MachineBasicBlock *MBB) const {
2946+ MachineFunction *MF = MBB->getParent();
2947+ MachineRegisterInfo &MRI = MF->getRegInfo();
2948+
2949+ const TargetRegisterClass *RC_GPR = &AArch64::GPR64RegClass;
2950+ const TargetRegisterClass *RC_GPRsp = &AArch64::GPR64spRegClass;
2951+
2952+ Register RegVL_GPR = MRI.createVirtualRegister(RC_GPR);
2953+ Register RegVL_GPRsp = MRI.createVirtualRegister(RC_GPRsp); // for ADDSVL src
2954+ Register RegSVL_GPR = MRI.createVirtualRegister(RC_GPR);
2955+ Register RegSVL_GPRsp = MRI.createVirtualRegister(RC_GPRsp); // for ADDSVL dst
2956+
2957+ const TargetInstrInfo *TII = Subtarget->getInstrInfo();
2958+ DebugLoc DL = MI.getDebugLoc();
2959+
2960+ // RDVL requires GPR64, ADDSVL requires GPR64sp
2961+ // We need to insert COPY instructions, these will later be removed by the
2962+ // RegisterCoalescer
2963+ BuildMI(*MBB, MI, DL, TII->get(AArch64::RDVLI_XI), RegVL_GPR).addImm(1);
2964+ BuildMI(*MBB, MI, DL, TII->get(TargetOpcode::COPY), RegVL_GPRsp)
2965+ .addReg(RegVL_GPR);
2966+
2967+ BuildMI(*MBB, MI, DL, TII->get(AArch64::ADDSVL_XXI), RegSVL_GPRsp)
2968+ .addReg(RegVL_GPRsp)
2969+ .addImm(-1);
2970+ BuildMI(*MBB, MI, DL, TII->get(TargetOpcode::COPY), RegSVL_GPR)
2971+ .addReg(RegSVL_GPRsp);
2972+
2973+ const BasicBlock *LLVM_BB = MBB->getBasicBlock();
2974+ MachineFunction::iterator It = ++MBB->getIterator();
2975+ MachineBasicBlock *TrapBB = MF->CreateMachineBasicBlock(LLVM_BB);
2976+ MachineBasicBlock *PassBB = MF->CreateMachineBasicBlock(LLVM_BB);
2977+ MF->insert(It, TrapBB);
2978+ MF->insert(It, PassBB);
2979+
2980+ // Continue if vector lengths match
2981+ BuildMI(*MBB, MI, DL, TII->get(AArch64::CBZX))
2982+ .addReg(RegSVL_GPR)
2983+ .addMBB(PassBB);
2984+
2985+ // Transfer rest of current BB to PassBB
2986+ PassBB->splice(PassBB->begin(), MBB,
2987+ std::next(MachineBasicBlock::iterator(MI)), MBB->end());
2988+ PassBB->transferSuccessorsAndUpdatePHIs(MBB);
2989+
2990+ // Trap if vector lengths mismatch
2991+ BuildMI(TrapBB, DL, TII->get(AArch64::BRK)).addImm(1);
2992+
2993+ MBB->addSuccessor(TrapBB);
2994+ MBB->addSuccessor(PassBB);
2995+
2996+ MI.eraseFromParent();
2997+ return PassBB;
2998+ }
2999+
29433000MachineBasicBlock *
29443001AArch64TargetLowering::EmitTileLoad(unsigned Opc, unsigned BaseReg,
29453002 MachineInstr &MI,
@@ -3343,6 +3400,9 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
33433400 case AArch64::PROBED_STACKALLOC_DYN:
33443401 return EmitDynamicProbedAlloc(MI, BB);
33453402
3403+ case AArch64::CHECK_MATCHING_VL_PSEUDO:
3404+ return EmitCheckMatchingVL(MI, BB);
3405+
33463406 case AArch64::LD1_MXIPXX_H_PSEUDO_B:
33473407 return EmitTileLoad(AArch64::LD1_MXIPXX_H_B, AArch64::ZAB0, MI, BB);
33483408 case AArch64::LD1_MXIPXX_H_PSEUDO_H:
@@ -9119,14 +9179,29 @@ void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
91199179 }
91209180}
91219181
9122- SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
9123- bool Enable, SDValue Chain,
9124- SDValue InGlue,
9125- unsigned Condition) const {
9182+ SDValue AArch64TargetLowering::changeStreamingMode(
9183+ SelectionDAG &DAG, SDLoc DL, bool Enable, SDValue Chain, SDValue InGlue,
9184+ unsigned Condition, bool InsertVectorLengthCheck) const {
91269185 MachineFunction &MF = DAG.getMachineFunction();
91279186 AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
91289187 FuncInfo->setHasStreamingModeChanges(true);
91299188
9189+ auto GetCheckVL = [&](SDValue Chain, SDValue InGlue = SDValue()) -> SDValue {
9190+ SmallVector<SDValue, 2> Ops = {Chain};
9191+ if (InGlue)
9192+ Ops.push_back(InGlue);
9193+ return DAG.getNode(AArch64ISD::CHECK_MATCHING_VL, DL,
9194+ DAG.getVTList(MVT::Other, MVT::Glue), Ops);
9195+ };
9196+
9197+ if (InsertVectorLengthCheck && Enable) {
9198+ // Non-streaming -> Streaming
9199+ // Insert vector length check before smstart
9200+ SDValue CheckVL = GetCheckVL(Chain, InGlue);
9201+ Chain = CheckVL.getValue(0);
9202+ InGlue = CheckVL.getValue(1);
9203+ }
9204+
91309205 const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
91319206 SDValue RegMask = DAG.getRegisterMask(TRI->getSMStartStopCallPreservedMask());
91329207 SDValue MSROp =
@@ -9153,7 +9228,16 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
91539228 if (InGlue)
91549229 Ops.push_back(InGlue);
91559230
9156- return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
9231+ SDValue SMChange =
9232+ DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
9233+
9234+ if (!InsertVectorLengthCheck || Enable)
9235+ return SMChange;
9236+
9237+ // Streaming -> Non-streaming
9238+ // Insert vector length check after smstop since we cannot read VL
9239+ // in streaming mode
9240+ return GetCheckVL(SMChange.getValue(0), SMChange.getValue(1));
91579241}
91589242
91599243// Emit a call to __arm_sme_save or __arm_sme_restore.
@@ -9735,9 +9819,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
97359819
97369820 SDValue InGlue;
97379821 if (RequiresSMChange) {
9738- Chain =
9739- changeStreamingMode(DAG, DL, CallAttrs.callee().hasStreamingInterface(),
9740- Chain, InGlue, getSMToggleCondition(CallAttrs));
9822+ bool InsertVectorLengthCheck =
9823+ (CallConv == CallingConv::AArch64_SVE_VectorCall);
9824+ Chain = changeStreamingMode(
9825+ DAG, DL, CallAttrs.callee().hasStreamingInterface(), Chain, InGlue,
9826+ getSMToggleCondition(CallAttrs), InsertVectorLengthCheck);
97419827 InGlue = Chain.getValue(1);
97429828 }
97439829
0 commit comments