Skip to content

Commit df516f9

Browse files
[GISel][CombinerHelper] Use a stream to check for G_CONCAT_VECTOR
We check for iterative shift masks which corresponds to the CONCAT_VECTOR instruction.
1 parent a3ae452 commit df516f9

File tree

3 files changed

+491
-42
lines changed

3 files changed

+491
-42
lines changed

llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,6 @@ class CombinerHelper {
246246
/// or an implicit_def if \p Ops is empty.
247247
void applyCombineShuffleConcat(MachineInstr &MI, SmallVector<Register> &Ops);
248248

249-
/// Try to combine G_SHUFFLE_VECTOR into G_CONCAT_VECTORS.
250249
/// A function type that returns either the next value in a
251250
/// shufflemask or an empty value. Each iteration should return
252251
/// one value, like a Python iterator or a Lisp stream.
@@ -257,14 +256,13 @@ class CombinerHelper {
257256
///
258257
/// \pre MI.getOpcode() == G_SHUFFLE_VECTOR.
259258
bool tryCombineShuffleVector(MachineInstr &MI);
260-
/// Check if the G_SHUFFLE_VECTOR \p MI can be replaced by a
261-
/// concat_vectors.
262-
/// \p Ops will contain the operands needed to produce the flattened
263-
/// concat_vectors.
259+
/// Check if the G_SHUFFLE_VECTOR \p MI can be replaced by checking
260+
/// whether the shufflemask given matches that of a given generator.
264261
///
265262
/// \pre MI.getOpcode() == G_SHUFFLE_VECTOR.
266-
bool matchCombineShuffleVector(MachineInstr &MI,
267-
SmallVectorImpl<Register> &Ops);
263+
bool matchCombineShuffleVector(MachineInstr &MI, GeneratorType Generator,
264+
const size_t TargetDstSize);
265+
268266
/// Replace \p MI with a concat_vectors with \p Ops.
269267
void applyCombineShuffleVector(MachineInstr &MI,
270268
const ArrayRef<Register> Ops);

llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp

Lines changed: 56 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -400,16 +400,64 @@ adderGenerator(const int32_t From, const int32_t To, const int32_t StepSize) {
400400
}
401401

402402
bool CombinerHelper::tryCombineShuffleVector(MachineInstr &MI) {
403+
const Register DstReg = MI.getOperand(0).getReg();
404+
const LLT DstTy = MRI.getType(DstReg);
405+
const LLT SrcTy = MRI.getType(MI.getOperand(1).getReg());
406+
const unsigned DstNumElts = DstTy.isVector() ? DstTy.getNumElements() : 1;
407+
const unsigned SrcNumElts = SrcTy.isVector() ? SrcTy.getNumElements() : 1;
408+
409+
// {1, 2, ..., n} -> G_CONCAT_VECTOR
410+
// Turns a shuffle vector that only increments into a concat vector
411+
// instruction
412+
GeneratorType CountUp = adderGenerator(0, DstNumElts - 1, 1);
403413
SmallVector<Register, 4> Ops;
404-
if (matchCombineShuffleVector(MI, Ops)) {
414+
415+
if (matchCombineShuffleVector(MI, CountUp, 2 * SrcNumElts)) {
416+
// The shuffle is concatenating multiple vectors together.
417+
// Collect the different operands for that.
418+
Register UndefReg;
419+
const Register Src1 = MI.getOperand(1).getReg();
420+
const Register Src2 = MI.getOperand(2).getReg();
421+
422+
const ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
423+
424+
// The destination can be longer than the source, so we separate them into
425+
// equal blocks and check them separately to see if one of the blocks can be
426+
// copied whole.
427+
unsigned NumConcat = DstNumElts / SrcNumElts;
428+
unsigned Index = 0;
429+
for (unsigned Concat = 0; Concat < NumConcat; Concat++) {
430+
unsigned Target = (Concat + 1) * SrcNumElts;
431+
while (Index < Target) {
432+
int MaskElt = Mask[Index];
433+
if (MaskElt >= 0) {
434+
Ops.push_back((MaskElt < (int)SrcNumElts) ? Src1 : Src2);
435+
break;
436+
}
437+
Index++;
438+
}
439+
440+
if (Index == Target) {
441+
if (!UndefReg) {
442+
Builder.setInsertPt(*MI.getParent(), MI);
443+
UndefReg = Builder.buildUndef(SrcTy).getReg(0);
444+
}
445+
Ops.push_back(UndefReg);
446+
}
447+
448+
Index = Target;
449+
}
450+
405451
applyCombineShuffleVector(MI, Ops);
406452
return true;
407453
}
454+
408455
return false;
409456
}
410457

411458
bool CombinerHelper::matchCombineShuffleVector(MachineInstr &MI,
412-
SmallVectorImpl<Register> &Ops) {
459+
GeneratorType Generator,
460+
const size_t TargetDstSize) {
413461
assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR &&
414462
"Invalid instruction kind");
415463
LLT DstType = MRI.getType(MI.getOperand(0).getReg());
@@ -436,51 +484,24 @@ bool CombinerHelper::matchCombineShuffleVector(MachineInstr &MI,
436484
//
437485
// TODO: If the size between the source and destination don't match
438486
// we could still emit an extract vector element in that case.
439-
if (DstNumElts < 2 * SrcNumElts && DstNumElts != 1)
487+
if ((DstNumElts < TargetDstSize) && DstNumElts != 1)
440488
return false;
441489

442-
// Check that the shuffle mask can be broken evenly between the
443-
// different sources.
444-
if (DstNumElts % SrcNumElts != 0)
445-
return false;
446-
447-
// Mask length is a multiple of the source vector length.
448-
// Check if the shuffle is some kind of concatenation of the input
449-
// vectors.
450-
unsigned NumConcat = DstNumElts / SrcNumElts;
451-
SmallVector<int, 8> ConcatSrcs(NumConcat, -1);
452490
ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
453491
for (unsigned i = 0; i != DstNumElts; ++i) {
454492
int Idx = Mask[i];
493+
const int32_t ShiftIndex = Generator().value_or(-1);
494+
455495
// Undef value.
456-
if (Idx < 0)
496+
if (Idx < 0 || ShiftIndex < 0)
457497
continue;
498+
458499
// Ensure the indices in each SrcType sized piece are sequential and that
459500
// the same source is used for the whole piece.
460-
if ((Idx % SrcNumElts != (i % SrcNumElts)) ||
461-
(ConcatSrcs[i / SrcNumElts] >= 0 &&
462-
ConcatSrcs[i / SrcNumElts] != (int)(Idx / SrcNumElts)))
501+
if ((Idx % SrcNumElts != (ShiftIndex % SrcNumElts)))
463502
return false;
464-
// Remember which source this index came from.
465-
ConcatSrcs[i / SrcNumElts] = Idx / SrcNumElts;
466503
}
467504

468-
// The shuffle is concatenating multiple vectors together.
469-
// Collect the different operands for that.
470-
Register UndefReg;
471-
Register Src2 = MI.getOperand(2).getReg();
472-
for (auto Src : ConcatSrcs) {
473-
if (Src < 0) {
474-
if (!UndefReg) {
475-
Builder.setInsertPt(*MI.getParent(), MI);
476-
UndefReg = Builder.buildUndef(SrcType).getReg(0);
477-
}
478-
Ops.push_back(UndefReg);
479-
} else if (Src == 0)
480-
Ops.push_back(Src1);
481-
else
482-
Ops.push_back(Src2);
483-
}
484505
return true;
485506
}
486507

0 commit comments

Comments
 (0)