3939#include " llvm/Support/MathExtras.h"
4040#include " llvm/Target/TargetMachine.h"
4141#include < cmath>
42+ #include < cstdint>
43+ #include < functional>
4244#include < optional>
4345#include < tuple>
4446
@@ -301,17 +303,77 @@ void CombinerHelper::applyCombineConcatVectors(
301303 replaceRegWith (MRI, DstReg, NewDstReg);
302304}
303305
306+ // Create a stream from 0 to n with a specified number of steps
307+ CombinerHelper::GeneratorType
308+ adderGenerator (const int32_t From, const int32_t To, const int32_t StepSize) {
309+ int32_t Counter = From;
310+ return [Counter, To, StepSize]() mutable {
311+ std::optional<int32_t > OldCount = std::optional<int32_t >(Counter);
312+ Counter += StepSize;
313+ if (OldCount == (To + StepSize))
314+ OldCount = {};
315+ return OldCount;
316+ };
317+ }
318+
304319bool CombinerHelper::tryCombineShuffleVector (MachineInstr &MI) {
320+ const Register DstReg = MI.getOperand (0 ).getReg ();
321+ const LLT DstTy = MRI.getType (DstReg);
322+ const LLT SrcTy = MRI.getType (MI.getOperand (1 ).getReg ());
323+ const unsigned DstNumElts = DstTy.isVector () ? DstTy.getNumElements () : 1 ;
324+ const unsigned SrcNumElts = SrcTy.isVector () ? SrcTy.getNumElements () : 1 ;
325+
326+ // {1, 2, ..., n} -> G_CONCAT_VECTOR
327+ // Turns a shuffle vector that only increments into a concat vector
328+ // instruction
329+ GeneratorType CountUp = adderGenerator (0 , DstNumElts - 1 , 1 );
305330 SmallVector<Register, 4 > Ops;
306- if (matchCombineShuffleVector (MI, Ops)) {
331+ if (matchCombineShuffleVector (MI, CountUp, 2 * SrcNumElts)) {
332+ // The shuffle is concatenating multiple vectors together.
333+ // Collect the different operands for that.
334+ Register UndefReg;
335+ const Register Src1 = MI.getOperand (1 ).getReg ();
336+ const Register Src2 = MI.getOperand (2 ).getReg ();
337+
338+ const ArrayRef<int > Mask = MI.getOperand (3 ).getShuffleMask ();
339+
340+ // The destination can be longer than the source, so we separate them into
341+ // equal blocks and check them separately to see if one of the blocks can be
342+ // copied whole.
343+ unsigned NumConcat = DstNumElts / SrcNumElts;
344+ unsigned Index = 0 ;
345+ for (unsigned Concat = 0 ; Concat < NumConcat; Concat++) {
346+ unsigned Target = (Concat + 1 ) * SrcNumElts;
347+ while (Index < Target) {
348+ int MaskElt = Mask[Index];
349+ if (MaskElt >= 0 ) {
350+ Ops.push_back ((MaskElt < (int )SrcNumElts) ? Src1 : Src2);
351+ break ;
352+ }
353+ Index++;
354+ }
355+
356+ if (Index == Target) {
357+ if (!UndefReg) {
358+ Builder.setInsertPt (*MI.getParent (), MI);
359+ UndefReg = Builder.buildUndef (SrcTy).getReg (0 );
360+ }
361+ Ops.push_back (UndefReg);
362+ }
363+
364+ Index = Target;
365+ }
366+
307367 applyCombineShuffleVector (MI, Ops);
308368 return true ;
309369 }
370+
310371 return false ;
311372}
312373
313374bool CombinerHelper::matchCombineShuffleVector (MachineInstr &MI,
314- SmallVectorImpl<Register> &Ops) {
375+ GeneratorType Generator,
376+ const size_t TargetDstSize) {
315377 assert (MI.getOpcode () == TargetOpcode::G_SHUFFLE_VECTOR &&
316378 " Invalid instruction kind" );
317379 LLT DstType = MRI.getType (MI.getOperand (0 ).getReg ());
@@ -338,51 +400,24 @@ bool CombinerHelper::matchCombineShuffleVector(MachineInstr &MI,
338400 //
339401 // TODO: If the size between the source and destination don't match
340402 // we could still emit an extract vector element in that case.
341- if (DstNumElts < 2 * SrcNumElts && DstNumElts != 1 )
342- return false ;
343-
344- // Check that the shuffle mask can be broken evenly between the
345- // different sources.
346- if (DstNumElts % SrcNumElts != 0 )
403+ if ((DstNumElts < TargetDstSize) && DstNumElts != 1 )
347404 return false ;
348405
349- // Mask length is a multiple of the source vector length.
350- // Check if the shuffle is some kind of concatenation of the input
351- // vectors.
352- unsigned NumConcat = DstNumElts / SrcNumElts;
353- SmallVector<int , 8 > ConcatSrcs (NumConcat, -1 );
354406 ArrayRef<int > Mask = MI.getOperand (3 ).getShuffleMask ();
355407 for (unsigned i = 0 ; i != DstNumElts; ++i) {
356408 int Idx = Mask[i];
409+ const int32_t ShiftIndex = Generator ().value_or (-1 );
410+
357411 // Undef value.
358- if (Idx < 0 )
412+ if (Idx < 0 || ShiftIndex < 0 )
359413 continue ;
360- // Ensure the indices in each SrcType sized piece are sequential and that
414+
415+ // Ensure the indices in each SrcType sized piece are seqential and that
361416 // the same source is used for the whole piece.
362- if ((Idx % SrcNumElts != (i % SrcNumElts)) ||
363- (ConcatSrcs[i / SrcNumElts] >= 0 &&
364- ConcatSrcs[i / SrcNumElts] != (int )(Idx / SrcNumElts)))
417+ if ((Idx % SrcNumElts != (ShiftIndex % SrcNumElts)))
365418 return false ;
366- // Remember which source this index came from.
367- ConcatSrcs[i / SrcNumElts] = Idx / SrcNumElts;
368419 }
369420
370- // The shuffle is concatenating multiple vectors together.
371- // Collect the different operands for that.
372- Register UndefReg;
373- Register Src2 = MI.getOperand (2 ).getReg ();
374- for (auto Src : ConcatSrcs) {
375- if (Src < 0 ) {
376- if (!UndefReg) {
377- Builder.setInsertPt (*MI.getParent (), MI);
378- UndefReg = Builder.buildUndef (SrcType).getReg (0 );
379- }
380- Ops.push_back (UndefReg);
381- } else if (Src == 0 )
382- Ops.push_back (Src1);
383- else
384- Ops.push_back (Src2);
385- }
386421 return true ;
387422}
388423
0 commit comments