Skip to content

Commit 6bf214b

Browse files
authored
[GlobalISel][AArch64] Legalize G_INSERT_VECTOR_ELT for SVE (llvm#114310)
There are patterns for: * {nxv2s32, s32, s64}, * {nxv4s16, s16, s64}, * {nxv2s16, s16, s64}
1 parent 948249d commit 6bf214b

File tree

5 files changed

+501
-8
lines changed

5 files changed

+501
-8
lines changed

llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,11 @@ inline LegalityPredicate typeIsNot(unsigned TypeIdx, LLT Type) {
273273
LegalityPredicate
274274
typePairInSet(unsigned TypeIdx0, unsigned TypeIdx1,
275275
std::initializer_list<std::pair<LLT, LLT>> TypesInit);
276+
/// True iff the given types for the given tuple of type indexes is one of the
277+
/// specified type tuple.
278+
LegalityPredicate
279+
typeTupleInSet(unsigned TypeIdx0, unsigned TypeIdx1, unsigned TypeIdx2,
280+
std::initializer_list<std::tuple<LLT, LLT, LLT>> TypesInit);
276281
/// True iff the given types for the given pair of type indexes is one of the
277282
/// specified type pairs.
278283
LegalityPredicate typePairAndMemDescInSet(
@@ -504,6 +509,15 @@ class LegalizeRuleSet {
504509
using namespace LegalityPredicates;
505510
return actionIf(Action, typePairInSet(typeIdx(0), typeIdx(1), Types));
506511
}
512+
513+
LegalizeRuleSet &
514+
actionFor(LegalizeAction Action,
515+
std::initializer_list<std::tuple<LLT, LLT, LLT>> Types) {
516+
using namespace LegalityPredicates;
517+
return actionIf(Action,
518+
typeTupleInSet(typeIdx(0), typeIdx(1), typeIdx(2), Types));
519+
}
520+
507521
/// Use the given action when type indexes 0 and 1 is any type pair in the
508522
/// given list.
509523
/// Action should be an action that requires mutation.
@@ -615,6 +629,12 @@ class LegalizeRuleSet {
615629
return *this;
616630
return actionFor(LegalizeAction::Legal, Types);
617631
}
632+
LegalizeRuleSet &
633+
legalFor(bool Pred, std::initializer_list<std::tuple<LLT, LLT, LLT>> Types) {
634+
if (!Pred)
635+
return *this;
636+
return actionFor(LegalizeAction::Legal, Types);
637+
}
618638
/// The instruction is legal when type index 0 is any type in the given list
619639
/// and imm index 0 is anything.
620640
LegalizeRuleSet &legalForTypeWithAnyImm(std::initializer_list<LLT> Types) {

llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,17 @@ LegalityPredicate LegalityPredicates::typePairInSet(
4949
};
5050
}
5151

52+
LegalityPredicate LegalityPredicates::typeTupleInSet(
53+
unsigned TypeIdx0, unsigned TypeIdx1, unsigned TypeIdx2,
54+
std::initializer_list<std::tuple<LLT, LLT, LLT>> TypesInit) {
55+
SmallVector<std::tuple<LLT, LLT, LLT>, 4> Types = TypesInit;
56+
return [=](const LegalityQuery &Query) {
57+
std::tuple<LLT, LLT, LLT> Match = {
58+
Query.Types[TypeIdx0], Query.Types[TypeIdx1], Query.Types[TypeIdx2]};
59+
return llvm::is_contained(Types, Match);
60+
};
61+
}
62+
5263
LegalityPredicate LegalityPredicates::typePairAndMemDescInSet(
5364
unsigned TypeIdx0, unsigned TypeIdx1, unsigned MMOIdx,
5465
std::initializer_list<TypePairAndMemDesc> TypesAndMemDescInit) {

llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -978,6 +978,10 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
978978
getActionDefinitionsBuilder(G_INSERT_VECTOR_ELT)
979979
.legalIf(
980980
typeInSet(0, {v16s8, v8s8, v8s16, v4s16, v4s32, v2s32, v2s64, v2p0}))
981+
.legalFor(HasSVE, {{nxv16s8, s32, s64},
982+
{nxv8s16, s32, s64},
983+
{nxv4s32, s32, s64},
984+
{nxv2s64, s64, s64}})
981985
.moreElementsToNextPow2(0)
982986
.widenVectorEltsToVectorMinSize(0, 64)
983987
.clampNumElements(0, v8s8, v16s8)

llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ bool matchREV(MachineInstr &MI, MachineRegisterInfo &MRI,
161161
Register Dst = MI.getOperand(0).getReg();
162162
Register Src = MI.getOperand(1).getReg();
163163
LLT Ty = MRI.getType(Dst);
164+
if (Ty.isScalableVector())
165+
return false;
164166
unsigned EltSize = Ty.getScalarSizeInBits();
165167

166168
// Element size for a rev cannot be 64.
@@ -196,7 +198,10 @@ bool matchTRN(MachineInstr &MI, MachineRegisterInfo &MRI,
196198
unsigned WhichResult;
197199
ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
198200
Register Dst = MI.getOperand(0).getReg();
199-
unsigned NumElts = MRI.getType(Dst).getNumElements();
201+
LLT DstTy = MRI.getType(Dst);
202+
if (DstTy.isScalableVector())
203+
return false;
204+
unsigned NumElts = DstTy.getNumElements();
200205
if (!isTRNMask(ShuffleMask, NumElts, WhichResult))
201206
return false;
202207
unsigned Opc = (WhichResult == 0) ? AArch64::G_TRN1 : AArch64::G_TRN2;
@@ -217,7 +222,10 @@ bool matchUZP(MachineInstr &MI, MachineRegisterInfo &MRI,
217222
unsigned WhichResult;
218223
ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
219224
Register Dst = MI.getOperand(0).getReg();
220-
unsigned NumElts = MRI.getType(Dst).getNumElements();
225+
LLT DstTy = MRI.getType(Dst);
226+
if (DstTy.isScalableVector())
227+
return false;
228+
unsigned NumElts = DstTy.getNumElements();
221229
if (!isUZPMask(ShuffleMask, NumElts, WhichResult))
222230
return false;
223231
unsigned Opc = (WhichResult == 0) ? AArch64::G_UZP1 : AArch64::G_UZP2;
@@ -233,7 +241,10 @@ bool matchZip(MachineInstr &MI, MachineRegisterInfo &MRI,
233241
unsigned WhichResult;
234242
ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
235243
Register Dst = MI.getOperand(0).getReg();
236-
unsigned NumElts = MRI.getType(Dst).getNumElements();
244+
LLT DstTy = MRI.getType(Dst);
245+
if (DstTy.isScalableVector())
246+
return false;
247+
unsigned NumElts = DstTy.getNumElements();
237248
if (!isZIPMask(ShuffleMask, NumElts, WhichResult))
238249
return false;
239250
unsigned Opc = (WhichResult == 0) ? AArch64::G_ZIP1 : AArch64::G_ZIP2;
@@ -288,7 +299,10 @@ bool matchDupFromBuildVector(int Lane, MachineInstr &MI,
288299
MachineRegisterInfo &MRI,
289300
ShuffleVectorPseudo &MatchInfo) {
290301
assert(Lane >= 0 && "Expected positive lane?");
291-
int NumElements = MRI.getType(MI.getOperand(1).getReg()).getNumElements();
302+
LLT Op1Ty = MRI.getType(MI.getOperand(1).getReg());
303+
if (Op1Ty.isScalableVector())
304+
return false;
305+
int NumElements = Op1Ty.getNumElements();
292306
// Test if the LHS is a BUILD_VECTOR. If it is, then we can just reference the
293307
// lane's definition directly.
294308
auto *BuildVecMI =
@@ -326,6 +340,8 @@ bool matchDup(MachineInstr &MI, MachineRegisterInfo &MRI,
326340
// Check if an EXT instruction can handle the shuffle mask when the vector
327341
// sources of the shuffle are the same.
328342
bool isSingletonExtMask(ArrayRef<int> M, LLT Ty) {
343+
if (Ty.isScalableVector())
344+
return false;
329345
unsigned NumElts = Ty.getNumElements();
330346

331347
// Assume that the first shuffle index is not UNDEF. Fail if it is.
@@ -357,12 +373,17 @@ bool matchEXT(MachineInstr &MI, MachineRegisterInfo &MRI,
357373
assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
358374
Register Dst = MI.getOperand(0).getReg();
359375
LLT DstTy = MRI.getType(Dst);
376+
if (DstTy.isScalableVector())
377+
return false;
360378
Register V1 = MI.getOperand(1).getReg();
361379
Register V2 = MI.getOperand(2).getReg();
362380
auto Mask = MI.getOperand(3).getShuffleMask();
363381
uint64_t Imm;
364382
auto ExtInfo = getExtMask(Mask, DstTy.getNumElements());
365-
uint64_t ExtFactor = MRI.getType(V1).getScalarSizeInBits() / 8;
383+
LLT V1Ty = MRI.getType(V1);
384+
if (V1Ty.isScalableVector())
385+
return false;
386+
uint64_t ExtFactor = V1Ty.getScalarSizeInBits() / 8;
366387

367388
if (!ExtInfo) {
368389
if (!getOpcodeDef<GImplicitDef>(V2, MRI) ||
@@ -423,6 +444,8 @@ void applyNonConstInsert(MachineInstr &MI, MachineRegisterInfo &MRI,
423444

424445
Register Offset = Insert.getIndexReg();
425446
LLT VecTy = MRI.getType(Insert.getReg(0));
447+
if (VecTy.isScalableVector())
448+
return;
426449
LLT EltTy = MRI.getType(Insert.getElementReg());
427450
LLT IdxTy = MRI.getType(Insert.getIndexReg());
428451

@@ -473,7 +496,10 @@ bool matchINS(MachineInstr &MI, MachineRegisterInfo &MRI,
473496
assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
474497
ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
475498
Register Dst = MI.getOperand(0).getReg();
476-
int NumElts = MRI.getType(Dst).getNumElements();
499+
LLT DstTy = MRI.getType(Dst);
500+
if (DstTy.isScalableVector())
501+
return false;
502+
int NumElts = DstTy.getNumElements();
477503
auto DstIsLeftAndDstLane = isINSMask(ShuffleMask, NumElts);
478504
if (!DstIsLeftAndDstLane)
479505
return false;
@@ -522,6 +548,8 @@ bool isVShiftRImm(Register Reg, MachineRegisterInfo &MRI, LLT Ty,
522548
if (!Cst)
523549
return false;
524550
Cnt = *Cst;
551+
if (Ty.isScalableVector())
552+
return false;
525553
int64_t ElementBits = Ty.getScalarSizeInBits();
526554
return Cnt >= 1 && Cnt <= ElementBits;
527555
}
@@ -698,6 +726,8 @@ bool matchDupLane(MachineInstr &MI, MachineRegisterInfo &MRI,
698726
Register Src1Reg = MI.getOperand(1).getReg();
699727
const LLT SrcTy = MRI.getType(Src1Reg);
700728
const LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
729+
if (SrcTy.isScalableVector())
730+
return false;
701731

702732
auto LaneIdx = getSplatIndex(MI);
703733
if (!LaneIdx)
@@ -774,6 +804,8 @@ bool matchScalarizeVectorUnmerge(MachineInstr &MI, MachineRegisterInfo &MRI) {
774804
auto &Unmerge = cast<GUnmerge>(MI);
775805
Register Src1Reg = Unmerge.getReg(Unmerge.getNumOperands() - 1);
776806
const LLT SrcTy = MRI.getType(Src1Reg);
807+
if (SrcTy.isScalableVector())
808+
return false;
777809
if (SrcTy.getSizeInBits() != 128 && SrcTy.getSizeInBits() != 64)
778810
return false;
779811
return SrcTy.isVector() && !SrcTy.isScalable() &&
@@ -987,7 +1019,10 @@ bool matchLowerVectorFCMP(MachineInstr &MI, MachineRegisterInfo &MRI,
9871019
if (!DstTy.isVector() || !ST.hasNEON())
9881020
return false;
9891021
Register LHS = MI.getOperand(2).getReg();
990-
unsigned EltSize = MRI.getType(LHS).getScalarSizeInBits();
1022+
LLT LHSTy = MRI.getType(LHS);
1023+
if (LHSTy.isScalableVector())
1024+
return false;
1025+
unsigned EltSize = LHSTy.getScalarSizeInBits();
9911026
if (EltSize == 16 && !ST.hasFullFP16())
9921027
return false;
9931028
if (EltSize != 16 && EltSize != 32 && EltSize != 64)
@@ -1183,7 +1218,7 @@ bool matchExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI) {
11831218
MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
11841219
MachineInstr *I2 = getDefIgnoringCopies(MI.getOperand(2).getReg(), MRI);
11851220

1186-
if (DstTy.isVector()) {
1221+
if (DstTy.isFixedVector()) {
11871222
// If the source operands were EXTENDED before, then {U/S}MULL can be used
11881223
unsigned I1Opc = I1->getOpcode();
11891224
unsigned I2Opc = I2->getOpcode();

0 commit comments

Comments
 (0)