@@ -376,6 +376,104 @@ static VectorType *getVRGatherIndexType(MVT DataVT, const RISCVSubtarget &ST,
376376 return cast<VectorType>(EVT (IndexVT).getTypeForEVT (C));
377377}
378378
379+ // / Try to perform better estimation of the permutation.
380+ // / 1. Split the source/destination vectors into real registers.
381+ // / 2. Do the mask analysis to identify which real registers are
382+ // / permuted. If more than 1 source registers are used for the
383+ // / destination register building, the cost for this destination register
384+ // / is (Number_of_source_register - 1) * Cost_PermuteTwoSrc. If only one
385+ // / source register is used, build mask and calculate the cost as a cost
386+ // / of PermuteSingleSrc.
387+ // / Also, for the single register permute we try to identify if the
388+ // / destination register is just a copy of the source register or the
389+ // / copy of the previous destination register (the cost is
390+ // / TTI::TCC_Basic). If the source register is just reused, the cost for
391+ // / this operation is 0.
392+ static InstructionCost
393+ costShuffleViaVRegSplitting (RISCVTTIImpl &TTI, MVT LegalVT,
394+ std::optional<unsigned > VLen, VectorType *Tp,
395+ ArrayRef<int > Mask, TTI::TargetCostKind CostKind) {
396+ InstructionCost NumOfDests = InstructionCost::getInvalid ();
397+ if (VLen && LegalVT.isFixedLengthVector () && !Mask.empty ()) {
398+ MVT ElemVT = LegalVT.getVectorElementType ();
399+ unsigned ElemsPerVReg = *VLen / ElemVT.getFixedSizeInBits ();
400+ LegalVT = TTI.getTypeLegalizationCost (
401+ FixedVectorType::get (Tp->getElementType (), ElemsPerVReg))
402+ .second ;
403+ // Number of destination vectors after legalization:
404+ NumOfDests = divideCeil (Mask.size (), LegalVT.getVectorNumElements ());
405+ }
406+ if (!NumOfDests.isValid () || NumOfDests <= 1 ||
407+ !LegalVT.isFixedLengthVector () ||
408+ LegalVT.getVectorElementType ().getSizeInBits () !=
409+ Tp->getElementType ()->getPrimitiveSizeInBits () ||
410+ LegalVT.getVectorNumElements () >= Tp->getElementCount ().getFixedValue ())
411+ return InstructionCost::getInvalid ();
412+
413+ unsigned VecTySize = TTI.getDataLayout ().getTypeStoreSize (Tp);
414+ unsigned LegalVTSize = LegalVT.getStoreSize ();
415+ // Number of source vectors after legalization:
416+ unsigned NumOfSrcs = divideCeil (VecTySize, LegalVTSize);
417+
418+ auto *SingleOpTy = FixedVectorType::get (Tp->getElementType (),
419+ LegalVT.getVectorNumElements ());
420+
421+ unsigned E = *NumOfDests.getValue ();
422+ unsigned NormalizedVF =
423+ LegalVT.getVectorNumElements () * std::max (NumOfSrcs, E);
424+ unsigned NumOfSrcRegs = NormalizedVF / LegalVT.getVectorNumElements ();
425+ unsigned NumOfDestRegs = NormalizedVF / LegalVT.getVectorNumElements ();
426+ SmallVector<int > NormalizedMask (NormalizedVF, PoisonMaskElem);
427+ assert (NormalizedVF >= Mask.size () &&
428+ " Normalized mask expected to be not shorter than original mask." );
429+ copy (Mask, NormalizedMask.begin ());
430+ InstructionCost Cost = 0 ;
431+ SmallBitVector ExtractedRegs (2 * NumOfSrcRegs);
432+ int NumShuffles = 0 ;
433+ processShuffleMasks (
434+ NormalizedMask, NumOfSrcRegs, NumOfDestRegs, NumOfDestRegs, []() {},
435+ [&](ArrayRef<int > RegMask, unsigned SrcReg, unsigned DestReg) {
436+ if (ExtractedRegs.test (SrcReg)) {
437+ Cost += TTI.getShuffleCost (TTI::SK_ExtractSubvector, Tp, {}, CostKind,
438+ (SrcReg % NumOfSrcRegs) *
439+ SingleOpTy->getNumElements (),
440+ SingleOpTy);
441+ ExtractedRegs.set (SrcReg);
442+ }
443+ if (!ShuffleVectorInst::isIdentityMask (RegMask, RegMask.size ())) {
444+ ++NumShuffles;
445+ Cost += TTI.getShuffleCost (TTI::SK_PermuteSingleSrc, SingleOpTy,
446+ RegMask, CostKind, 0 , nullptr );
447+ return ;
448+ }
449+ },
450+ [&](ArrayRef<int > RegMask, unsigned Idx1, unsigned Idx2, bool NewReg) {
451+ if (ExtractedRegs.test (Idx1)) {
452+ Cost += TTI.getShuffleCost (
453+ TTI::SK_ExtractSubvector, Tp, {}, CostKind,
454+ (Idx1 % NumOfSrcRegs) * SingleOpTy->getNumElements (), SingleOpTy);
455+ ExtractedRegs.set (Idx1);
456+ }
457+ if (ExtractedRegs.test (Idx2)) {
458+ Cost += TTI.getShuffleCost (
459+ TTI::SK_ExtractSubvector, Tp, {}, CostKind,
460+ (Idx2 % NumOfSrcRegs) * SingleOpTy->getNumElements (), SingleOpTy);
461+ ExtractedRegs.set (Idx2);
462+ }
463+ Cost += TTI.getShuffleCost (TTI::SK_PermuteTwoSrc, SingleOpTy, RegMask,
464+ CostKind, 0 , nullptr );
465+ NumShuffles += 2 ;
466+ });
467+ // Note: check that we do not emit too many shuffles here to prevent code
468+ // size explosion.
469+ // TODO: investigate, if it can be improved by extra analysis of the masks
470+ // to check if the code is more profitable.
471+ if ((NumOfDestRegs > 2 && NumShuffles <= static_cast <int >(NumOfDestRegs)) ||
472+ (NumOfDestRegs <= 2 && NumShuffles < 4 ))
473+ return Cost;
474+ return InstructionCost::getInvalid ();
475+ }
476+
379477InstructionCost RISCVTTIImpl::getShuffleCost (TTI::ShuffleKind Kind,
380478 VectorType *Tp, ArrayRef<int > Mask,
381479 TTI::TargetCostKind CostKind,
@@ -389,7 +487,11 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
389487 // First, handle cases where having a fixed length vector enables us to
390488 // give a more accurate cost than falling back to generic scalable codegen.
391489 // TODO: Each of these cases hints at a modeling gap around scalable vectors.
392- if (isa<FixedVectorType>(Tp)) {
490+ if (ST->hasVInstructions () && isa<FixedVectorType>(Tp)) {
491+ InstructionCost VRegSplittingCost = costShuffleViaVRegSplitting (
492+ *this , LT.second , ST->getRealVLen (), Tp, Mask, CostKind);
493+ if (VRegSplittingCost.isValid ())
494+ return VRegSplittingCost;
393495 switch (Kind) {
394496 default :
395497 break ;
0 commit comments