Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 93 additions & 8 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2940,6 +2940,56 @@ AArch64TargetLowering::EmitDynamicProbedAlloc(MachineInstr &MI,
return NextInst->getParent();
}

MachineBasicBlock *
AArch64TargetLowering::EmitCheckMatchingVL(MachineInstr &MI,
MachineBasicBlock *MBB) const {
MachineFunction *MF = MBB->getParent();
const TargetInstrInfo *TII = Subtarget->getInstrInfo();
const BasicBlock *LLVM_BB = MBB->getBasicBlock();
DebugLoc DL = MI.getDebugLoc();
MachineFunction::iterator It = ++MBB->getIterator();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super-nit: Rename to something explicit like MBBInsertPoint

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super-nit: Also please move this closer to its use rather than defining it at the top of the function.


const TargetRegisterClass *RC = &AArch64::GPR64RegClass;
MachineRegisterInfo &MRI = MF->getRegInfo();

Register RegVL = MRI.createVirtualRegister(RC);
Register RegSVL = MRI.createVirtualRegister(RC);
Register RegCheck = MRI.createVirtualRegister(RC);

// Read VL and Streaming VL
BuildMI(*MBB, MI, DL, TII->get(AArch64::RDVLI_XI), RegVL).addImm(1);
BuildMI(*MBB, MI, DL, TII->get(AArch64::RDSVLI_XI), RegSVL).addImm(1);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this comment (and some similar comments below) is redundant because it's clear from the statement below that is what it's doing. It might be more helpful to explain (once) why the COPY's are needed , e.g. something along the lines of "because ADDSVL requires GPR64sp and RDVL requires GPR64, we need to insert some COPYs that will be removed by the RegisterCoalescer".

// Compare vector lengths
BuildMI(*MBB, MI, DL, TII->get(AArch64::SUBXrr), RegCheck)
.addReg(RegVL)
.addReg(RegSVL);

MachineBasicBlock *TrapBB = MF->CreateMachineBasicBlock(LLVM_BB);
MachineBasicBlock *PassBB = MF->CreateMachineBasicBlock(LLVM_BB);
MF->insert(It, TrapBB);
MF->insert(It, PassBB);

// Continue if vector lengths match
BuildMI(*MBB, MI, DL, TII->get(AArch64::CBZX))
.addReg(RegCheck)
.addMBB(PassBB);

// Transfer rest of current BB to PassBB
PassBB->splice(PassBB->begin(), MBB,
std::next(MachineBasicBlock::iterator(MI)), MBB->end());
PassBB->transferSuccessorsAndUpdatePHIs(MBB);

// Trap if vector lengths mismatch
BuildMI(TrapBB, DL, TII->get(AArch64::BRK)).addImm(1);

MBB->addSuccessor(TrapBB);
MBB->addSuccessor(PassBB);

MI.eraseFromParent();
return PassBB;
}

MachineBasicBlock *
AArch64TargetLowering::EmitTileLoad(unsigned Opc, unsigned BaseReg,
MachineInstr &MI,
Expand Down Expand Up @@ -3343,6 +3393,9 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
case AArch64::PROBED_STACKALLOC_DYN:
return EmitDynamicProbedAlloc(MI, BB);

case AArch64::CHECK_MATCHING_VL_PSEUDO:
return EmitCheckMatchingVL(MI, BB);

case AArch64::LD1_MXIPXX_H_PSEUDO_B:
return EmitTileLoad(AArch64::LD1_MXIPXX_H_B, AArch64::ZAB0, MI, BB);
case AArch64::LD1_MXIPXX_H_PSEUDO_H:
Expand Down Expand Up @@ -9113,10 +9166,9 @@ void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
}
}

SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
bool Enable, SDValue Chain,
SDValue InGlue,
unsigned Condition) const {
SDValue AArch64TargetLowering::changeStreamingMode(
SelectionDAG &DAG, SDLoc DL, bool Enable, SDValue Chain, SDValue InGlue,
unsigned Condition, bool InsertVectorLengthCheck) const {
MachineFunction &MF = DAG.getMachineFunction();
AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
FuncInfo->setHasStreamingModeChanges(true);
Expand Down Expand Up @@ -9147,7 +9199,38 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
if (InGlue)
Ops.push_back(InGlue);

return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
if (!InsertVectorLengthCheck)
return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Supert-nit: I would "inilne" that lambda, I feel something like below is easier to read, and does not duplicate too much code.

SDValue CheckVL = SDValue(DAG.getMachineNode(AArch64::CHECK_MATCHING_VL, DL,
                                      DAG.getVTList(MVT::Other, MVT::Glue),
                                      {Chaine, InGlue}), 0);

But really, feel free to keep as is if you prefer.


auto GetCheckVL = [&](SDValue Chain, SDValue InGlue = SDValue()) -> SDValue {
SmallVector<SDValue, 2> Ops = {Chain};
if (InGlue)
Ops.push_back(InGlue);
return DAG.getNode(AArch64ISD::CHECK_MATCHING_VL, DL,
DAG.getVTList(MVT::Other, MVT::Glue), Ops);
};

// Non-streaming -> Streaming
if (Enable) {
SDValue CheckVL = GetCheckVL(Chain, InGlue);

// Replace chain
Ops[0] = CheckVL.getValue(0);

// Replace/append glue
if (InGlue)
Ops.back() = CheckVL.getValue(1);
else
Ops.push_back(CheckVL.getValue(1));

return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
}

// Streaming -> Non-streaming
SDValue StreamingModeInstr =
DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
return GetCheckVL(StreamingModeInstr.getValue(0),
StreamingModeInstr.getValue(1));
}

// Emit a call to __arm_sme_save or __arm_sme_restore.
Expand Down Expand Up @@ -9730,9 +9813,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,

SDValue InGlue;
if (RequiresSMChange) {
Chain =
changeStreamingMode(DAG, DL, CallAttrs.callee().hasStreamingInterface(),
Chain, InGlue, getSMToggleCondition(CallAttrs));
bool InsertVectorLengthCheck =
(CallConv == CallingConv::AArch64_SVE_VectorCall);
Chain = changeStreamingMode(
DAG, DL, CallAttrs.callee().hasStreamingInterface(), Chain, InGlue,
getSMToggleCondition(CallAttrs), InsertVectorLengthCheck);
InGlue = Chain.getValue(1);
}

Expand Down
7 changes: 5 additions & 2 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ class AArch64TargetLowering : public TargetLowering {
MachineBasicBlock *EmitDynamicProbedAlloc(MachineInstr &MI,
MachineBasicBlock *MBB) const;

MachineBasicBlock *EmitCheckMatchingVL(MachineInstr &MI,
MachineBasicBlock *MBB) const;

MachineBasicBlock *EmitTileLoad(unsigned Opc, unsigned BaseReg,
MachineInstr &MI,
MachineBasicBlock *BB) const;
Expand Down Expand Up @@ -532,8 +535,8 @@ class AArch64TargetLowering : public TargetLowering {
/// node. \p Condition should be one of the enum values from
/// AArch64SME::ToggleCondition.
SDValue changeStreamingMode(SelectionDAG &DAG, SDLoc DL, bool Enable,
SDValue Chain, SDValue InGlue,
unsigned Condition) const;
SDValue Chain, SDValue InGlue, unsigned Condition,
bool InsertVectorLengthCheck = false) const;

bool isVScaleKnownToBeAPowerOfTwo() const override { return true; }

Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,17 @@ let usesCustomInserter = 1 in {
}
def : Pat<(i64 (AArch64EntryPStateSM)), (EntryPStateSM)>;

// Pseudo-instruction that compares the current SVE vector length (VL) with the
// streaming vector length (SVL). If the two lengths do not match, the check
// lowers to a `brk`, causing a trap.
let hasSideEffects = 1, isCodeGenOnly = 1, usesCustomInserter = 1 in
def CHECK_MATCHING_VL_PSEUDO : Pseudo<(outs), (ins), []>, Sched<[]>;

def AArch64_check_matching_vl
: SDNode<"AArch64ISD::CHECK_MATCHING_VL", SDTypeProfile<0, 0,[]>,
[SDNPHasChain, SDNPOptInGlue, SDNPOutGlue]>;
def : Pat<(AArch64_check_matching_vl), (CHECK_MATCHING_VL_PSEUDO)>;

//===----------------------------------------------------------------------===//
// Old SME ABI lowering ISD nodes/pseudos (deprecated)
//===----------------------------------------------------------------------===//
Expand Down
26 changes: 20 additions & 6 deletions llvm/test/CodeGen/AArch64/sme-callee-save-restore-pairs.ll
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,19 @@ define void @fbyte(<vscale x 16 x i8> %v) #0{
; NOPAIR-NEXT: // %bb.1:
; NOPAIR-NEXT: smstop sm
; NOPAIR-NEXT: .LBB0_2:
; NOPAIR-NEXT: rdvl x8, #1
; NOPAIR-NEXT: rdsvl x9, #1
; NOPAIR-NEXT: cmp x8, x9
; NOPAIR-NEXT: b.eq .LBB0_4
; NOPAIR-NEXT: // %bb.3:
; NOPAIR-NEXT: brk #0x1
; NOPAIR-NEXT: .LBB0_4:
; NOPAIR-NEXT: ldr z0, [sp] // 16-byte Folded Reload
; NOPAIR-NEXT: bl my_func2
; NOPAIR-NEXT: tbz w19, #0, .LBB0_4
; NOPAIR-NEXT: // %bb.3:
; NOPAIR-NEXT: tbz w19, #0, .LBB0_6
; NOPAIR-NEXT: // %bb.5:
; NOPAIR-NEXT: smstart sm
; NOPAIR-NEXT: .LBB0_4:
; NOPAIR-NEXT: .LBB0_6:
; NOPAIR-NEXT: addvl sp, sp, #1
; NOPAIR-NEXT: ldr z23, [sp, #2, mul vl] // 16-byte Folded Reload
; NOPAIR-NEXT: ldr z22, [sp, #3, mul vl] // 16-byte Folded Reload
Expand Down Expand Up @@ -127,12 +134,19 @@ define void @fbyte(<vscale x 16 x i8> %v) #0{
; PAIR-NEXT: // %bb.1:
; PAIR-NEXT: smstop sm
; PAIR-NEXT: .LBB0_2:
; PAIR-NEXT: rdvl x8, #1
; PAIR-NEXT: rdsvl x9, #1
; PAIR-NEXT: cmp x8, x9
; PAIR-NEXT: b.eq .LBB0_4
; PAIR-NEXT: // %bb.3:
; PAIR-NEXT: brk #0x1
; PAIR-NEXT: .LBB0_4:
; PAIR-NEXT: ldr z0, [sp] // 16-byte Folded Reload
; PAIR-NEXT: bl my_func2
; PAIR-NEXT: tbz w19, #0, .LBB0_4
; PAIR-NEXT: // %bb.3:
; PAIR-NEXT: tbz w19, #0, .LBB0_6
; PAIR-NEXT: // %bb.5:
; PAIR-NEXT: smstart sm
; PAIR-NEXT: .LBB0_4:
; PAIR-NEXT: .LBB0_6:
; PAIR-NEXT: addvl sp, sp, #1
; PAIR-NEXT: ldr z23, [sp, #2, mul vl] // 16-byte Folded Reload
; PAIR-NEXT: ldr z22, [sp, #3, mul vl] // 16-byte Folded Reload
Expand Down
14 changes: 13 additions & 1 deletion llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
Original file line number Diff line number Diff line change
Expand Up @@ -527,14 +527,26 @@ define void @test13(ptr %ptr) nounwind "aarch64_pstate_sm_enabled" {
; CHECK-NEXT: stp x30, x19, [sp, #80] // 16-byte Folded Spill
; CHECK-NEXT: addvl sp, sp, #-1
; CHECK-NEXT: mov z0.s, #0 // =0x0
; CHECK-NEXT: mov x19, x0
; CHECK-NEXT: str z0, [sp] // 16-byte Folded Spill
; CHECK-NEXT: smstop sm
; CHECK-NEXT: rdvl x8, #1
; CHECK-NEXT: rdsvl x9, #1
; CHECK-NEXT: cmp x8, x9
; CHECK-NEXT: b.ne .LBB14_2
; CHECK-NEXT: // %bb.1:
; CHECK-NEXT: ldr z0, [sp] // 16-byte Folded Reload
; CHECK-NEXT: mov x19, x0
; CHECK-NEXT: bl callee_farg_fret
; CHECK-NEXT: str z0, [sp] // 16-byte Folded Spill
; CHECK-NEXT: smstart sm
; CHECK-NEXT: smstop sm
; CHECK-NEXT: rdvl x8, #1
; CHECK-NEXT: rdsvl x9, #1
; CHECK-NEXT: cmp x8, x9
; CHECK-NEXT: b.eq .LBB14_3
; CHECK-NEXT: .LBB14_2:
; CHECK-NEXT: brk #0x1
; CHECK-NEXT: .LBB14_3:
; CHECK-NEXT: ldr z0, [sp] // 16-byte Folded Reload
; CHECK-NEXT: bl callee_farg_fret
; CHECK-NEXT: str z0, [sp] // 16-byte Folded Spill
Expand Down
Loading
Loading