@@ -410,6 +410,150 @@ void applyExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
410410 MI.eraseFromParent ();
411411}
412412
413+ // Matches {U/S}ADDV(ext(x)) => {U/S}ADDLV(x)
414+ // Ensure that the type coming from the extend instruction is the right size
415+ bool matchExtUaddvToUaddlv (MachineInstr &MI, MachineRegisterInfo &MRI,
416+ std::pair<Register, bool > &MatchInfo) {
417+ assert (MI.getOpcode () == TargetOpcode::G_VECREDUCE_ADD &&
418+ " Expected G_VECREDUCE_ADD Opcode" );
419+
420+ // Check if the last instruction is an extend
421+ MachineInstr *ExtMI = getDefIgnoringCopies (MI.getOperand (1 ).getReg (), MRI);
422+ auto ExtOpc = ExtMI->getOpcode ();
423+
424+ if (ExtOpc == TargetOpcode::G_ZEXT)
425+ std::get<1 >(MatchInfo) = 0 ;
426+ else if (ExtOpc == TargetOpcode::G_SEXT)
427+ std::get<1 >(MatchInfo) = 1 ;
428+ else
429+ return false ;
430+
431+ // Check if the source register is a valid type
432+ Register ExtSrcReg = ExtMI->getOperand (1 ).getReg ();
433+ LLT ExtSrcTy = MRI.getType (ExtSrcReg);
434+ LLT DstTy = MRI.getType (MI.getOperand (0 ).getReg ());
435+ if ((DstTy.getScalarSizeInBits () == 16 &&
436+ ExtSrcTy.getNumElements () % 8 == 0 && ExtSrcTy.getNumElements () < 256 ) ||
437+ (DstTy.getScalarSizeInBits () == 32 &&
438+ ExtSrcTy.getNumElements () % 4 == 0 ) ||
439+ (DstTy.getScalarSizeInBits () == 64 &&
440+ ExtSrcTy.getNumElements () % 4 == 0 )) {
441+ std::get<0 >(MatchInfo) = ExtSrcReg;
442+ return true ;
443+ }
444+ return false ;
445+ }
446+
447+ void applyExtUaddvToUaddlv (MachineInstr &MI, MachineRegisterInfo &MRI,
448+ MachineIRBuilder &B, GISelChangeObserver &Observer,
449+ std::pair<Register, bool > &MatchInfo) {
450+ assert (MI.getOpcode () == TargetOpcode::G_VECREDUCE_ADD &&
451+ " Expected G_VECREDUCE_ADD Opcode" );
452+
453+ unsigned Opc = std::get<1 >(MatchInfo) ? AArch64::G_SADDLV : AArch64::G_UADDLV;
454+ Register SrcReg = std::get<0 >(MatchInfo);
455+ Register DstReg = MI.getOperand (0 ).getReg ();
456+ LLT SrcTy = MRI.getType (SrcReg);
457+ LLT DstTy = MRI.getType (DstReg);
458+
459+ // If SrcTy has more elements than expected, split them into multiple
460+ // insructions and sum the results
461+ LLT MainTy;
462+ SmallVector<Register, 1 > WorkingRegisters;
463+ unsigned SrcScalSize = SrcTy.getScalarSizeInBits ();
464+ unsigned SrcNumElem = SrcTy.getNumElements ();
465+ if ((SrcScalSize == 8 && SrcNumElem > 16 ) ||
466+ (SrcScalSize == 16 && SrcNumElem > 8 ) ||
467+ (SrcScalSize == 32 && SrcNumElem > 4 )) {
468+
469+ LLT LeftoverTy;
470+ SmallVector<Register, 4 > LeftoverRegs;
471+ if (SrcScalSize == 8 )
472+ MainTy = LLT::fixed_vector (16 , 8 );
473+ else if (SrcScalSize == 16 )
474+ MainTy = LLT::fixed_vector (8 , 16 );
475+ else if (SrcScalSize == 32 )
476+ MainTy = LLT::fixed_vector (4 , 32 );
477+ else
478+ llvm_unreachable (" Source's Scalar Size not supported" );
479+
480+ // Extract the parts and put each extracted sources through U/SADDLV and put
481+ // the values inside a small vec
482+ extractParts (SrcReg, SrcTy, MainTy, LeftoverTy, WorkingRegisters,
483+ LeftoverRegs, B, MRI);
484+ for (unsigned I = 0 ; I < LeftoverRegs.size (); I++) {
485+ WorkingRegisters.push_back (LeftoverRegs[I]);
486+ }
487+ } else {
488+ WorkingRegisters.push_back (SrcReg);
489+ MainTy = SrcTy;
490+ }
491+
492+ unsigned MidScalarSize = MainTy.getScalarSizeInBits () * 2 ;
493+ LLT MidScalarLLT = LLT::scalar (MidScalarSize);
494+ Register zeroReg = B.buildConstant (LLT::scalar (64 ), 0 ).getReg (0 );
495+ for (unsigned I = 0 ; I < WorkingRegisters.size (); I++) {
496+ // If the number of elements is too small to build an instruction, extend
497+ // its size before applying addlv
498+ LLT WorkingRegTy = MRI.getType (WorkingRegisters[I]);
499+ if ((WorkingRegTy.getScalarSizeInBits () == 8 ) &&
500+ (WorkingRegTy.getNumElements () == 4 )) {
501+ WorkingRegisters[I] =
502+ B.buildInstr (std::get<1 >(MatchInfo) ? TargetOpcode::G_SEXT
503+ : TargetOpcode::G_ZEXT,
504+ {LLT::fixed_vector (4 , 16 )}, {WorkingRegisters[I]})
505+ .getReg (0 );
506+ }
507+
508+ // Generate the {U/S}ADDLV instruction, whose output is always double of the
509+ // Src's Scalar size
510+ LLT addlvTy = MidScalarSize <= 32 ? LLT::fixed_vector (4 , 32 )
511+ : LLT::fixed_vector (2 , 64 );
512+ Register addlvReg =
513+ B.buildInstr (Opc, {addlvTy}, {WorkingRegisters[I]}).getReg (0 );
514+
515+ // The output from {U/S}ADDLV gets placed in the lowest lane of a v4i32 or
516+ // v2i64 register.
517+ // i16, i32 results uses v4i32 registers
518+ // i64 results uses v2i64 registers
519+ // Therefore we have to extract/truncate the the value to the right type
520+ if (MidScalarSize == 32 || MidScalarSize == 64 ) {
521+ WorkingRegisters[I] = B.buildInstr (AArch64::G_EXTRACT_VECTOR_ELT,
522+ {MidScalarLLT}, {addlvReg, zeroReg})
523+ .getReg (0 );
524+ } else {
525+ Register extractReg = B.buildInstr (AArch64::G_EXTRACT_VECTOR_ELT,
526+ {LLT::scalar (32 )}, {addlvReg, zeroReg})
527+ .getReg (0 );
528+ WorkingRegisters[I] =
529+ B.buildTrunc ({MidScalarLLT}, {extractReg}).getReg (0 );
530+ }
531+ }
532+
533+ Register outReg;
534+ if (WorkingRegisters.size () > 1 ) {
535+ outReg = B.buildAdd (MidScalarLLT, WorkingRegisters[0 ], WorkingRegisters[1 ])
536+ .getReg (0 );
537+ for (unsigned I = 2 ; I < WorkingRegisters.size (); I++) {
538+ outReg = B.buildAdd (MidScalarLLT, outReg, WorkingRegisters[I]).getReg (0 );
539+ }
540+ } else {
541+ outReg = WorkingRegisters[0 ];
542+ }
543+
544+ if (DstTy.getScalarSizeInBits () > MidScalarSize) {
545+ // Handle the scalar value if the DstTy's Scalar Size is more than double
546+ // Src's ScalarType
547+ B.buildInstr (std::get<1 >(MatchInfo) ? TargetOpcode::G_SEXT
548+ : TargetOpcode::G_ZEXT,
549+ {DstReg}, {outReg});
550+ } else {
551+ B.buildCopy (DstReg, outReg);
552+ }
553+
554+ MI.eraseFromParent ();
555+ }
556+
413557bool tryToSimplifyUADDO (MachineInstr &MI, MachineIRBuilder &B,
414558 CombinerHelper &Helper, GISelChangeObserver &Observer) {
415559 // Try simplify G_UADDO with 8 or 16 bit operands to wide G_ADD and TBNZ if
0 commit comments