Skip to content

Commit 7d90e64

Browse files
committed
Extend vector.reduce.add and splat transform to scalable vectors
1 parent d297987 commit 7d90e64

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3785,13 +3785,19 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
37853785

37863786
// vector.reduce.add.vNiM(splat(%x)) -> mul(%x, N)
37873787
if (Value *Splat = getSplatValue(Arg)) {
3788-
ElementCount VecToReduceCount =
3789-
cast<VectorType>(Arg->getType())->getElementCount();
3788+
VectorType *VecToReduceTy = cast<VectorType>(Arg->getType());
3789+
ElementCount VecToReduceCount = VecToReduceTy->getElementCount();
3790+
Value *RHS;
37903791
if (VecToReduceCount.isFixed()) {
37913792
unsigned VectorSize = VecToReduceCount.getFixedValue();
3792-
return BinaryOperator::CreateMul(
3793-
Splat, ConstantInt::get(Splat->getType(), VectorSize));
3793+
RHS = ConstantInt::get(Splat->getType(), VectorSize);
37943794
}
3795+
3796+
RHS = Builder.CreateElementCount(Type::getInt64Ty(II->getContext()),
3797+
VecToReduceCount);
3798+
if (Splat->getType() != RHS->getType())
3799+
RHS = Builder.CreateZExtOrTrunc(RHS, Splat->getType());
3800+
return BinaryOperator::CreateMul(Splat, RHS);
37953801
}
37963802
}
37973803
[[fallthrough]];

llvm/test/Transforms/InstCombine/vector-reductions.ll

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -469,9 +469,10 @@ define i2 @constant_multiplied_7xi2(i2 %0) {
469469

470470
define i32 @negative_scalable_vector(i32 %0) {
471471
; CHECK-LABEL: @negative_scalable_vector(
472-
; CHECK-NEXT: [[TMP2:%.*]] = insertelement <vscale x 4 x i32> poison, i32 [[TMP0:%.*]], i64 0
473-
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <vscale x 4 x i32> [[TMP2]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
474-
; CHECK-NEXT: [[TMP4:%.*]] = tail call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[TMP3]])
472+
; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
473+
; CHECK-NEXT: [[DOTTR:%.*]] = trunc i64 [[TMP2]] to i32
474+
; CHECK-NEXT: [[TMP3:%.*]] = shl i32 [[DOTTR]], 2
475+
; CHECK-NEXT: [[TMP4:%.*]] = mul i32 [[TMP0:%.*]], [[TMP3]]
475476
; CHECK-NEXT: ret i32 [[TMP4]]
476477
;
477478
%2 = insertelement <vscale x 4 x i32> poison, i32 %0, i64 0

0 commit comments

Comments
 (0)