Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ class AArch64ExpandPseudo : public MachineFunctionPass {
TargetRegisterClass ContiguousClass,
TargetRegisterClass StridedClass,
unsigned ContiguousOpc, unsigned StridedOpc);
bool expandFormTuplePseudo(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
MachineBasicBlock::iterator &NextMBBI,
unsigned Size);
bool expandMOVImm(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
unsigned BitSize);

Expand Down Expand Up @@ -1142,6 +1146,30 @@ bool AArch64ExpandPseudo::expandMultiVecPseudo(
return true;
}

bool AArch64ExpandPseudo::expandFormTuplePseudo(
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
MachineBasicBlock::iterator &NextMBBI, unsigned Size) {
assert(Size == 2 || Size == 4 && "Invalid Tuple Size");
MachineInstr &MI = *MBBI;
Register ReturnTuple = MI.getOperand(0).getReg();

const TargetRegisterInfo *TRI =
MBB.getParent()->getSubtarget().getRegisterInfo();
for (unsigned i = 0; i < Size; i++) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
for (unsigned i = 0; i < Size; i++) {
for (unsigned I = 0; I < Size; ++I) {

Register FormTupleOpReg = MI.getOperand(i + 1).getReg();
Register ReturnTupleSubReg =
TRI->getSubReg(ReturnTuple, AArch64::zsub0 + i);
if (FormTupleOpReg != ReturnTupleSubReg)
BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(AArch64::ORR_ZZZ))
.addReg(ReturnTupleSubReg, RegState::Define)
.addReg(FormTupleOpReg)
.addReg(FormTupleOpReg);
}

MI.eraseFromParent();
return true;
}

/// If MBBI references a pseudo instruction that should be expanded here,
/// do the expansion and return true. Otherwise return false.
bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
Expand Down Expand Up @@ -1724,6 +1752,10 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
return expandMultiVecPseudo(
MBB, MBBI, AArch64::ZPR4RegClass, AArch64::ZPR4StridedRegClass,
AArch64::LDNT1D_4Z, AArch64::LDNT1D_4Z_STRIDED);
case AArch64::FORM_STRIDED_TUPLE_X2_PSEUDO:
return expandFormTuplePseudo(MBB, MBBI, NextMBBI, 2);
case AArch64::FORM_STRIDED_TUPLE_X4_PSEUDO:
return expandFormTuplePseudo(MBB, MBBI, NextMBBI, 4);
}
return false;
}
Expand Down
27 changes: 27 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,8 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {

bool SelectAllActivePredicate(SDValue N);
bool SelectAnyPredicate(SDValue N);

void SelectFormTuplePseudo(SDNode *N, unsigned Size);
};

class AArch64DAGToDAGISelLegacy : public SelectionDAGISelLegacy {
Expand Down Expand Up @@ -7181,6 +7183,14 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
}
break;
}
case AArch64ISD::FORM_STRIDED_TUPLE_X2: {
SelectFormTuplePseudo(Node, 2);
return;
}
case AArch64ISD::FORM_STRIDED_TUPLE_X4: {
SelectFormTuplePseudo(Node, 4);
return;
}
}

// Select the default instruction
Expand Down Expand Up @@ -7438,3 +7448,20 @@ bool AArch64DAGToDAGISel::SelectSMETileSlice(SDValue N, unsigned MaxSize,
Offset = CurDAG->getTargetConstant(0, SDLoc(N), MVT::i64);
return true;
}

void AArch64DAGToDAGISel::SelectFormTuplePseudo(SDNode *Node, unsigned Size) {
assert((Size == 2 || Size == 4) && "Invalid Tuple size");
EVT VT = Node->getValueType(0);
SmallVector<SDValue> Ops;
for (unsigned I = 0; I < Size; I++)
Ops.push_back(Node->getOperand(I));
SDLoc DL(Node);
unsigned Opc = Size == 2 ? AArch64::FORM_STRIDED_TUPLE_X2_PSEUDO
: AArch64::FORM_STRIDED_TUPLE_X4_PSEUDO;
SDNode *Tuple = CurDAG->getMachineNode(Opc, DL, MVT::Untyped, Ops);
SDValue SuperReg = SDValue(Tuple, 0);
for (unsigned I = 0; I < Size; ++I)
ReplaceUses(SDValue(Node, I), CurDAG->getTargetExtractSubreg(
AArch64::zsub0 + I, DL, VT, SuperReg));
CurDAG->RemoveDeadNode(Node);
}
63 changes: 63 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2808,6 +2808,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::FMUL_PRED)
MAKE_CASE(AArch64ISD::FSUB_PRED)
MAKE_CASE(AArch64ISD::RDSVL)
MAKE_CASE(AArch64ISD::FORM_STRIDED_TUPLE_X2)
MAKE_CASE(AArch64ISD::FORM_STRIDED_TUPLE_X4)
MAKE_CASE(AArch64ISD::BIC)
MAKE_CASE(AArch64ISD::CBZ)
MAKE_CASE(AArch64ISD::CBNZ)
Expand Down Expand Up @@ -5709,6 +5711,46 @@ SDValue AArch64TargetLowering::getRuntimePStateSM(SelectionDAG &DAG,
Mask);
}

