Skip to content

Commit fa77c2c

Browse files
committed
[InstCombine] Transform vector.reduce.add (splat %0, 4) into shl i32 %0, 2
Fixes #160066 Whenever we have a vector with all the same elemnts, created with `insertelement` and `shufflevector` and the result type's element number is a power of two and we sum the vector, we have a multiplication by a power of two, which can be replaced with a left shift.
1 parent 44e71c9 commit fa77c2c

File tree

2 files changed

+103
-0
lines changed

2 files changed

+103
-0
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3761,6 +3761,39 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
37613761
return replaceInstUsesWith(CI, Res);
37623762
}
37633763
}
3764+
3765+
// Handle the case where a value is multiplied by a power of two.
3766+
// For example:
3767+
// %2 = insertelement <4 x i32> poison, i32 %0, i64 0
3768+
// %3 = shufflevector <4 x i32> %2, poison, <4 x i32> zeroinitializer
3769+
// %4 = tail call i32 @llvm.vector.reduce.add.v4i32(%3)
3770+
// =>
3771+
// %2 = shl i32 %0, 2
3772+
Value *InputValue;
3773+
ArrayRef<int> Mask;
3774+
ConstantInt *InsertionIdx;
3775+
assert(Arg->getType()->isVectorTy() &&
3776+
"The vector.reduce.add intrinsic's argument must be a vector!");
3777+
3778+
if (match(Arg, m_Shuffle(m_InsertElt(m_Poison(), m_Value(InputValue),
3779+
m_ConstantInt(InsertionIdx)),
3780+
m_Poison(), m_Mask(Mask)))) {
3781+
// It is only a multiplication if we add the same element over and over.
3782+
bool AllElementsAreTheSameInMask =
3783+
std::all_of(Mask.begin(), Mask.end(),
3784+
[&Mask](int MaskElt) { return MaskElt == Mask[0]; });
3785+
unsigned ReducedVectorLength = Mask.size();
3786+
3787+
if (AllElementsAreTheSameInMask &&
3788+
InsertionIdx->getSExtValue() == Mask[0] &&
3789+
isPowerOf2_32(ReducedVectorLength)) {
3790+
unsigned Pow2 = Log2_32(ReducedVectorLength);
3791+
Value *Res = Builder.CreateShl(
3792+
InputValue, Constant::getIntegerValue(InputValue->getType(),
3793+
APInt(32, Pow2)));
3794+
return replaceInstUsesWith(CI, Res);
3795+
}
3796+
}
37643797
}
37653798
[[fallthrough]];
37663799
}

llvm/test/Transforms/InstCombine/vector-reductions.ll

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,3 +308,73 @@ define i32 @diff_of_sums_type_mismatch2(<8 x i32> %v0, <4 x i32> %v1) {
308308
%r = sub i32 %r0, %r1
309309
ret i32 %r
310310
}
311+
312+
define i32 @constant_multiplied_at_0(i32 %0) {
313+
; CHECK-LABEL: @constant_multiplied_at_0(
314+
; CHECK-NEXT: [[TMP2:%.*]] = shl i32 [[TMP0:%.*]], 2
315+
; CHECK-NEXT: ret i32 [[TMP2]]
316+
;
317+
%2 = insertelement <4 x i32> poison, i32 %0, i64 0
318+
%3 = shufflevector <4 x i32> %2, <4 x i32> poison, <4 x i32> zeroinitializer
319+
%4 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %3)
320+
ret i32 %4
321+
}
322+
323+
define i32 @constant_multiplied_at_0_two_pow8(i32 %0) {
324+
; CHECK-LABEL: @constant_multiplied_at_0_two_pow8(
325+
; CHECK-NEXT: [[TMP2:%.*]] = shl i32 [[TMP0:%.*]], 3
326+
; CHECK-NEXT: ret i32 [[TMP2]]
327+
;
328+
%2 = insertelement <4 x i32> poison, i32 %0, i64 0
329+
%3 = shufflevector <4 x i32> %2, <4 x i32> poison, <8 x i32> zeroinitializer
330+
%4 = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %3)
331+
ret i32 %4
332+
}
333+
334+
335+
define i32 @constant_multiplied_at_0_two_pow16(i32 %0) {
336+
; CHECK-LABEL: @constant_multiplied_at_0_two_pow16(
337+
; CHECK-NEXT: [[TMP2:%.*]] = shl i32 [[TMP0:%.*]], 4
338+
; CHECK-NEXT: ret i32 [[TMP2]]
339+
;
340+
%2 = insertelement <4 x i32> poison, i32 %0, i64 0
341+
%3 = shufflevector <4 x i32> %2, <4 x i32> poison, <16 x i32> zeroinitializer
342+
%4 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %3)
343+
ret i32 %4
344+
}
345+
346+
347+
define i32 @constant_multiplied_at_1(i32 %0) {
348+
; CHECK-LABEL: @constant_multiplied_at_1(
349+
; CHECK-NEXT: [[TMP2:%.*]] = shl i32 [[TMP0:%.*]], 2
350+
; CHECK-NEXT: ret i32 [[TMP2]]
351+
;
352+
%2 = insertelement <4 x i32> poison, i32 %0, i64 1
353+
%3 = shufflevector <4 x i32> %2, <4 x i32> poison,
354+
<4 x i32> <i32 1, i32 1, i32 1, i32 1>
355+
%4 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %3)
356+
ret i32 %4
357+
}
358+
359+
define i32 @negative_constant_multiplied_at_1(i32 %0) {
360+
; CHECK-LABEL: @negative_constant_multiplied_at_1(
361+
; CHECK-NEXT: ret i32 poison
362+
;
363+
%2 = insertelement <4 x i32> poison, i32 %0, i64 1
364+
%3 = shufflevector <4 x i32> %2, <4 x i32> poison, <4 x i32> zeroinitializer
365+
%4 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %3)
366+
ret i32 %4
367+
}
368+
369+
define i32 @negative_constant_multiplied_non_power_of_2(i32 %0) {
370+
; CHECK-LABEL: @negative_constant_multiplied_non_power_of_2(
371+
; CHECK-NEXT: [[TMP2:%.*]] = insertelement <4 x i32> poison, i32 [[TMP0:%.*]], i64 0
372+
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <4 x i32> [[TMP2]], <4 x i32> poison, <6 x i32> zeroinitializer
373+
; CHECK-NEXT: [[TMP4:%.*]] = tail call i32 @llvm.vector.reduce.add.v6i32(<6 x i32> [[TMP3]])
374+
; CHECK-NEXT: ret i32 [[TMP4]]
375+
;
376+
%2 = insertelement <4 x i32> poison, i32 %0, i64 0
377+
%3 = shufflevector <4 x i32> %2, <4 x i32> poison, <6 x i32> zeroinitializer
378+
%4 = tail call i32 @llvm.vector.reduce.add.v6i32(<6 x i32> %3)
379+
ret i32 %4
380+
}

0 commit comments

Comments
 (0)