Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 87 additions & 3 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2940,6 +2940,52 @@ AArch64TargetLowering::EmitDynamicProbedAlloc(MachineInstr &MI,
return NextInst->getParent();
}

MachineBasicBlock *
AArch64TargetLowering::EmitCheckVL(MachineInstr &MI,
MachineBasicBlock *MBB) const {
MachineFunction *MF = MBB->getParent();
const TargetInstrInfo *TII = Subtarget->getInstrInfo();
const BasicBlock *LLVM_BB = MBB->getBasicBlock();
DebugLoc DL = MI.getDebugLoc();
MachineFunction::iterator It = ++MBB->getIterator();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super-nit: Rename to something explicit like MBBInsertPoint

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super-nit: Also please move this closer to its use rather than defining it at the top of the function.


const TargetRegisterClass *RC = &AArch64::GPR64RegClass;
MachineRegisterInfo &MRI = MF->getRegInfo();

Register RegVL = MRI.createVirtualRegister(RC);
Register RegSVL = MRI.createVirtualRegister(RC);
Register RegCheck = MRI.createVirtualRegister(RC);

BuildMI(*MBB, MI, DL, TII->get(AArch64::RDVLI_XI), RegVL).addImm(1);
BuildMI(*MBB, MI, DL, TII->get(AArch64::RDSVLI_XI), RegSVL).addImm(1);

BuildMI(*MBB, MI, DL, TII->get(AArch64::SUBXrr), RegCheck)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this comment (and some similar comments below) is redundant because it's clear from the statement below that is what it's doing. It might be more helpful to explain (once) why the COPY's are needed , e.g. something along the lines of "because ADDSVL requires GPR64sp and RDVL requires GPR64, we need to insert some COPYs that will be removed by the RegisterCoalescer".

.addReg(RegVL)
.addReg(RegSVL);

MachineBasicBlock *TrapBB = MF->CreateMachineBasicBlock(LLVM_BB);
MachineBasicBlock *PassBB = MF->CreateMachineBasicBlock(LLVM_BB);
MF->insert(It, TrapBB);
MF->insert(It, PassBB);

BuildMI(*MBB, MI, DL, TII->get(AArch64::CBZX))
.addReg(RegCheck)
.addMBB(PassBB);

// Transfer rest of current BB to PassBB
PassBB->splice(PassBB->begin(), MBB,
std::next(MachineBasicBlock::iterator(MI)), MBB->end());
PassBB->transferSuccessorsAndUpdatePHIs(MBB);

BuildMI(TrapBB, DL, TII->get(AArch64::BRK)).addImm(1);

MBB->addSuccessor(TrapBB);
MBB->addSuccessor(PassBB);

MI.eraseFromParent();
return PassBB;
}

MachineBasicBlock *
AArch64TargetLowering::EmitTileLoad(unsigned Opc, unsigned BaseReg,
MachineInstr &MI,
Expand Down Expand Up @@ -3343,6 +3389,9 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
case AArch64::PROBED_STACKALLOC_DYN:
return EmitDynamicProbedAlloc(MI, BB);

case AArch64::CHECK_MATCHING_VL:
return EmitCheckVL(MI, BB);

case AArch64::LD1_MXIPXX_H_PSEUDO_B:
return EmitTileLoad(AArch64::LD1_MXIPXX_H_B, AArch64::ZAB0, MI, BB);
case AArch64::LD1_MXIPXX_H_PSEUDO_H:
Expand Down Expand Up @@ -9116,7 +9165,8 @@ void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
bool Enable, SDValue Chain,
SDValue InGlue,
unsigned Condition) const {
unsigned Condition,
bool HasSVECC) const {
MachineFunction &MF = DAG.getMachineFunction();
AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
FuncInfo->setHasStreamingModeChanges(true);
Expand Down Expand Up @@ -9147,7 +9197,40 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
if (InGlue)
Ops.push_back(InGlue);

return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
if (!HasSVECC)
return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);