static unsigned getIntrinsicID(const SDNode *N);

SDValue TryLowerMultiVecSMEDotIntrinsic(SDValue Op, SelectionDAG &DAG,
unsigned Size) {
assert((Size == 2 || Size == 4) && "Invalid Tuple Size");
auto IsStridedLoad = [Size](SDValue Op) -> bool {
unsigned Intrinsic = getIntrinsicID(Op.getNode());
if (Size == 2)
return Intrinsic == Intrinsic::aarch64_sve_ld1_pn_x2;
else
return Intrinsic == Intrinsic::aarch64_sve_ld1_pn_x4;
};

SmallVector<SDValue> Ops;
unsigned LastLoadIdx = Size == 2 ? 5 : 7;
unsigned LoadResNo = Op.getOperand(3).getResNo();
for (unsigned I = 3; I < LastLoadIdx; I++) {
if (!IsStridedLoad(Op->getOperand(I)) ||
Op.getOperand(I).getResNo() != LoadResNo)
return SDValue();
Ops.push_back(Op->getOperand(I));
}

EVT VT = Op->getOperand(3).getValueType();
SDVTList VTList =
Size == 2 ? DAG.getVTList(VT, VT) : DAG.getVTList(VT, VT, VT, VT);
unsigned Opc = Size == 2 ? AArch64ISD::FORM_STRIDED_TUPLE_X2
: AArch64ISD::FORM_STRIDED_TUPLE_X4;
SDLoc DL(Op);
SDValue Pseudo = DAG.getNode(Opc, DL, VTList, Ops);

SmallVector<SDValue> DotOps = {Op.getOperand(0), Op->getOperand(1),
Op->getOperand(2)};
for (unsigned I = 0; I < Size; I++)
DotOps.push_back(Pseudo.getValue(I));
DotOps.push_back(Op->getOperand(DotOps.size()));
DotOps.push_back(Op->getOperand(DotOps.size()));
return DAG.getNode(Op->getOpcode(), DL, MVT::Other, DotOps);
}

// Lower an SME LDR/STR ZA intrinsic
// Case 1: If the vector number (vecnum) is an immediate in range, it gets
// folded into the instruction
Expand Down Expand Up @@ -5898,6 +5940,22 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_VOID(SDValue Op,
Op->getOperand(0), // Chain
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
case Intrinsic::aarch64_sme_uvdot_lane_za32_vg1x4:
case Intrinsic::aarch64_sme_suvdot_lane_za32_vg1x4:
case Intrinsic::aarch64_sme_usvdot_lane_za32_vg1x4:
case Intrinsic::aarch64_sme_svdot_lane_za32_vg1x4:
case Intrinsic::aarch64_sme_usdot_lane_za32_vg1x4:
case Intrinsic::aarch64_sme_udot_lane_za32_vg1x4:
case Intrinsic::aarch64_sme_sudot_lane_za32_vg1x4:
case Intrinsic::aarch64_sme_sdot_lane_za32_vg1x4:
return TryLowerMultiVecSMEDotIntrinsic(Op, DAG, 4);
case Intrinsic::aarch64_sme_uvdot_lane_za32_vg1x2:
case Intrinsic::aarch64_sme_sdot_lane_za32_vg1x2:
case Intrinsic::aarch64_sme_svdot_lane_za32_vg1x2:
case Intrinsic::aarch64_sme_usdot_lane_za32_vg1x2:
case Intrinsic::aarch64_sme_sudot_lane_za32_vg1x2:
case Intrinsic::aarch64_sme_udot_lane_za32_vg1x2:
return TryLowerMultiVecSMEDotIntrinsic(Op, DAG, 2);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps a structurally simpler way to implement this (avoiding the need to do custom isel) is to change the patterns to always use the FORM_STRIDED_TUPLE_X.._PSEUDO instruction instead of REG_SEQUENCE.

For the multi-vector load case that you're trying to improve, the inputs to the tuple are always COPY nodes of the form:

%9:zpr = COPY %7.zsub0:zpr2stridedorcontiguous

There are cases where the RegisterCoalescer can make better decisions when using regular COPY nodes rather than the FORM_STRIDED_TUPLE pseudos. We could choose to handle the FORM_STRIDED_TUPLE pseudo with the hasPostISelHook = 1 where directly post-isel they are transformed into a REG_SEQUENCE node when any of the input values are not COPY nodes where the source register is in a 'stridedorcontiguous' register class. The REG_SEQUENCE node itself is then lowered later by the TwoAddressInstructionPass into individual COPY nodes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the suggestion, @sdesmalen-arm. I was able to remove the FORM_STRIDED_TUPLE nodes and instead add hasPostISelHook = 1 to the pseudos, creating a REG_SEQUENCE if the input values are not copies from a StridedOrContiguous source register.
This has been added in a new commit, with the RegAllocHints commit added on top.

}
}

