Skip to content

Commit a38842b

Browse files
[GISel][CombinerHelper] Use a stream to check mask patterns to detect CONCAT_VECTOR
We check for iterative shift masks which corresponds to the CONCAT_VECTOR instruction.
1 parent 52f054f commit a38842b

File tree

3 files changed

+337
-38
lines changed

3 files changed

+337
-38
lines changed

llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "llvm/CodeGen/Register.h"
2828
#include "llvm/IR/InstrTypes.h"
2929
#include <functional>
30+
#include <optional>
3031

3132
namespace llvm {
3233

@@ -255,8 +256,11 @@ class CombinerHelper {
255256
/// concat_vectors.
256257
///
257258
/// \pre MI.getOpcode() == G_SHUFFLE_VECTOR.
258-
bool matchCombineShuffleVector(MachineInstr &MI,
259-
SmallVectorImpl<Register> &Ops);
259+
using GeneratorType = std::function<std::optional<int32_t>()>;
260+
261+
bool matchCombineShuffleVector(MachineInstr &MI, GeneratorType Generator,
262+
const size_t TargetDstSize);
263+
260264
/// Replace \p MI with a concat_vectors with \p Ops.
261265
void applyCombineShuffleVector(MachineInstr &MI,
262266
const ArrayRef<Register> Ops);

llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp

Lines changed: 71 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
#include "llvm/Support/MathExtras.h"
4040
#include "llvm/Target/TargetMachine.h"
4141
#include <cmath>
42+
#include <cstdint>
43+
#include <functional>
4244
#include <optional>
4345
#include <tuple>
4446

@@ -301,17 +303,77 @@ void CombinerHelper::applyCombineConcatVectors(
301303
replaceRegWith(MRI, DstReg, NewDstReg);
302304
}
303305

306+
// Create a stream from 0 to n with a specified number of steps
307+
CombinerHelper::GeneratorType
308+
adderGenerator(const int32_t From, const int32_t To, const int32_t StepSize) {
309+
int32_t Counter = From;
310+
return [Counter, To, StepSize]() mutable {
311+
std::optional<int32_t> OldCount = std::optional<int32_t>(Counter);
312+
Counter += StepSize;
313+
if (OldCount == (To + StepSize))
314+
OldCount = {};
315+
return OldCount;
316+
};
317+
}
318+
304319
bool CombinerHelper::tryCombineShuffleVector(MachineInstr &MI) {
320+
const Register DstReg = MI.getOperand(0).getReg();
321+
const LLT DstTy = MRI.getType(DstReg);
322+
const LLT SrcTy = MRI.getType(MI.getOperand(1).getReg());
323+
const unsigned DstNumElts = DstTy.isVector() ? DstTy.getNumElements() : 1;
324+
const unsigned SrcNumElts = SrcTy.isVector() ? SrcTy.getNumElements() : 1;
325+
326+
// {1, 2, ..., n} -> G_CONCAT_VECTOR
327+
// Turns a shuffle vector that only increments into a concat vector
328+
// instruction
329+
GeneratorType CountUp = adderGenerator(0, DstNumElts - 1, 1);
305330
SmallVector<Register, 4> Ops;
306-
if (matchCombineShuffleVector(MI, Ops)) {
331+
if (matchCombineShuffleVector(MI, CountUp, 2 * SrcNumElts)) {
332+
// The shuffle is concatenating multiple vectors together.
333+
// Collect the different operands for that.
334+
Register UndefReg;
335+
const Register Src1 = MI.getOperand(1).getReg();
336+
const Register Src2 = MI.getOperand(2).getReg();
337+
338+
const ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
339+
340+
// The destination can be longer than the source, so we separate them into
341+
// equal blocks and check them separately to see if one of the blocks can be
342+
// copied whole.
343+
unsigned NumConcat = DstNumElts / SrcNumElts;
344+
unsigned Index = 0;
345+
for (unsigned Concat = 0; Concat < NumConcat; Concat++) {
346+
unsigned Target = (Concat + 1) * SrcNumElts;
347+
while (Index < Target) {
348+
int MaskElt = Mask[Index];
349+
if (MaskElt >= 0) {
350+
Ops.push_back((MaskElt < (int)SrcNumElts) ? Src1 : Src2);
351+
break;
352+
}
353+
Index++;
354+
}
355+
356+
if (Index == Target) {
357+
if (!UndefReg) {
358+
Builder.setInsertPt(*MI.getParent(), MI);
359+
UndefReg = Builder.buildUndef(SrcTy).getReg(0);
360+
}
361+
Ops.push_back(UndefReg);
362+
}
363+
364+
Index = Target;
365+
}
366+
307367
applyCombineShuffleVector(MI, Ops);
308368
return true;
309369
}
370+
310371
return false;
311372
}
312373

313374
bool CombinerHelper::matchCombineShuffleVector(MachineInstr &MI,
314-
SmallVectorImpl<Register> &Ops) {
375+
GeneratorType Generator,
376+
const size_t TargetDstSize) {
315377
assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR &&
316378
"Invalid instruction kind");
317379
LLT DstType = MRI.getType(MI.getOperand(0).getReg());
@@ -338,51 +400,24 @@ bool CombinerHelper::matchCombineShuffleVector(MachineInstr &MI,
338400
//
339401
// TODO: If the size between the source and destination don't match
340402
// we could still emit an extract vector element in that case.
341-
if (DstNumElts < 2 * SrcNumElts && DstNumElts != 1)
342-
return false;
343-
344-
// Check that the shuffle mask can be broken evenly between the
345-
// different sources.
346-
if (DstNumElts % SrcNumElts != 0)
403+
if ((DstNumElts < TargetDstSize) && DstNumElts != 1)
347404
return false;
348405

349-
// Mask length is a multiple of the source vector length.
350-
// Check if the shuffle is some kind of concatenation of the input
351-
// vectors.
352-
unsigned NumConcat = DstNumElts / SrcNumElts;
353-
SmallVector<int, 8> ConcatSrcs(NumConcat, -1);
354406
ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
355407
for (unsigned i = 0; i != DstNumElts; ++i) {
356408
int Idx = Mask[i];
409+
const int32_t ShiftIndex = Generator().value_or(-1);
410+
357411
// Undef value.
358-
if (Idx < 0)
412+
if (Idx < 0 || ShiftIndex < 0)
359413
continue;
360-
// Ensure the indices in each SrcType sized piece are sequential and that
414+
415+
// Ensure the indices in each SrcType sized piece are seqential and that
361416
// the same source is used for the whole piece.
362-
if ((Idx % SrcNumElts != (i % SrcNumElts)) ||
363-
(ConcatSrcs[i / SrcNumElts] >= 0 &&
364-
ConcatSrcs[i / SrcNumElts] != (int)(Idx / SrcNumElts)))
417+
if ((Idx % SrcNumElts != (ShiftIndex % SrcNumElts)))
365418
return false;
366-
// Remember which source this index came from.
367-
ConcatSrcs[i / SrcNumElts] = Idx / SrcNumElts;
368419
}
369420

370-
// The shuffle is concatenating multiple vectors together.
371-
// Collect the different operands for that.
372-
Register UndefReg;
373-
Register Src2 = MI.getOperand(2).getReg();
374-
for (auto Src : ConcatSrcs) {
375-
if (Src < 0) {
376-
if (!UndefReg) {
377-
Builder.setInsertPt(*MI.getParent(), MI);
378-
UndefReg = Builder.buildUndef(SrcType).getReg(0);
379-
}
380-
Ops.push_back(UndefReg);
381-
} else if (Src == 0)
382-
Ops.push_back(Src1);
383-
else
384-
Ops.push_back(Src2);
385-
}
386421
return true;
387422
}
388423

0 commit comments

Comments
 (0)