auto GetCheckVL = [&](SDValue Chain, SDValue InGlue = SDValue()) -> SDValue {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Supert-nit: I would "inilne" that lambda, I feel something like below is easier to read, and does not duplicate too much code.

SDValue CheckVL = SDValue(DAG.getMachineNode(AArch64::CHECK_MATCHING_VL, DL,
                                      DAG.getVTList(MVT::Other, MVT::Glue),
                                      {Chaine, InGlue}), 0);

But really, feel free to keep as is if you prefer.

SmallVector<SDValue, 2> Ops = {Chain};
if (InGlue)
Ops.push_back(InGlue);
return SDValue(DAG.getMachineNode(AArch64::CHECK_MATCHING_VL, DL,
DAG.getVTList(MVT::Other, MVT::Glue),
Ops),
0);
};

// NS -> S
if (Enable) {
SDValue CheckVL = GetCheckVL(Chain, InGlue);

// Replace chain
Ops[0] = CheckVL.getValue(0);

// Replace/append glue
if (InGlue)
Ops.back() = CheckVL.getValue(1);
else
Ops.push_back(CheckVL.getValue(1));

return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
}

// S -> NS
SDValue StreamingModeInstr =
DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
return GetCheckVL(StreamingModeInstr.getValue(0),
StreamingModeInstr.getValue(1));
}

// Emit a call to __arm_sme_save or __arm_sme_restore.
Expand Down Expand Up @@ -9732,7 +9815,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
if (RequiresSMChange) {
Chain =
changeStreamingMode(DAG, DL, CallAttrs.callee().hasStreamingInterface(),
Chain, InGlue, getSMToggleCondition(CallAttrs));
Chain, InGlue, getSMToggleCondition(CallAttrs),
CallConv == CallingConv::AArch64_SVE_VectorCall);
InGlue = Chain.getValue(1);
}

Expand Down
7 changes: 5 additions & 2 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ class AArch64TargetLowering : public TargetLowering {
MachineBasicBlock *EmitDynamicProbedAlloc(MachineInstr &MI,
MachineBasicBlock *MBB) const;

MachineBasicBlock *EmitCheckVL(MachineInstr &MI,
MachineBasicBlock *MBB) const;

MachineBasicBlock *EmitTileLoad(unsigned Opc, unsigned BaseReg,
MachineInstr &MI,
MachineBasicBlock *BB) const;
Expand Down Expand Up @@ -532,8 +535,8 @@ class AArch64TargetLowering : public TargetLowering {
/// node. \p Condition should be one of the enum values from
/// AArch64SME::ToggleCondition.
SDValue changeStreamingMode(SelectionDAG &DAG, SDLoc DL, bool Enable,
SDValue Chain, SDValue InGlue,
unsigned Condition) const;
SDValue Chain, SDValue InGlue, unsigned Condition,
bool HasSVECC = false) const;

bool isVScaleKnownToBeAPowerOfTwo() const override { return true; }

Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1367,6 +1367,12 @@ def PROBED_STACKALLOC_DYN : Pseudo<(outs),
} // Defs = [SP, NZCV], Uses = [SP] in
} // hasSideEffects = 1, isCodeGenOnly = 1

// Pseudo-instruction that compares the current SVE vector length (VL) with the
// streaming vector length (SVL). If the two lengths do not match, the check
// lowers to a `brk`, causing a trap.
let hasSideEffects = 1, isCodeGenOnly = 1, usesCustomInserter = 1 in
def CHECK_MATCHING_VL : Pseudo<(outs), (ins), []>, Sched<[]>;

