Skip to content

Commit 559cfef

Browse files
authored
[AArch64][SME] Introduce CHECK_MATCHING_VL pseudo for streaming transitions (#157510)
This patch adds a new codegen-only pseudo-instruction, `CHECK_MATCHING_VL`, used when transitioning between non-streaming / streaming-compatible callers and streaming-enabled callees. The pseudo verifies that the current SVE vector length (VL) matches the streaming vector length (SVL); if they differ, we trap.
1 parent a044d61 commit 559cfef

10 files changed

+861
-24
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 94 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2940,6 +2940,63 @@ AArch64TargetLowering::EmitDynamicProbedAlloc(MachineInstr &MI,
29402940
return NextInst->getParent();
29412941
}
29422942

2943+
MachineBasicBlock *
2944+
AArch64TargetLowering::EmitCheckMatchingVL(MachineInstr &MI,
2945+
MachineBasicBlock *MBB) const {
2946+
MachineFunction *MF = MBB->getParent();
2947+
MachineRegisterInfo &MRI = MF->getRegInfo();
2948+
2949+
const TargetRegisterClass *RC_GPR = &AArch64::GPR64RegClass;
2950+
const TargetRegisterClass *RC_GPRsp = &AArch64::GPR64spRegClass;
2951+
2952+
Register RegVL_GPR = MRI.createVirtualRegister(RC_GPR);
2953+
Register RegVL_GPRsp = MRI.createVirtualRegister(RC_GPRsp); // for ADDSVL src
2954+
Register RegSVL_GPR = MRI.createVirtualRegister(RC_GPR);
2955+
Register RegSVL_GPRsp = MRI.createVirtualRegister(RC_GPRsp); // for ADDSVL dst
2956+
2957+
const TargetInstrInfo *TII = Subtarget->getInstrInfo();
2958+
DebugLoc DL = MI.getDebugLoc();
2959+
2960+
// RDVL requires GPR64, ADDSVL requires GPR64sp
2961+
// We need to insert COPY instructions, these will later be removed by the
2962+
// RegisterCoalescer
2963+
BuildMI(*MBB, MI, DL, TII->get(AArch64::RDVLI_XI), RegVL_GPR).addImm(1);
2964+
BuildMI(*MBB, MI, DL, TII->get(TargetOpcode::COPY), RegVL_GPRsp)
2965+
.addReg(RegVL_GPR);
2966+
2967+
BuildMI(*MBB, MI, DL, TII->get(AArch64::ADDSVL_XXI), RegSVL_GPRsp)
2968+
.addReg(RegVL_GPRsp)
2969+
.addImm(-1);
2970+
BuildMI(*MBB, MI, DL, TII->get(TargetOpcode::COPY), RegSVL_GPR)
2971+
.addReg(RegSVL_GPRsp);
2972+
2973+
const BasicBlock *LLVM_BB = MBB->getBasicBlock();
2974+
MachineFunction::iterator It = ++MBB->getIterator();
2975+
MachineBasicBlock *TrapBB = MF->CreateMachineBasicBlock(LLVM_BB);
2976+
MachineBasicBlock *PassBB = MF->CreateMachineBasicBlock(LLVM_BB);
2977+
MF->insert(It, TrapBB);
2978+
MF->insert(It, PassBB);
2979+
2980+
// Continue if vector lengths match
2981+
BuildMI(*MBB, MI, DL, TII->get(AArch64::CBZX))
2982+
.addReg(RegSVL_GPR)
2983+
.addMBB(PassBB);
2984+
2985+
// Transfer rest of current BB to PassBB
2986+
PassBB->splice(PassBB->begin(), MBB,
2987+
std::next(MachineBasicBlock::iterator(MI)), MBB->end());
2988+
PassBB->transferSuccessorsAndUpdatePHIs(MBB);
2989+
2990+
// Trap if vector lengths mismatch
2991+
BuildMI(TrapBB, DL, TII->get(AArch64::BRK)).addImm(1);
2992+
2993+
MBB->addSuccessor(TrapBB);
2994+
MBB->addSuccessor(PassBB);
2995+
2996+
MI.eraseFromParent();
2997+
return PassBB;
2998+
}
2999+
29433000
MachineBasicBlock *
29443001
AArch64TargetLowering::EmitTileLoad(unsigned Opc, unsigned BaseReg,
29453002
MachineInstr &MI,
@@ -3343,6 +3400,9 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
33433400
case AArch64::PROBED_STACKALLOC_DYN:
33443401
return EmitDynamicProbedAlloc(MI, BB);
33453402

3403+
case AArch64::CHECK_MATCHING_VL_PSEUDO:
3404+
return EmitCheckMatchingVL(MI, BB);
3405+
33463406
case AArch64::LD1_MXIPXX_H_PSEUDO_B:
33473407
return EmitTileLoad(AArch64::LD1_MXIPXX_H_B, AArch64::ZAB0, MI, BB);
33483408
case AArch64::LD1_MXIPXX_H_PSEUDO_H:
@@ -9119,14 +9179,29 @@ void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
91199179
}
91209180
}
91219181

