Skip to content
90 changes: 90 additions & 0 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ class VectorCombine {
bool foldShuffleOfSelects(Instruction &I);
bool foldShuffleOfCastops(Instruction &I);
bool foldShuffleOfShuffles(Instruction &I);
bool foldPermuteOfIntrinsic(Instruction &I);
bool foldShuffleOfIntrinsics(Instruction &I);
bool foldShuffleToIdentity(Instruction &I);
bool foldShuffleFromReductions(Instruction &I);
Expand Down Expand Up @@ -2960,6 +2961,93 @@ bool VectorCombine::foldShuffleOfIntrinsics(Instruction &I) {
return true;
}

/// Try to convert
/// "shuffle (intrinsic), (poison/undef)" into "intrinsic (shuffle)".
bool VectorCombine::foldPermuteOfIntrinsic(Instruction &I) {
Value *V0, *V1;
ArrayRef<int> Mask;
if (!match(&I, m_Shuffle(m_OneUse(m_Value(V0)), m_Value(V1), m_Mask(Mask))))
return false;

// Check for permute
if (!match(V1, m_Poison()) && !match(V1, m_Undef())) {
LLVM_DEBUG(dbgs() << "not a permute\n");
return false;
}

auto *II0 = dyn_cast<IntrinsicInst>(V0);
if (!II0)
return false;

auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
auto *IntrinsicSrcTy = dyn_cast<FixedVectorType>(II0->getType());
if (!ShuffleDstTy || !IntrinsicSrcTy)
return false;

// Validate it's a pure permute, mask should only reference the first vector
unsigned NumSrcElts = IntrinsicSrcTy->getNumElements();
for (int Idx : Mask) {
if (Idx > 0 && Idx >= (int)NumSrcElts)
return false;
}

Intrinsic::ID IID = II0->getIntrinsicID();
if (!isTriviallyVectorizable(IID))
return false;

// Cost analysis
InstructionCost OldCost =
TTI.getIntrinsicInstrCost(IntrinsicCostAttributes(IID, *II0), CostKind) +
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, ShuffleDstTy,
IntrinsicSrcTy, Mask, CostKind);

SmallVector<Type *> NewArgsTy;
InstructionCost NewCost = 0;

for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI)) {
NewArgsTy.push_back(II0->getArgOperand(I)->getType());
} else {
auto *VecTy = cast<FixedVectorType>(II0->getArgOperand(I)->getType());
auto *ArgTy = FixedVectorType::get(VecTy->getElementType(),
ShuffleDstTy->getNumElements());
NewArgsTy.push_back(ArgTy);
NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc,
ArgTy, VecTy, Mask, CostKind);
}
}
IntrinsicCostAttributes NewAttr(IID, ShuffleDstTy, NewArgsTy);
NewCost += TTI.getIntrinsicInstrCost(NewAttr, CostKind);

LLVM_DEBUG(dbgs() << "Found a permute of intrinsic: " << I << "\n OldCost: "
<< OldCost << " vs NewCost: " << NewCost << "\n");

if (NewCost > OldCost)
return false;

// Transform
SmallVector<Value *> NewArgs;
for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI)) {
NewArgs.push_back(II0->getArgOperand(I));
} else {
Value *Shuf = Builder.CreateShuffleVector(
II0->getArgOperand(I),
PoisonValue::get(II0->getArgOperand(I)->getType()), Mask);
NewArgs.push_back(Shuf);
Worklist.pushValue(Shuf);
}
}

Value *NewIntrinsic = Builder.CreateIntrinsic(ShuffleDstTy, IID, NewArgs);

if (auto *NewInst = dyn_cast<Instruction>(NewIntrinsic))
NewInst->copyIRFlags(II0);

replaceValue(I, *NewIntrinsic);
return true;
}

using InstLane = std::pair<Use *, int>;