let isReMaterializable = 1, isCodeGenOnly = 1 in {
// FIXME: The following pseudo instructions are only needed because remat
// cannot handle multiple instructions. When that changes, they can be
Expand Down
26 changes: 20 additions & 6 deletions llvm/test/CodeGen/AArch64/sme-callee-save-restore-pairs.ll
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,19 @@ define void @fbyte(<vscale x 16 x i8> %v) #0{
; NOPAIR-NEXT: // %bb.1:
; NOPAIR-NEXT: smstop sm
; NOPAIR-NEXT: .LBB0_2:
; NOPAIR-NEXT: rdvl x8, #1
; NOPAIR-NEXT: rdsvl x9, #1
; NOPAIR-NEXT: cmp x8, x9
; NOPAIR-NEXT: b.eq .LBB0_4
; NOPAIR-NEXT: // %bb.3:
; NOPAIR-NEXT: brk #0x1
; NOPAIR-NEXT: .LBB0_4:
; NOPAIR-NEXT: ldr z0, [sp] // 16-byte Folded Reload
; NOPAIR-NEXT: bl my_func2
; NOPAIR-NEXT: tbz w19, #0, .LBB0_4
; NOPAIR-NEXT: // %bb.3:
; NOPAIR-NEXT: tbz w19, #0, .LBB0_6
; NOPAIR-NEXT: // %bb.5:
; NOPAIR-NEXT: smstart sm
; NOPAIR-NEXT: .LBB0_4:
; NOPAIR-NEXT: .LBB0_6:
; NOPAIR-NEXT: addvl sp, sp, #1
; NOPAIR-NEXT: ldr z23, [sp, #2, mul vl] // 16-byte Folded Reload
; NOPAIR-NEXT: ldr z22, [sp, #3, mul vl] // 16-byte Folded Reload
Expand Down Expand Up @@ -127,12 +134,19 @@ define void @fbyte(<vscale x 16 x i8> %v) #0{
; PAIR-NEXT: // %bb.1:
; PAIR-NEXT: smstop sm
; PAIR-NEXT: .LBB0_2:
; PAIR-NEXT: rdvl x8, #1
; PAIR-NEXT: rdsvl x9, #1
; PAIR-NEXT: cmp x8, x9
; PAIR-NEXT: b.eq .LBB0_4
; PAIR-NEXT: // %bb.3:
; PAIR-NEXT: brk #0x1
; PAIR-NEXT: .LBB0_4:
; PAIR-NEXT: ldr z0, [sp] // 16-byte Folded Reload
; PAIR-NEXT: bl my_func2
; PAIR-NEXT: tbz w19, #0, .LBB0_4
; PAIR-NEXT: // %bb.3:
; PAIR-NEXT: tbz w19, #0, .LBB0_6
; PAIR-NEXT: // %bb.5:
; PAIR-NEXT: smstart sm
; PAIR-NEXT: .LBB0_4:
; PAIR-NEXT: .LBB0_6:
; PAIR-NEXT: addvl sp, sp, #1
; PAIR-NEXT: ldr z23, [sp, #2, mul vl] // 16-byte Folded Reload
; PAIR-NEXT: ldr z22, [sp, #3, mul vl] // 16-byte Folded Reload
Expand Down
14 changes: 13 additions & 1 deletion llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
Original file line number Diff line number Diff line change
Expand Up @@ -527,14 +527,26 @@ define void @test13(ptr %ptr) nounwind "aarch64_pstate_sm_enabled" {
; CHECK-NEXT: stp x30, x19, [sp, #80] // 16-byte Folded Spill
; CHECK-NEXT: addvl sp, sp, #-1
; CHECK-NEXT: mov z0.s, #0 // =0x0
; CHECK-NEXT: mov x19, x0
; CHECK-NEXT: str z0, [sp] // 16-byte Folded Spill
; CHECK-NEXT: smstop sm
; CHECK-NEXT: rdvl x8, #1
; CHECK-NEXT: rdsvl x9, #1
; CHECK-NEXT: cmp x8, x9
; CHECK-NEXT: b.ne .LBB14_2
; CHECK-NEXT: // %bb.1:
; CHECK-NEXT: ldr z0, [sp] // 16-byte Folded Reload
; CHECK-NEXT: mov x19, x0
; CHECK-NEXT: bl callee_farg_fret
; CHECK-NEXT: str z0, [sp] // 16-byte Folded Spill
; CHECK-NEXT: smstart sm
; CHECK-NEXT: smstop sm
; CHECK-NEXT: rdvl x8, #1
; CHECK-NEXT: rdsvl x9, #1
; CHECK-NEXT: cmp x8, x9
; CHECK-NEXT: b.eq .LBB14_3
; CHECK-NEXT: .LBB14_2:
; CHECK-NEXT: brk #0x1
; CHECK-NEXT: .LBB14_3:
; CHECK-NEXT: ldr z0, [sp] // 16-byte Folded Reload
; CHECK-NEXT: bl callee_farg_fret
; CHECK-NEXT: str z0, [sp] // 16-byte Folded Spill
Expand Down
Loading
Loading