Expand Down Expand Up @@ -7639,6 +7697,11 @@ static unsigned getIntrinsicID(const SDNode *N) {
return IID;
return Intrinsic::not_intrinsic;
}
case ISD::INTRINSIC_W_CHAIN: {
unsigned IID = N->getConstantOperandVal(1);
if (IID < Intrinsic::num_intrinsics)
return IID;
}
}
}

Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,9 @@ enum NodeType : unsigned {
SME_ZA_LDR,
SME_ZA_STR,

FORM_STRIDED_TUPLE_X2,
FORM_STRIDED_TUPLE_X4,

// NEON Load/Store with post-increment base updates
LD2post = ISD::FIRST_TARGET_MEMORY_OPCODE,
LD3post,
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ def AArch64_restore_zt : SDNode<"AArch64ISD::RESTORE_ZT", SDTypeProfile<0, 2,
def AArch64_save_zt : SDNode<"AArch64ISD::SAVE_ZT", SDTypeProfile<0, 2,
[SDTCisInt<0>, SDTCisPtrTy<1>]>,
[SDNPHasChain, SDNPSideEffect, SDNPMayStore]>;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove newline

def SDT_FORM_STRIDED_TUPLE_X2 : SDTypeProfile<4, 4,
[SDTCisVec<0>, SDTCisSameAs<0, 1>,
SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>]>;

def SDT_FORM_STRIDED_TUPLE_X4 : SDTypeProfile<4, 4,
[SDTCisVec<0>, SDTCisSameAs<0, 1>,
SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>,
SDTCisSameAs<0, 4>, SDTCisSameAs<0, 5>,
SDTCisSameAs<0, 6>, SDTCisSameAs<0, 7>]>;

def AArch64CoalescerBarrier
: SDNode<"AArch64ISD::COALESCER_BARRIER", SDTypeProfile<1, 1, []>, [SDNPOptInGlue, SDNPOutGlue]>;

Expand Down
12 changes: 12 additions & 0 deletions llvm/lib/Target/AArch64/SMEInstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@ def tileslicerange0s4 : ComplexPattern<i32, 2, "SelectSMETileSlice<0, 4>", []>;

def am_sme_indexed_b4 :ComplexPattern<iPTR, 2, "SelectAddrModeIndexedSVE<0,15>", [], [SDNPWantRoot]>;

def FORM_STRIDED_TUPLE_X2_PSEUDO :
Pseudo<(outs ZPR2Mul2:$tup),
(ins ZPR:$zn0, ZPR:$zn1), []>, Sched<[]>{
let hasSideEffects = 0;
}

def FORM_STRIDED_TUPLE_X4_PSEUDO :
Pseudo<(outs ZPR4Mul4:$tup),
(ins ZPR:$zn0, ZPR:$zn1, ZPR:$zn2, ZPR:$zn3), []>, Sched<[]>{
let hasSideEffects = 0;
}

def SDTZALoadStore : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisInt<2>]>;
def AArch64SMELdr : SDNode<"AArch64ISD::SME_ZA_LDR", SDTZALoadStore,
[SDNPHasChain, SDNPSideEffect, SDNPMayLoad]>;
Expand Down
Loading
Loading