static InstLane lookThroughShuffles(Use *U, int Lane) {
Expand Down Expand Up @@ -4718,6 +4806,8 @@ bool VectorCombine::run() {
return true;
if (foldShuffleOfShuffles(I))
return true;
if (foldPermuteOfIntrinsic(I))
return true;
if (foldShuffleOfIntrinsics(I))
return true;
if (foldSelectShuffle(I))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt -passes=vector-combine -S -mtriple=aarch64 %s | FileCheck %s

; This file tests the foldPermuteOfIntrinsic optimization which transforms:
; shuffle(intrinsic(args), poison/undef) -> intrinsic(shuffle(args))
; when the shuffle is a permute (operates on single vector) and cost model
; determines the transformation is beneficial.

;; ============================================================================
;; Positive Tests - Should Optimize
;; ============================================================================

define <4 x i32> @extract_lower_sadd_sat(<8 x i32> %v1, <8 x i32> %v2) {
; CHECK-LABEL: @extract_lower_sadd_sat(
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x i32> [[V1:%.*]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i32> [[V2:%.*]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[RESULT:%.*]] = call <4 x i32> @llvm.sadd.sat.v4i32(<4 x i32> [[TMP1]], <4 x i32> [[TMP2]])
; CHECK-NEXT: ret <4 x i32> [[RESULT]]
;
%sat = call <8 x i32> @llvm.sadd.sat.v8i32(<8 x i32> %v1, <8 x i32> %v2)
%result = shufflevector <8 x i32> %sat, <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
ret <4 x i32> %result
}

define <4 x i32> @extract_lower_uadd_sat(<8 x i32> %v1, <8 x i32> %v2) {
; CHECK-LABEL: @extract_lower_uadd_sat(
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x i32> [[V1:%.*]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i32> [[V2:%.*]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[RESULT:%.*]] = call <4 x i32> @llvm.uadd.sat.v4i32(<4 x i32> [[TMP1]], <4 x i32> [[TMP2]])
; CHECK-NEXT: ret <4 x i32> [[RESULT]]
;
%sat = call <8 x i32> @llvm.uadd.sat.v8i32(<8 x i32> %v1, <8 x i32> %v2)
%result = shufflevector <8 x i32> %sat, <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
ret <4 x i32> %result
}

define <4 x float> @extract_lower_fma(<8 x float> %a, <8 x float> %b, <8 x float> %c) {
; CHECK-LABEL: @extract_lower_fma(
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x float> [[A:%.*]], <8 x float> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x float> [[B:%.*]], <8 x float> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <8 x float> [[C:%.*]], <8 x float> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[RESULT:%.*]] = call <4 x float> @llvm.fma.v4f32(<4 x float> [[TMP1]], <4 x float> [[TMP2]], <4 x float> [[TMP3]])
; CHECK-NEXT: ret <4 x float> [[RESULT]]
;
%fma = call <8 x float> @llvm.fma.v8f32(<8 x float> %a, <8 x float> %b, <8 x float> %c)
%result = shufflevector <8 x float> %fma, <8 x float> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
ret <4 x float> %result
}

define <4 x i32> @extract_lower_abs_should_not_shuffle_scalar(<8 x i32> %v) {
; CHECK-LABEL: @extract_lower_abs_should_not_shuffle_scalar(
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x i32> [[V:%.*]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[RESULT:%.*]] = call <4 x i32> @llvm.abs.v4i32(<4 x i32> [[TMP1]], i1 false)
; CHECK-NEXT: ret <4 x i32> [[RESULT]]
;
%abs = call <8 x i32> @llvm.abs.v8i32(<8 x i32> %v, i1 false)
%result = shufflevector <8 x i32> %abs, <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
ret <4 x i32> %result
}

define <2 x i64> @extract_lower_i64(<4 x i64> %v1, <4 x i64> %v2) {
; CHECK-LABEL: @extract_lower_i64(
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x i64> [[V1:%.*]], <4 x i64> poison, <2 x i32> <i32 0, i32 1>
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i64> [[V2:%.*]], <4 x i64> poison, <2 x i32> <i32 0, i32 1>
; CHECK-NEXT: [[RESULT:%.*]] = call <2 x i64> @llvm.sadd.sat.v2i64(<2 x i64> [[TMP1]], <2 x i64> [[TMP2]])
; CHECK-NEXT: ret <2 x i64> [[RESULT]]
;
%sat = call <4 x i64> @llvm.sadd.sat.v4i64(<4 x i64> %v1, <4 x i64> %v2)
%result = shufflevector <4 x i64> %sat, <4 x i64> poison, <2 x i32> <i32 0, i32 1>
ret <2 x i64> %result
}

define <8 x i16> @extract_lower_i16(<16 x i16> %v1, <16 x i16> %v2) {
; CHECK-LABEL: @extract_lower_i16(
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <16 x i16> [[V1:%.*]], <16 x i16> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <16 x i16> [[V2:%.*]], <16 x i16> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
; CHECK-NEXT: [[RESULT:%.*]] = call <8 x i16> @llvm.sadd.sat.v8i16(<8 x i16> [[TMP1]], <8 x i16> [[TMP2]])
; CHECK-NEXT: ret <8 x i16> [[RESULT]]
;
%sat = call <16 x i16> @llvm.sadd.sat.v16i16(<16 x i16> %v1, <16 x i16> %v2)
%result = shufflevector <16 x i16> %sat, <16 x i16> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
ret <8 x i16> %result
}

define <4 x i32> @extract_lower_with_undef(<8 x i32> %v1, <8 x i32> %v2) {
; CHECK-LABEL: @extract_lower_with_undef(
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x i32> [[V1:%.*]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i32> [[V2:%.*]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[RESULT:%.*]] = call <4 x i32> @llvm.sadd.sat.v4i32(<4 x i32> [[TMP1]], <4 x i32> [[TMP2]])
; CHECK-NEXT: ret <4 x i32> [[RESULT]]
;
%sat = call <8 x i32> @llvm.sadd.sat.v8i32(<8 x i32> %v1, <8 x i32> %v2)
%result = shufflevector <8 x i32> %sat, <8 x i32> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
ret <4 x i32> %result
}

;; ============================================================================
;; Negative Tests - Should NOT Optimize
;; ============================================================================

define <4 x i32> @same_size_permute(<4 x i32> %v1, <4 x i32> %v2) {
; CHECK-LABEL: @same_size_permute(
; CHECK-NEXT: [[SAT:%.*]] = call <4 x i32> @llvm.sadd.sat.v4i32(<4 x i32> [[V1:%.*]], <4 x i32> [[V2:%.*]])
; CHECK-NEXT: [[RESULT:%.*]] = shufflevector <4 x i32> [[SAT]], <4 x i32> poison, <4 x i32> <i32 2, i32 0, i32 3, i32 1>
; CHECK-NEXT: ret <4 x i32> [[RESULT]]
;
%sat = call <4 x i32> @llvm.sadd.sat.v4i32(<4 x i32> %v1, <4 x i32> %v2)
%result = shufflevector <4 x i32> %sat, <4 x i32> poison, <4 x i32> <i32 2, i32 0, i32 3, i32 1>
ret <4 x i32> %result
}

define <4 x i32> @not_a_permute_uses_second_operand(<4 x i32> %v1, <4 x i32> %v2, <4 x i32> %other) {
; CHECK-LABEL: @not_a_permute_uses_second_operand(
; CHECK-NEXT: [[SAT:%.*]] = call <4 x i32> @llvm.sadd.sat.v4i32(<4 x i32> [[V1:%.*]], <4 x i32> [[V2:%.*]])
; CHECK-NEXT: [[RESULT:%.*]] = shufflevector <4 x i32> [[SAT]], <4 x i32> [[OTHER:%.*]], <4 x i32> <i32 0, i32 4, i32 1, i32 5>
; CHECK-NEXT: ret <4 x i32> [[RESULT]]
;
%sat = call <4 x i32> @llvm.sadd.sat.v4i32(<4 x i32> %v1, <4 x i32> %v2)
%result = shufflevector <4 x i32> %sat, <4 x i32> %other, <4 x i32> <i32 0, i32 4, i32 1, i32 5>
ret <4 x i32> %result
}

define <4 x i32> @not_an_intrinsic(<8 x i32> %v1, <8 x i32> %v2) {
; CHECK-LABEL: @not_an_intrinsic(
; CHECK-NEXT: [[ADD:%.*]] = add <8 x i32> [[V1:%.*]], [[V2:%.*]]
; CHECK-NEXT: [[RESULT:%.*]] = shufflevector <8 x i32> [[ADD]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: ret <4 x i32> [[RESULT]]
;
%add = add <8 x i32> %v1, %v2
%result = shufflevector <8 x i32> %add, <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
ret <4 x i32> %result
}

declare <8 x i32> @llvm.sadd.sat.v8i32(<8 x i32>, <8 x i32>)
declare <4 x i32> @llvm.sadd.sat.v4i32(<4 x i32>, <4 x i32>)
declare <4 x i64> @llvm.sadd.sat.v4i64(<4 x i64>, <4 x i64>)
declare <2 x i64> @llvm.sadd.sat.v2i64(<2 x i64>, <2 x i64>)
declare <16 x i16> @llvm.sadd.sat.v16i16(<16 x i16>, <16 x i16>)
declare <8 x i16> @llvm.sadd.sat.v8i16(<8 x i16>, <8 x i16>)

declare <8 x i32> @llvm.uadd.sat.v8i32(<8 x i32>, <8 x i32>)
declare <4 x i32> @llvm.uadd.sat.v4i32(<4 x i32>, <4 x i32>)

declare <8 x i32> @llvm.abs.v8i32(<8 x i32>, i1 immarg)
declare <4 x i32> @llvm.abs.v4i32(<4 x i32>, i1 immarg)

declare <8 x float> @llvm.fma.v8f32(<8 x float>, <8 x float>, <8 x float>)
declare <4 x float> @llvm.fma.v4f32(<4 x float>, <4 x float>, <4 x float>)
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,9 @@ define <8 x i8> @abs_different(<8 x i8> %a) {

define <4 x i32> @poison_intrinsic(<2 x i16> %l256) {
; CHECK-LABEL: @poison_intrinsic(
; CHECK-NEXT: [[L266:%.*]] = call <2 x i16> @llvm.abs.v2i16(<2 x i16> [[L256:%.*]], i1 false)
; CHECK-NEXT: [[L267:%.*]] = shufflevector <2 x i16> [[L266]], <2 x i16> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
; CHECK-NEXT: [[L271:%.*]] = zext <4 x i16> [[L267]] to <4 x i32>
; CHECK-NEXT: [[L267:%.*]] = shufflevector <2 x i16> [[L266:%.*]], <2 x i16> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
; CHECK-NEXT: [[TMP2:%.*]] = call <4 x i16> @llvm.abs.v4i16(<4 x i16> [[L267]], i1 false)
; CHECK-NEXT: [[L271:%.*]] = zext <4 x i16> [[TMP2]] to <4 x i32>
; CHECK-NEXT: ret <4 x i32> [[L271]]
;
%l266 = call <2 x i16> @llvm.abs.v2i16(<2 x i16> %l256, i1 false)
Expand Down
Loading