@@ -135,6 +135,7 @@ class VectorCombine {
135
135
bool foldShuffleOfIntrinsics (Instruction &I);
136
136
bool foldShuffleToIdentity (Instruction &I);
137
137
bool foldShuffleFromReductions (Instruction &I);
138
+ bool foldShuffleChainsToReduce (Instruction &I);
138
139
bool foldCastFromReductions (Instruction &I);
139
140
bool foldSelectShuffle (Instruction &I, bool FromReduction = false );
140
141
bool foldInterleaveIntrinsics (Instruction &I);
@@ -3136,6 +3137,267 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
3136
3137
return MadeChanges;
3137
3138
}
3138
3139
3140
+ // / For a given chain of patterns of the following form:
3141
+ // /
3142
+ // / ```
3143
+ // / %1 = shufflevector <n x ty1> %0, <n x ty1> poison <n x ty2> mask
3144
+ // /
3145
+ // / %2 = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %0, <n x
3146
+ // / ty1> %1)
3147
+ // / OR
3148
+ // / %2 = add/mul/or/and/xor <n x ty1> %0, %1
3149
+ // /
3150
+ // / %3 = shufflevector <n x ty1> %2, <n x ty1> poison <n x ty2> mask
3151
+ // / ...
3152
+ // / ...
3153
+ // / %(i - 1) = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %(i -
3154
+ // / 3), <n x ty1> %(i - 2)
3155
+ // / OR
3156
+ // / %(i - 1) = add/mul/or/and/xor <n x ty1> %(i - 3), %(i - 2)
3157
+ // /
3158
+ // / %(i) = extractelement <n x ty1> %(i - 1), 0
3159
+ // / ```
3160
+ // /
3161
+ // / Where:
3162
+ // / `mask` follows a partition pattern:
3163
+ // /
3164
+ // / Ex:
3165
+ // / [n = 8, p = poison]
3166
+ // /
3167
+ // / 4 5 6 7 | p p p p
3168
+ // / 2 3 | p p p p p p
3169
+ // / 1 | p p p p p p p
3170
+ // /
3171
+ // / For powers of 2, there's a consistent pattern, but for other cases
3172
+ // / the parity of the current half value at each step decides the
3173
+ // / next partition half (see `ExpectedParityMask` for more logical details
3174
+ // / in generalising this).
3175
+ // /
3176
+ // / Ex:
3177
+ // / [n = 6]
3178
+ // /
3179
+ // / 3 4 5 | p p p
3180
+ // / 1 2 | p p p p
3181
+ // / 1 | p p p p p
3182
+ bool VectorCombine::foldShuffleChainsToReduce (Instruction &I) {
3183
+ // Going bottom-up for the pattern.
3184
+ std::queue<Value *> InstWorklist;
3185
+ InstructionCost OrigCost = 0 ;
3186
+
3187
+ // Common instruction operation after each shuffle op.
3188
+ std::optional<unsigned int > CommonCallOp = std::nullopt ;
3189
+ std::optional<Instruction::BinaryOps> CommonBinOp = std::nullopt ;
3190
+
3191
+ bool IsFirstCallOrBinInst = true ;
3192
+ bool ShouldBeCallOrBinInst = true ;
3193
+
3194
+ // This stores the last used instructions for shuffle/common op.
3195
+ //
3196
+ // PrevVecV[0] / PrevVecV[1] store the last two simultaneous
3197
+ // instructions from either shuffle/common op.
3198
+ SmallVector<Value *, 2 > PrevVecV (2 , nullptr );
3199
+
3200
+ Value *VecOpEE;
3201
+ if (!match (&I, m_ExtractElt (m_Value (VecOpEE), m_Zero ())))
3202
+ return false ;
3203
+
3204
+ auto *FVT = dyn_cast<FixedVectorType>(VecOpEE->getType ());
3205
+ if (!FVT)
3206
+ return false ;
3207
+
3208
+ int64_t VecSize = FVT->getNumElements ();
3209
+ if (VecSize < 2 )
3210
+ return false ;
3211
+
3212
+ // Number of levels would be ~log2(n), considering we always partition
3213
+ // by half for this fold pattern.
3214
+ unsigned int NumLevels = Log2_64_Ceil (VecSize), VisitedCnt = 0 ;
3215
+ int64_t ShuffleMaskHalf = 1 , ExpectedParityMask = 0 ;
3216
+
3217
+ // This is how we generalise for all element sizes.
3218
+ // At each step, if vector size is odd, we need non-poison
3219
+ // values to cover the dominant half so we don't miss out on any element.
3220
+ //
3221
+ // This mask will help us retrieve this as we go from bottom to top:
3222
+ //
3223
+ // Mask Set -> N = N * 2 - 1
3224
+ // Mask Unset -> N = N * 2
3225
+ for (int Cur = VecSize, Mask = NumLevels - 1 ; Cur > 1 ;
3226
+ Cur = (Cur + 1 ) / 2 , --Mask) {
3227
+ if (Cur & 1 )
3228
+ ExpectedParityMask |= (1ll << Mask);
3229
+ }
3230
+
3231
+ InstWorklist.push (VecOpEE);
3232
+
3233
+ while (!InstWorklist.empty ()) {
3234
+ Value *CI = InstWorklist.front ();
3235
+ InstWorklist.pop ();
3236
+
3237
+ if (auto *II = dyn_cast<IntrinsicInst>(CI)) {
3238
+ if (!ShouldBeCallOrBinInst)
3239
+ return false ;
3240
+
3241
+ if (!IsFirstCallOrBinInst &&
3242
+ any_of (PrevVecV, [](Value *VecV) { return VecV == nullptr ; }))
3243
+ return false ;
3244
+
3245
+ // For the first found call/bin op, the vector has to come from the
3246
+ // extract element op.
3247
+ if (II != (IsFirstCallOrBinInst ? VecOpEE : PrevVecV[0 ]))
3248
+ return false ;
3249
+ IsFirstCallOrBinInst = false ;
3250
+
3251
+ if (!CommonCallOp)
3252
+ CommonCallOp = II->getIntrinsicID ();
3253
+ if (II->getIntrinsicID () != *CommonCallOp)
3254
+ return false ;
3255
+
3256
+ switch (II->getIntrinsicID ()) {
3257
+ case Intrinsic::umin:
3258
+ case Intrinsic::umax:
3259
+ case Intrinsic::smin:
3260
+ case Intrinsic::smax: {
3261
+ auto *Op0 = II->getOperand (0 );
3262
+ auto *Op1 = II->getOperand (1 );
3263
+ PrevVecV[0 ] = Op0;
3264
+ PrevVecV[1 ] = Op1;
3265
+ break ;
3266
+ }
3267
+ default :
3268
+ return false ;
3269
+ }
3270
+ ShouldBeCallOrBinInst ^= 1 ;
3271
+
3272
+ IntrinsicCostAttributes ICA (
3273
+ *CommonCallOp, II->getType (),
3274
+ {PrevVecV[0 ]->getType (), PrevVecV[1 ]->getType ()});
3275
+ OrigCost += TTI.getIntrinsicInstrCost (ICA, CostKind);
3276
+
3277
+ // We may need a swap here since it can be (a, b) or (b, a)
3278
+ // and accordingly change as we go up.
3279
+ if (!isa<ShuffleVectorInst>(PrevVecV[1 ]))
3280
+ std::swap (PrevVecV[0 ], PrevVecV[1 ]);
3281
+ InstWorklist.push (PrevVecV[1 ]);
3282
+ InstWorklist.push (PrevVecV[0 ]);
3283
+ } else if (auto *BinOp = dyn_cast<BinaryOperator>(CI)) {
3284
+ // Similar logic for bin ops.
3285
+
3286
+ if (!ShouldBeCallOrBinInst)
3287
+ return false ;
3288
+
3289
+ if (!IsFirstCallOrBinInst &&
3290
+ any_of (PrevVecV, [](Value *VecV) { return VecV == nullptr ; }))
3291
+ return false ;
3292
+
3293
+ if (BinOp != (IsFirstCallOrBinInst ? VecOpEE : PrevVecV[0 ]))
3294
+ return false ;
3295
+ IsFirstCallOrBinInst = false ;
3296
+
3297
+ if (!CommonBinOp)
3298
+ CommonBinOp = BinOp->getOpcode ();
3299
+
3300
+ if (BinOp->getOpcode () != *CommonBinOp)
3301
+ return false ;
3302
+
3303
+ switch (*CommonBinOp) {
3304
+ case BinaryOperator::Add:
3305
+ case BinaryOperator::Mul:
3306
+ case BinaryOperator::Or:
3307
+ case BinaryOperator::And:
3308
+ case BinaryOperator::Xor: {
3309
+ auto *Op0 = BinOp->getOperand (0 );
3310
+ auto *Op1 = BinOp->getOperand (1 );
3311
+ PrevVecV[0 ] = Op0;
3312
+ PrevVecV[1 ] = Op1;
3313
+ break ;
3314
+ }
3315
+ default :
3316
+ return false ;
3317
+ }
3318
+ ShouldBeCallOrBinInst ^= 1 ;
3319
+
3320
+ OrigCost +=
3321
+ TTI.getArithmeticInstrCost (*CommonBinOp, BinOp->getType (), CostKind);
3322
+
3323
+ if (!isa<ShuffleVectorInst>(PrevVecV[1 ]))
3324
+ std::swap (PrevVecV[0 ], PrevVecV[1 ]);
3325
+ InstWorklist.push (PrevVecV[1 ]);
3326
+ InstWorklist.push (PrevVecV[0 ]);
3327
+ } else if (auto *SVInst = dyn_cast<ShuffleVectorInst>(CI)) {
3328
+ // We shouldn't have any null values in the previous vectors,
3329
+ // is so, there was a mismatch in pattern.
3330
+ if (ShouldBeCallOrBinInst ||
3331
+ any_of (PrevVecV, [](Value *VecV) { return VecV == nullptr ; }))
3332
+ return false ;
3333
+
3334
+ if (SVInst != PrevVecV[1 ])
3335
+ return false ;
3336
+
3337
+ ArrayRef<int > CurMask;
3338
+ if (!match (SVInst, m_Shuffle (m_Specific (PrevVecV[0 ]), m_Poison (),
3339
+ m_Mask (CurMask))))
3340
+ return false ;
3341
+
3342
+ // Subtract the parity mask when checking the condition.
3343
+ for (int Mask = 0 , MaskSize = CurMask.size (); Mask != MaskSize; ++Mask) {
3344
+ if (Mask < ShuffleMaskHalf &&
3345
+ CurMask[Mask] != ShuffleMaskHalf + Mask - (ExpectedParityMask & 1 ))
3346
+ return false ;
3347
+ if (Mask >= ShuffleMaskHalf && CurMask[Mask] != -1 )
3348
+ return false ;
3349
+ }
3350
+
3351
+ // Update mask values.
3352
+ ShuffleMaskHalf *= 2 ;
3353
+ ShuffleMaskHalf -= (ExpectedParityMask & 1 );
3354
+ ExpectedParityMask >>= 1 ;
3355
+
3356
+ OrigCost += TTI.getShuffleCost (TargetTransformInfo::SK_PermuteSingleSrc,
3357
+ SVInst->getType (), SVInst->getType (),
3358
+ CurMask, CostKind);
3359
+
3360
+ VisitedCnt += 1 ;
3361
+ if (!ExpectedParityMask && VisitedCnt == NumLevels)
3362
+ break ;
3363
+
3364
+ ShouldBeCallOrBinInst ^= 1 ;
3365
+ } else {
3366
+ return false ;
3367
+ }
3368
+ }
3369
+
3370
+ // Pattern should end with a shuffle op.
3371
+ if (ShouldBeCallOrBinInst)
3372
+ return false ;
3373
+
3374
+ assert (VecSize != -1 && " Expected Match for Vector Size" );
3375
+
3376
+ Value *FinalVecV = PrevVecV[0 ];
3377
+ if (!FinalVecV)
3378
+ return false ;
3379
+
3380
+ auto *FinalVecVTy = cast<FixedVectorType>(FinalVecV->getType ());
3381
+
3382
+ Intrinsic::ID ReducedOp =
3383
+ (CommonCallOp ? getMinMaxReductionIntrinsicID (*CommonCallOp)
3384
+ : getReductionForBinop (*CommonBinOp));
3385
+ if (!ReducedOp)
3386
+ return false ;
3387
+
3388
+ IntrinsicCostAttributes ICA (ReducedOp, FinalVecVTy, {FinalVecV});
3389
+ InstructionCost NewCost = TTI.getIntrinsicInstrCost (ICA, CostKind);
3390
+
3391
+ if (NewCost >= OrigCost)
3392
+ return false ;
3393
+
3394
+ auto *ReducedResult =
3395
+ Builder.CreateIntrinsic (ReducedOp, {FinalVecV->getType ()}, {FinalVecV});
3396
+ replaceValue (I, *ReducedResult);
3397
+
3398
+ return true ;
3399
+ }
3400
+
3139
3401
// / Determine if its more efficient to fold:
3140
3402
// / reduce(trunc(x)) -> trunc(reduce(x)).
3141
3403
// / reduce(sext(x)) -> sext(reduce(x)).
@@ -4223,6 +4485,9 @@ bool VectorCombine::run() {
4223
4485
if (foldCastFromReductions (I))
4224
4486
return true ;
4225
4487
break ;
4488
+ case Instruction::ExtractElement:
4489
+ if (foldShuffleChainsToReduce (I))
4490
+ return true ;
4226
4491
case Instruction::ICmp:
4227
4492
case Instruction::FCmp:
4228
4493
if (foldExtractExtract (I))
0 commit comments