9122-
SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
9123-
bool Enable, SDValue Chain,
9124-
SDValue InGlue,
9125-
unsigned Condition) const {
9182+
SDValue AArch64TargetLowering::changeStreamingMode(
9183+
SelectionDAG &DAG, SDLoc DL, bool Enable, SDValue Chain, SDValue InGlue,
9184+
unsigned Condition, bool InsertVectorLengthCheck) const {
91269185
MachineFunction &MF = DAG.getMachineFunction();
91279186
AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
91289187
FuncInfo->setHasStreamingModeChanges(true);
91299188

9189+
auto GetCheckVL = [&](SDValue Chain, SDValue InGlue = SDValue()) -> SDValue {
9190+
SmallVector<SDValue, 2> Ops = {Chain};
9191+
if (InGlue)
9192+
Ops.push_back(InGlue);
9193+
return DAG.getNode(AArch64ISD::CHECK_MATCHING_VL, DL,
9194+
DAG.getVTList(MVT::Other, MVT::Glue), Ops);
9195+
};
9196+
9197+
if (InsertVectorLengthCheck && Enable) {
9198+
// Non-streaming -> Streaming
9199+
// Insert vector length check before smstart
9200+
SDValue CheckVL = GetCheckVL(Chain, InGlue);
9201+
Chain = CheckVL.getValue(0);
9202+
InGlue = CheckVL.getValue(1);
9203+
}
9204+
91309205
const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
91319206
SDValue RegMask = DAG.getRegisterMask(TRI->getSMStartStopCallPreservedMask());
91329207
SDValue MSROp =
@@ -9153,7 +9228,16 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
91539228
if (InGlue)
91549229
Ops.push_back(InGlue);
91559230

9156-
return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
9231+
SDValue SMChange =
9232+
DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
9233+
9234+
if (!InsertVectorLengthCheck || Enable)
9235+
return SMChange;
9236+
9237+
// Streaming -> Non-streaming
9238+
// Insert vector length check after smstop since we cannot read VL
9239+
// in streaming mode
9240+
return GetCheckVL(SMChange.getValue(0), SMChange.getValue(1));
91579241
}
91589242

91599243
// Emit a call to __arm_sme_save or __arm_sme_restore.
@@ -9735,9 +9819,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
97359819

97369820
SDValue InGlue;
97379821
if (RequiresSMChange) {
9738-
Chain =
9739-
changeStreamingMode(DAG, DL, CallAttrs.callee().hasStreamingInterface(),
9740-
Chain, InGlue, getSMToggleCondition(CallAttrs));
9822+
bool InsertVectorLengthCheck =
9823+
(CallConv == CallingConv::AArch64_SVE_VectorCall);
9824+
Chain = changeStreamingMode(
9825+
DAG, DL, CallAttrs.callee().hasStreamingInterface(), Chain, InGlue,
9826+
getSMToggleCondition(CallAttrs), InsertVectorLengthCheck);
97419827
InGlue = Chain.getValue(1);
97429828
}
97439829

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,9 @@ class AArch64TargetLowering : public TargetLowering {
168168
MachineBasicBlock *EmitDynamicProbedAlloc(MachineInstr &MI,
169169
MachineBasicBlock *MBB) const;
170170

171+
MachineBasicBlock *EmitCheckMatchingVL(MachineInstr &MI,
172+
MachineBasicBlock *MBB) const;
173+
171174
MachineBasicBlock *EmitTileLoad(unsigned Opc, unsigned BaseReg,
172175
MachineInstr &MI,
173176
MachineBasicBlock *BB) const;
@@ -532,8 +535,8 @@ class AArch64TargetLowering : public TargetLowering {
532535
/// node. \p Condition should be one of the enum values from
533536
/// AArch64SME::ToggleCondition.
534537
SDValue changeStreamingMode(SelectionDAG &DAG, SDLoc DL, bool Enable,
535-
SDValue Chain, SDValue InGlue,
536-
unsigned Condition) const;
538+
SDValue Chain, SDValue InGlue, unsigned Condition,
539+
bool InsertVectorLengthCheck = false) const;
537540

538541
bool isVScaleKnownToBeAPowerOfTwo() const override { return true; }
539542

llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,17 @@ let usesCustomInserter = 1 in {
4848
}
4949
def : Pat<(i64 (AArch64EntryPStateSM)), (EntryPStateSM)>;
5050

51+
// Pseudo-instruction that compares the current SVE vector length (VL) with the
52+
// streaming vector length (SVL). If the two lengths do not match, the check
53+
// lowers to a `brk`, causing a trap.
54+
let hasSideEffects = 1, isCodeGenOnly = 1, usesCustomInserter = 1 in
55+
def CHECK_MATCHING_VL_PSEUDO : Pseudo<(outs), (ins), []>, Sched<[]>;
56+
57+
def AArch64_check_matching_vl
58+
: SDNode<"AArch64ISD::CHECK_MATCHING_VL", SDTypeProfile<0, 0,[]>,
59+
[SDNPHasChain, SDNPOptInGlue, SDNPOutGlue]>;
60+
def : Pat<(AArch64_check_matching_vl), (CHECK_MATCHING_VL_PSEUDO)>;
61+
5162
//===----------------------------------------------------------------------===//
5263
// Old SME ABI lowering ISD nodes/pseudos (deprecated)
5364
//===----------------------------------------------------------------------===//

llvm/test/CodeGen/AArch64/sme-callee-save-restore-pairs.ll

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,18 @@ define void @fbyte(<vscale x 16 x i8> %v) #0{
4747
; NOPAIR-NEXT: // %bb.1:
4848
; NOPAIR-NEXT: smstop sm
4949
; NOPAIR-NEXT: .LBB0_2:
50+
; NOPAIR-NEXT: rdvl x8, #1
51+
; NOPAIR-NEXT: addsvl x8, x8, #-1
52+
; NOPAIR-NEXT: cbz x8, .LBB0_4
53+
; NOPAIR-NEXT: // %bb.3:
54+
; NOPAIR-NEXT: brk #0x1
55+
; NOPAIR-NEXT: .LBB0_4:
5056
; NOPAIR-NEXT: ldr z0, [sp] // 16-byte Folded Reload
5157
; NOPAIR-NEXT: bl my_func2
52-
; NOPAIR-NEXT: tbz w19, #0, .LBB0_4
53-
; NOPAIR-NEXT: // %bb.3:
58+
; NOPAIR-NEXT: tbz w19, #0, .LBB0_6
59+
; NOPAIR-NEXT: // %bb.5:
5460
; NOPAIR-NEXT: smstart sm
55-
; NOPAIR-NEXT: .LBB0_4:
61+
; NOPAIR-NEXT: .LBB0_6:
5662
; NOPAIR-NEXT: addvl sp, sp, #1
5763
; NOPAIR-NEXT: ldr z23, [sp, #2, mul vl] // 16-byte Folded Reload
5864
; NOPAIR-NEXT: ldr z22, [sp, #3, mul vl] // 16-byte Folded Reload
@@ -127,12 +133,18 @@ define void @fbyte(<vscale x 16 x i8> %v) #0{
127133
; PAIR-NEXT: // %bb.1:
128134
; PAIR-NEXT: smstop sm
129135
; PAIR-NEXT: .LBB0_2:
136+
; PAIR-NEXT: rdvl x8, #1
137+
; PAIR-NEXT: addsvl x8, x8, #-1
138+
; PAIR-NEXT: cbz x8, .LBB0_4
139+
; PAIR-NEXT: // %bb.3:
140+
; PAIR-NEXT: brk #0x1
141+
; PAIR-NEXT: .LBB0_4:
130142
; PAIR-NEXT: ldr z0, [sp] // 16-byte Folded Reload
131143
; PAIR-NEXT: bl my_func2
132-
; PAIR-NEXT: tbz w19, #0, .LBB0_4
133-
; PAIR-NEXT: // %bb.3:
144+
; PAIR-NEXT: tbz w19, #0, .LBB0_6
145+
; PAIR-NEXT: // %bb.5:
134146
; PAIR-NEXT: smstart sm
135-
; PAIR-NEXT: .LBB0_4:
147+
; PAIR-NEXT: .LBB0_6:
136148
; PAIR-NEXT: addvl sp, sp, #1
137149
; PAIR-NEXT: ldr z23, [sp, #2, mul vl] // 16-byte Folded Reload
138150
; PAIR-NEXT: ldr z22, [sp, #3, mul vl] // 16-byte Folded Reload

llvm/test/CodeGen/AArch64/sme-peephole-opts.ll

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,14 +526,24 @@ define void @test13(ptr %ptr) nounwind "aarch64_pstate_sm_enabled" {
526526
; CHECK-NEXT: stp x30, x19, [sp, #80] // 16-byte Folded Spill
527527
; CHECK-NEXT: addvl sp, sp, #-1
528528
; CHECK-NEXT: mov z0.s, #0 // =0x0
529-
; CHECK-NEXT: mov x19, x0
530529
; CHECK-NEXT: str z0, [sp] // 16-byte Folded Spill
531530
; CHECK-NEXT: smstop sm
531+
; CHECK-NEXT: rdvl x8, #1
532+
; CHECK-NEXT: addsvl x8, x8, #-1
533+
; CHECK-NEXT: cbnz x8, .LBB14_2
534+
; CHECK-NEXT: // %bb.1:
532535
; CHECK-NEXT: ldr z0, [sp] // 16-byte Folded Reload
536+
; CHECK-NEXT: mov x19, x0
533537
; CHECK-NEXT: bl callee_farg_fret
534538
; CHECK-NEXT: str z0, [sp] // 16-byte Folded Spill
535539
; CHECK-NEXT: smstart sm
536540
; CHECK-NEXT: smstop sm
541+
; CHECK-NEXT: rdvl x8, #1
542+
; CHECK-NEXT: addsvl x8, x8, #-1
543+
; CHECK-NEXT: cbz x8, .LBB14_3
544+
; CHECK-NEXT: .LBB14_2:
545+
; CHECK-NEXT: brk #0x1
546+
; CHECK-NEXT: .LBB14_3:
537547
; CHECK-NEXT: ldr z0, [sp] // 16-byte Folded Reload
538548
; CHECK-NEXT: bl callee_farg_fret
539549
; CHECK-NEXT: str z0, [sp] // 16-byte Folded Spill

0 commit comments

Comments
 (0)