Skip to content

Commit d30a8d6

Browse files
committed
[AArch64][GlobalISel] Perfect Shuffles
This is a port of the existing perfect shuffle generation code from SDAG, geneticized to work for both SDAG and GISel. I wrote it a while ago and it has been sitting on my machine. It brings the codegen for certain shuffles inline and avoids the need for generating a tbl and constant pool load.
1 parent 56ffb72 commit d30a8d6

16 files changed

+466
-352
lines changed

llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,6 +1365,23 @@ class MachineIRBuilder {
13651365
const SrcOp &Elt,
13661366
const SrcOp &Idx);
13671367

1368+
/// Build and insert \p Res = G_INSERT_VECTOR_ELT \p Val, \p Elt, \p Idx
1369+
///
1370+
/// \pre setBasicBlock or setMI must have been called.
1371+
/// \pre \p Res must be a generic virtual register with scalar type.
1372+
/// \pre \p Val must be a generic virtual register with vector type.
1373+
/// \pre \p Elt must be a generic virtual register with scalar type.
1374+
///
1375+
/// \return The newly created instruction.
1376+
MachineInstrBuilder buildInsertVectorElementConstant(const DstOp &Res,
1377+
const SrcOp &Val,
1378+
const SrcOp &Elt,
1379+
const int Idx) {
1380+
const TargetLowering *TLI = getMF().getSubtarget().getTargetLowering();
1381+
LLT IdxTy = TLI->getVectorIdxLLT(getDataLayout());
1382+
return buildInsertVectorElement(Res, Val, Elt, buildConstant(IdxTy, Idx));
1383+
}
1384+
13681385
/// Build and insert \p Res = G_EXTRACT_VECTOR_ELT \p Val, \p Idx
13691386
///
13701387
/// \pre setBasicBlock or setMI must have been called.

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,8 @@ class TargetLoweringBase {
417417
}
418418

419419
/// Returns the type to be used for the index operand vector operations. By
420-
/// default we assume it will have the same size as an address space 0 pointer.
420+
/// default we assume it will have the same size as an address space 0
421+
/// pointer.
421422
virtual unsigned getVectorIdxWidth(const DataLayout &DL) const {
422423
return DL.getPointerSizeInBits(0);
423424
}

llvm/lib/Target/AArch64/AArch64Combine.td

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,13 @@ def shuf_to_ins: GICombineRule <
155155
(apply [{ applyINS(*${root}, MRI, B, ${matchinfo}); }])
156156
>;
157157

158+
def perfect_shuffle: GICombineRule <
159+
(defs root:$root),
160+
(match (G_SHUFFLE_VECTOR $dst, $src1, $src2, $mask):$root,
161+
[{ return matchPerfectShuffle(*${root}, MRI); }]),
162+
(apply [{ applyPerfectShuffle(*${root}, MRI, B); }])
163+
>;
164+
158165
def vashr_vlshr_imm_matchdata : GIDefMatchData<"int64_t">;
159166
def vashr_vlshr_imm : GICombineRule<
160167
(defs root:$root, vashr_vlshr_imm_matchdata:$matchinfo),
@@ -173,7 +180,8 @@ def form_duplane : GICombineRule <
173180
>;
174181

175182
def shuffle_vector_lowering : GICombineGroup<[dup, rev, ext, zip, uzp, trn, fullrev,
176-
form_duplane, shuf_to_ins]>;
183+
form_duplane, shuf_to_ins,
184+
perfect_shuffle]>;
177185

178186
// Turn G_UNMERGE_VALUES -> G_EXTRACT_VECTOR_ELT's
179187
def vector_unmerge_lowering : GICombineRule <

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 89 additions & 168 deletions
Original file line numberDiff line numberDiff line change
@@ -13467,172 +13467,6 @@ static SDValue tryFormConcatFromShuffle(SDValue Op, SelectionDAG &DAG) {
1346713467
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, V0, V1);
1346813468
}
1346913469

