Skip to content

Commit e9cc989

Browse files
committed
Address non power of 2 cases
1 parent fb492be commit e9cc989

File tree

2 files changed

+27
-11
lines changed

2 files changed

+27
-11
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3779,15 +3779,22 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
37793779
static_cast<VectorType *>(Arg->getType())->getElementCount();
37803780
if (ReducedVectorElementCount.isFixed()) {
37813781
unsigned VectorSize = ReducedVectorElementCount.getFixedValue();
3782+
Type *SplatType = Splat->getType();
3783+
unsigned SplatTypeWidth = SplatType->getIntegerBitWidth();
3784+
Value *Res;
3785+
// Power of two is a special case. We can just use a left shif here.
37823786
if (isPowerOf2_32(VectorSize)) {
37833787
unsigned Pow2 = Log2_32(VectorSize);
3784-
Value *Res = Builder.CreateShl(
3785-
Splat,
3786-
Constant::getIntegerValue(
3787-
Splat->getType(),
3788-
APInt(Splat->getType()->getIntegerBitWidth(), Pow2)));
3788+
Res = Builder.CreateShl(
3789+
Splat, Constant::getIntegerValue(SplatType,
3790+
APInt(SplatTypeWidth, Pow2)));
37893791
return replaceInstUsesWith(CI, Res);
37903792
}
3793+
// Otherwise just multiply.
3794+
Res = Builder.CreateMul(
3795+
Splat, Constant::getIntegerValue(
3796+
SplatType, APInt(SplatTypeWidth, VectorSize)));
3797+
return replaceInstUsesWith(CI, Res);
37913798
}
37923799
}
37933800
}

llvm/test/Transforms/InstCombine/vector-reductions.ll

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -377,15 +377,24 @@ define i32 @negative_constant_multiplied_at_1(i32 %0) {
377377
ret i32 %4
378378
}
379379

380-
define i32 @negative_constant_multiplied_non_power_of_2(i32 %0) {
381-
; CHECK-LABEL: @negative_constant_multiplied_non_power_of_2(
382-
; CHECK-NEXT: [[TMP2:%.*]] = insertelement <4 x i32> poison, i32 [[TMP0:%.*]], i64 0
383-
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <4 x i32> [[TMP2]], <4 x i32> poison, <6 x i32> zeroinitializer
384-
; CHECK-NEXT: [[TMP4:%.*]] = tail call i32 @llvm.vector.reduce.add.v6i32(<6 x i32> [[TMP3]])
385-
; CHECK-NEXT: ret i32 [[TMP4]]
380+
define i32 @constant_multiplied_non_power_of_2(i32 %0) {
381+
; CHECK-LABEL: @constant_multiplied_non_power_of_2(
382+
; CHECK-NEXT: [[TMP2:%.*]] = mul i32 [[TMP0:%.*]], 6
383+
; CHECK-NEXT: ret i32 [[TMP2]]
386384
;
387385
%2 = insertelement <4 x i32> poison, i32 %0, i64 0
388386
%3 = shufflevector <4 x i32> %2, <4 x i32> poison, <6 x i32> zeroinitializer
389387
%4 = tail call i32 @llvm.vector.reduce.add.v6i32(<6 x i32> %3)
390388
ret i32 %4
391389
}
390+
391+
define i64 @constant_multiplied_non_power_of_2_i64(i64 %0) {
392+
; CHECK-LABEL: @constant_multiplied_non_power_of_2_i64(
393+
; CHECK-NEXT: [[TMP2:%.*]] = mul i64 [[TMP0:%.*]], 6
394+
; CHECK-NEXT: ret i64 [[TMP2]]
395+
;
396+
%2 = insertelement <4 x i64> poison, i64 %0, i64 0
397+
%3 = shufflevector <4 x i64> %2, <4 x i64> poison, <6 x i32> zeroinitializer
398+
%4 = tail call i64 @llvm.vector.reduce.add.v6i64(<6 x i64> %3)
399+
ret i64 %4
400+
}

0 commit comments

Comments
 (0)