13470-
/// GeneratePerfectShuffle - Given an entry in the perfect-shuffle table, emit
13471-
/// the specified operations to build the shuffle. ID is the perfect-shuffle
13472-
//ID, V1 and V2 are the original shuffle inputs. PFEntry is the Perfect shuffle
13473-
//table entry and LHS/RHS are the immediate inputs for this stage of the
13474-
//shuffle.
13475-
static SDValue GeneratePerfectShuffle(unsigned ID, SDValue V1,
13476-
SDValue V2, unsigned PFEntry, SDValue LHS,
13477-
SDValue RHS, SelectionDAG &DAG,
13478-
const SDLoc &dl) {
13479-
unsigned OpNum = (PFEntry >> 26) & 0x0F;
13480-
unsigned LHSID = (PFEntry >> 13) & ((1 << 13) - 1);
13481-
unsigned RHSID = (PFEntry >> 0) & ((1 << 13) - 1);
13482-
13483-
enum {
13484-
OP_COPY = 0, // Copy, used for things like <u,u,u,3> to say it is <0,1,2,3>
13485-
OP_VREV,
13486-
OP_VDUP0,
13487-
OP_VDUP1,
13488-
OP_VDUP2,
13489-
OP_VDUP3,
13490-
OP_VEXT1,
13491-
OP_VEXT2,
13492-
OP_VEXT3,
13493-
OP_VUZPL, // VUZP, left result
13494-
OP_VUZPR, // VUZP, right result
13495-
OP_VZIPL, // VZIP, left result
13496-
OP_VZIPR, // VZIP, right result
13497-
OP_VTRNL, // VTRN, left result
13498-
OP_VTRNR, // VTRN, right result
13499-
OP_MOVLANE // Move lane. RHSID is the lane to move into
13500-
};
13501-
13502-
if (OpNum == OP_COPY) {
13503-
if (LHSID == (1 * 9 + 2) * 9 + 3)
13504-
return LHS;
13505-
assert(LHSID == ((4 * 9 + 5) * 9 + 6) * 9 + 7 && "Illegal OP_COPY!");
13506-
return RHS;
13507-
}
13508-
13509-
if (OpNum == OP_MOVLANE) {
13510-
// Decompose a PerfectShuffle ID to get the Mask for lane Elt
13511-
auto getPFIDLane = [](unsigned ID, int Elt) -> int {
13512-
assert(Elt < 4 && "Expected Perfect Lanes to be less than 4");
13513-
Elt = 3 - Elt;
13514-
while (Elt > 0) {
13515-
ID /= 9;
13516-
Elt--;
13517-
}
13518-
return (ID % 9 == 8) ? -1 : ID % 9;
13519-
};
13520-
13521-
// For OP_MOVLANE shuffles, the RHSID represents the lane to move into. We
13522-
// get the lane to move from the PFID, which is always from the
13523-
// original vectors (V1 or V2).
13524-
SDValue OpLHS = GeneratePerfectShuffle(
13525-
LHSID, V1, V2, PerfectShuffleTable[LHSID], LHS, RHS, DAG, dl);
13526-
EVT VT = OpLHS.getValueType();
13527-
assert(RHSID < 8 && "Expected a lane index for RHSID!");
13528-
unsigned ExtLane = 0;
13529-
SDValue Input;
13530-
13531-
// OP_MOVLANE are either D movs (if bit 0x4 is set) or S movs. D movs
13532-
// convert into a higher type.
13533-
if (RHSID & 0x4) {
13534-
int MaskElt = getPFIDLane(ID, (RHSID & 0x01) << 1) >> 1;
13535-
if (MaskElt == -1)
13536-
MaskElt = (getPFIDLane(ID, ((RHSID & 0x01) << 1) + 1) - 1) >> 1;
13537-
assert(MaskElt >= 0 && "Didn't expect an undef movlane index!");
13538-
ExtLane = MaskElt < 2 ? MaskElt : (MaskElt - 2);
13539-
Input = MaskElt < 2 ? V1 : V2;
13540-
if (VT.getScalarSizeInBits() == 16) {
13541-
Input = DAG.getBitcast(MVT::v2f32, Input);
13542-
OpLHS = DAG.getBitcast(MVT::v2f32, OpLHS);
13543-
} else {
13544-
assert(VT.getScalarSizeInBits() == 32 &&
13545-
"Expected 16 or 32 bit shuffle elemements");
13546-
Input = DAG.getBitcast(MVT::v2f64, Input);
13547-
OpLHS = DAG.getBitcast(MVT::v2f64, OpLHS);
13548-
}
13549-
} else {
13550-
int MaskElt = getPFIDLane(ID, RHSID);
13551-
assert(MaskElt >= 0 && "Didn't expect an undef movlane index!");
13552-
ExtLane = MaskElt < 4 ? MaskElt : (MaskElt - 4);
13553-
Input = MaskElt < 4 ? V1 : V2;
13554-
// Be careful about creating illegal types. Use f16 instead of i16.
13555-
if (VT == MVT::v4i16) {
13556-
Input = DAG.getBitcast(MVT::v4f16, Input);
13557-
OpLHS = DAG.getBitcast(MVT::v4f16, OpLHS);
13558-
}
13559-
}
13560-
SDValue Ext = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
13561-
Input.getValueType().getVectorElementType(),
13562-
Input, DAG.getVectorIdxConstant(ExtLane, dl));
13563-
SDValue Ins =
13564-
DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, Input.getValueType(), OpLHS,
13565-
Ext, DAG.getVectorIdxConstant(RHSID & 0x3, dl));
13566-
return DAG.getBitcast(VT, Ins);
13567-
}
13568-
13569-
SDValue OpLHS, OpRHS;
13570-
OpLHS = GeneratePerfectShuffle(LHSID, V1, V2, PerfectShuffleTable[LHSID], LHS,
13571-
RHS, DAG, dl);
13572-
OpRHS = GeneratePerfectShuffle(RHSID, V1, V2, PerfectShuffleTable[RHSID], LHS,
13573-
RHS, DAG, dl);
13574-
EVT VT = OpLHS.getValueType();
13575-
13576-
switch (OpNum) {
13577-
default:
13578-
llvm_unreachable("Unknown shuffle opcode!");
13579-
case OP_VREV:
13580-
// VREV divides the vector in half and swaps within the half.
13581-
if (VT.getVectorElementType() == MVT::i32 ||
13582-
VT.getVectorElementType() == MVT::f32)
13583-
return DAG.getNode(AArch64ISD::REV64, dl, VT, OpLHS);
13584-
// vrev <4 x i16> -> REV32
13585-
if (VT.getVectorElementType() == MVT::i16 ||
13586-
VT.getVectorElementType() == MVT::f16 ||
13587-
VT.getVectorElementType() == MVT::bf16)
13588-
return DAG.getNode(AArch64ISD::REV32, dl, VT, OpLHS);
13589-
// vrev <4 x i8> -> REV16
13590-
assert(VT.getVectorElementType() == MVT::i8);
13591-
return DAG.getNode(AArch64ISD::REV16, dl, VT, OpLHS);
13592-
case OP_VDUP0:
13593-
case OP_VDUP1:
13594-
case OP_VDUP2:
13595-
case OP_VDUP3: {
13596-
EVT EltTy = VT.getVectorElementType();
13597-
unsigned Opcode;
13598-
if (EltTy == MVT::i8)
13599-
Opcode = AArch64ISD::DUPLANE8;
13600-
else if (EltTy == MVT::i16 || EltTy == MVT::f16 || EltTy == MVT::bf16)
13601-
Opcode = AArch64ISD::DUPLANE16;
13602-
else if (EltTy == MVT::i32 || EltTy == MVT::f32)
13603-
Opcode = AArch64ISD::DUPLANE32;
13604-
else if (EltTy == MVT::i64 || EltTy == MVT::f64)
13605-
Opcode = AArch64ISD::DUPLANE64;
13606-
else
13607-
llvm_unreachable("Invalid vector element type?");
13608-
13609-
if (VT.getSizeInBits() == 64)
13610-
OpLHS = WidenVector(OpLHS, DAG);
13611-
SDValue Lane = DAG.getConstant(OpNum - OP_VDUP0, dl, MVT::i64);
13612-
return DAG.getNode(Opcode, dl, VT, OpLHS, Lane);
13613-
}
13614-
case OP_VEXT1:
13615-
case OP_VEXT2:
13616-
case OP_VEXT3: {
13617-
unsigned Imm = (OpNum - OP_VEXT1 + 1) * getExtFactor(OpLHS);
13618-
return DAG.getNode(AArch64ISD::EXT, dl, VT, OpLHS, OpRHS,
13619-
DAG.getConstant(Imm, dl, MVT::i32));
13620-
}
13621-
case OP_VUZPL:
13622-
return DAG.getNode(AArch64ISD::UZP1, dl, VT, OpLHS, OpRHS);
13623-
case OP_VUZPR:
13624-
return DAG.getNode(AArch64ISD::UZP2, dl, VT, OpLHS, OpRHS);
13625-
case OP_VZIPL:
13626-
return DAG.getNode(AArch64ISD::ZIP1, dl, VT, OpLHS, OpRHS);
13627-
case OP_VZIPR:
13628-
return DAG.getNode(AArch64ISD::ZIP2, dl, VT, OpLHS, OpRHS);
13629-
case OP_VTRNL:
13630-
return DAG.getNode(AArch64ISD::TRN1, dl, VT, OpLHS, OpRHS);
13631-
case OP_VTRNR:
13632-
return DAG.getNode(AArch64ISD::TRN2, dl, VT, OpLHS, OpRHS);
13633-
}
13634-
}
13635-
1363613470
static SDValue GenerateTBL(SDValue Op, ArrayRef<int> ShuffleMask,
1363713471
SelectionDAG &DAG) {
1363813472
// Check to see if we can use the TBL instruction.
@@ -14056,8 +13890,95 @@ SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
1405613890
unsigned PFTableIndex = PFIndexes[0] * 9 * 9 * 9 + PFIndexes[1] * 9 * 9 +
1405713891
PFIndexes[2] * 9 + PFIndexes[3];
1405813892
unsigned PFEntry = PerfectShuffleTable[PFTableIndex];
14059-
return GeneratePerfectShuffle(PFTableIndex, V1, V2, PFEntry, V1, V2, DAG,
14060-
dl);
13893+
13894+
auto BuildRev = [&DAG, &dl](SDValue OpLHS) {
13895+
EVT VT = OpLHS.getValueType();
13896+
unsigned Opcode = VT.getScalarSizeInBits() == 32 ? AArch64ISD::REV64
13897+
: VT.getScalarSizeInBits() == 16 ? AArch64ISD::REV32
13898+
: AArch64ISD::REV16;
13899+
return DAG.getNode(Opcode, dl, VT, OpLHS);
13900+
};
13901+
auto BuildDup = [&DAG, &dl](SDValue OpLHS, unsigned Lane) {
13902+
EVT VT = OpLHS.getValueType();
13903+
unsigned Opcode;
13904+
if (VT.getScalarSizeInBits() == 8)
13905+
Opcode = AArch64ISD::DUPLANE8;
13906+
else if (VT.getScalarSizeInBits() == 16)
13907+
Opcode = AArch64ISD::DUPLANE16;
13908+
else if (VT.getScalarSizeInBits() == 32)
13909+
Opcode = AArch64ISD::DUPLANE32;
13910+
else if (VT.getScalarSizeInBits() == 64)
13911+
Opcode = AArch64ISD::DUPLANE64;
13912+
else
13913+
llvm_unreachable("Invalid vector element type?");
13914+
13915+
if (VT.getSizeInBits() == 64)
13916+
OpLHS = WidenVector(OpLHS, DAG);
13917+
return DAG.getNode(Opcode, dl, VT, OpLHS,
13918+
DAG.getConstant(Lane, dl, MVT::i64));
13919+
};
13920+
auto BuildExt = [&DAG, &dl](SDValue OpLHS, SDValue OpRHS, unsigned Imm) {
13921+
EVT VT = OpLHS.getValueType();
13922+
Imm = Imm * getExtFactor(OpLHS);
13923+
return DAG.getNode(AArch64ISD::EXT, dl, VT, OpLHS, OpRHS,
13924+
DAG.getConstant(Imm, dl, MVT::i32));
13925+
};
13926+
auto BuildZipLike = [&DAG, &dl](unsigned OpNum, SDValue OpLHS,
13927+
SDValue OpRHS) {
13928+
EVT VT = OpLHS.getValueType();
13929+
switch (OpNum) {
13930+
default:
13931+
llvm_unreachable("Unexpected perfect shuffle opcode\n");
13932+
case OP_VUZPL:
13933+
return DAG.getNode(AArch64ISD::UZP1, dl, VT, OpLHS, OpRHS);
13934+
case OP_VUZPR:
13935+
return DAG.getNode(AArch64ISD::UZP2, dl, VT, OpLHS, OpRHS);
13936+
case OP_VZIPL:
13937+
return DAG.getNode(AArch64ISD::ZIP1, dl, VT, OpLHS, OpRHS);
13938+
case OP_VZIPR:
13939+
return DAG.getNode(AArch64ISD::ZIP2, dl, VT, OpLHS, OpRHS);
13940+
case OP_VTRNL:
13941+
return DAG.getNode(AArch64ISD::TRN1, dl, VT, OpLHS, OpRHS);
13942+
case OP_VTRNR:
13943+
return DAG.getNode(AArch64ISD::TRN2, dl, VT, OpLHS, OpRHS);
13944+
}
13945+
};
13946+
auto BuildExtractInsert64 = [&DAG, &dl](SDValue ExtSrc, unsigned ExtLane,
13947+
SDValue InsSrc, unsigned InsLane) {
13948+
EVT VT = InsSrc.getValueType();
13949+
if (VT.getScalarSizeInBits() == 16) {
13950+
ExtSrc = DAG.getBitcast(MVT::v2f32, ExtSrc);
13951+
InsSrc = DAG.getBitcast(MVT::v2f32, InsSrc);
13952+
} else if (VT.getScalarSizeInBits() == 32) {
13953+
ExtSrc = DAG.getBitcast(MVT::v2f64, ExtSrc);
13954+
InsSrc = DAG.getBitcast(MVT::v2f64, InsSrc);
13955+
}
13956+
SDValue Ext = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
13957+
ExtSrc.getValueType().getVectorElementType(),
13958+
ExtSrc, DAG.getVectorIdxConstant(ExtLane, dl));
13959+
SDValue Ins =
13960+
DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, ExtSrc.getValueType(), InsSrc,
13961+
Ext, DAG.getVectorIdxConstant(InsLane, dl));
13962+
return DAG.getBitcast(VT, Ins);
13963+
};
13964+
auto BuildExtractInsert32 = [&DAG, &dl](SDValue ExtSrc, unsigned ExtLane,
13965+
SDValue InsSrc, unsigned InsLane) {
13966+
EVT VT = InsSrc.getValueType();
13967+
if (VT.getScalarSizeInBits() == 16) {
13968+
ExtSrc = DAG.getBitcast(MVT::v4f16, ExtSrc);
13969+
InsSrc = DAG.getBitcast(MVT::v4f16, InsSrc);
13970+
}
13971+
SDValue Ext = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
13972+
ExtSrc.getValueType().getVectorElementType(),
13973+
ExtSrc, DAG.getVectorIdxConstant(ExtLane, dl));
13974+
SDValue Ins =
13975+
DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, ExtSrc.getValueType(), InsSrc,
13976+
Ext, DAG.getVectorIdxConstant(InsLane, dl));
13977+
return DAG.getBitcast(VT, Ins);
13978+
};
13979+
return generatePerfectShuffle<SDValue, MVT>(
13980+
PFTableIndex, V1, V2, PFEntry, V1, V2, BuildExtractInsert64,
13981+
BuildExtractInsert32, BuildRev, BuildDup, BuildExt, BuildZipLike);
1406113982
}
1406213983

1406313984
// Check for a "select shuffle", generating a BSL to pick between lanes in

0 commit comments

Comments
 (0)