Skip to content

Commit d297987

Browse files
authored
[InstCombine] Transform vector.reduce.add and splat into multiplication (#161020)
Fixes #160066 Whenever we have a vector with all the same elemnts, created with `insertelement` and `shufflevector` and we sum the vector, we have a multiplication.
1 parent 71be13a commit d297987

File tree

2 files changed

+183
-0
lines changed

2 files changed

+183
-0
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
#include "llvm/Support/KnownBits.h"
6565
#include "llvm/Support/KnownFPClass.h"
6666
#include "llvm/Support/MathExtras.h"
67+
#include "llvm/Support/TypeSize.h"
6768
#include "llvm/Support/raw_ostream.h"
6869
#include "llvm/Transforms/InstCombine/InstCombiner.h"
6970
#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
@@ -3781,6 +3782,17 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
37813782
return replaceInstUsesWith(CI, Res);
37823783
}
37833784
}
3785+
3786+
// vector.reduce.add.vNiM(splat(%x)) -> mul(%x, N)
3787+
if (Value *Splat = getSplatValue(Arg)) {
3788+
ElementCount VecToReduceCount =
3789+
cast<VectorType>(Arg->getType())->getElementCount();
3790+
if (VecToReduceCount.isFixed()) {
3791+
unsigned VectorSize = VecToReduceCount.getFixedValue();
3792+
return BinaryOperator::CreateMul(
3793+
Splat, ConstantInt::get(Splat->getType(), VectorSize));
3794+
}
3795+
}
37843796
}
37853797
[[fallthrough]];
37863798
}

llvm/test/Transforms/InstCombine/vector-reductions.ll

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,3 +308,174 @@ define i32 @diff_of_sums_type_mismatch2(<8 x i32> %v0, <4 x i32> %v1) {
308308
%r = sub i32 %r0, %r1
309309
ret i32 %r
310310
}
311+
312+
define i32 @constant_multiplied_4xi32(i32 %0) {
313+
; CHECK-LABEL: @constant_multiplied_4xi32(
314+
; CHECK-NEXT: [[TMP2:%.*]] = shl i32 [[TMP0:%.*]], 2
315+
; CHECK-NEXT: ret i32 [[TMP2]]
316+
;
317+
%2 = insertelement <4 x i32> poison, i32 %0, i64 0
318+
%3 = shufflevector <4 x i32> %2, <4 x i32> poison, <4 x i32> zeroinitializer
319+
%4 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %3)
320+
ret i32 %4
321+
}
322+
323+
define i32 @constant_multiplied_3xi32(i32 %0) {
324+
; CHECK-LABEL: @constant_multiplied_3xi32(
325+
; CHECK-NEXT: [[TMP2:%.*]] = mul i32 [[TMP0:%.*]], 3
326+
; CHECK-NEXT: ret i32 [[TMP2]]
327+
;
328+
%2 = insertelement <3 x i32> poison, i32 %0, i64 0
329+
%3 = shufflevector <3 x i32> %2, <3 x i32> poison, <3 x i32> zeroinitializer
330+
%4 = tail call i32 @llvm.vector.reduce.add.v3i32(<3 x i32> %3)
331+
ret i32 %4
332+
}
333+
334+
define i64 @constant_multiplied_4xi64(i64 %0) {
335+
; CHECK-LABEL: @constant_multiplied_4xi64(
336+
; CHECK-NEXT: [[TMP2:%.*]] = shl i64 [[TMP0:%.*]], 2
337+
; CHECK-NEXT: ret i64 [[TMP2]]
338+
;
339+
%2 = insertelement <4 x i64> poison, i64 %0, i64 0
340+
%3 = shufflevector <4 x i64> %2, <4 x i64> poison, <4 x i32> zeroinitializer
341+
%4 = tail call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> %3)
342+
ret i64 %4
343+
}
344+
345+
define i32 @constant_multiplied_8xi32(i32 %0) {
346+
; CHECK-LABEL: @constant_multiplied_8xi32(
347+
; CHECK-NEXT: [[TMP2:%.*]] = shl i32 [[TMP0:%.*]], 3
348+
; CHECK-NEXT: ret i32 [[TMP2]]
349+
;
350+
%2 = insertelement <4 x i32> poison, i32 %0, i64 0
351+
%3 = shufflevector <4 x i32> %2, <4 x i32> poison, <8 x i32> zeroinitializer
352+
%4 = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %3)
353+
ret i32 %4
354+
}
355+
356+
357+
define i32 @constant_multiplied_16xi32(i32 %0) {
358+
; CHECK-LABEL: @constant_multiplied_16xi32(
359+
; CHECK-NEXT: [[TMP2:%.*]] = shl i32 [[TMP0:%.*]], 4
360+
; CHECK-NEXT: ret i32 [[TMP2]]
361+
;
362+
%2 = insertelement <4 x i32> poison, i32 %0, i64 0
363+
%3 = shufflevector <4 x i32> %2, <4 x i32> poison, <16 x i32> zeroinitializer
364+
%4 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %3)
365+
ret i32 %4
366+
}
367+
368+
369+
define i32 @constant_multiplied_4xi32_at_idx1(i32 %0) {
370+
; CHECK-LABEL: @constant_multiplied_4xi32_at_idx1(
371+
; CHECK-NEXT: [[TMP2:%.*]] = shl i32 [[TMP0:%.*]], 2
372+
; CHECK-NEXT: ret i32 [[TMP2]]
373+
;
374+
%2 = insertelement <4 x i32> poison, i32 %0, i64 1
375+
%3 = shufflevector <4 x i32> %2, <4 x i32> poison,
376+
<4 x i32> <i32 1, i32 1, i32 1, i32 1>
377+
%4 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %3)
378+
ret i32 %4
379+
}
380+
381+
define i32 @negative_constant_multiplied_4xi32(i32 %0) {
382+
; CHECK-LABEL: @negative_constant_multiplied_4xi32(
383+
; CHECK-NEXT: ret i32 poison
384+
;
385+
%2 = insertelement <4 x i32> poison, i32 %0, i64 1
386+
%3 = shufflevector <4 x i32> %2, <4 x i32> poison, <4 x i32> zeroinitializer
387+
%4 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %3)
388+
ret i32 %4
389+
}
390+
391+
define i32 @constant_multiplied_6xi32(i32 %0) {
392+
; CHECK-LABEL: @constant_multiplied_6xi32(
393+
; CHECK-NEXT: [[TMP2:%.*]] = mul i32 [[TMP0:%.*]], 6
394+
; CHECK-NEXT: ret i32 [[TMP2]]
395+
;
396+
%2 = insertelement <4 x i32> poison, i32 %0, i64 0
397+
%3 = shufflevector <4 x i32> %2, <4 x i32> poison, <6 x i32> zeroinitializer
398+
%4 = tail call i32 @llvm.vector.reduce.add.v6i32(<6 x i32> %3)
399+
ret i32 %4
400+
}
401+
402+
define i64 @constant_multiplied_6xi64(i64 %0) {
403+
; CHECK-LABEL: @constant_multiplied_6xi64(
404+
; CHECK-NEXT: [[TMP2:%.*]] = mul i64 [[TMP0:%.*]], 6
405+
; CHECK-NEXT: ret i64 [[TMP2]]
406+
;
407+
%2 = insertelement <4 x i64> poison, i64 %0, i64 0
408+
%3 = shufflevector <4 x i64> %2, <4 x i64> poison, <6 x i32> zeroinitializer
409+
%4 = tail call i64 @llvm.vector.reduce.add.v6i64(<6 x i64> %3)
410+
ret i64 %4
411+
}
412+
413+
define i1 @constant_multiplied_8xi1(i1 %0) {
414+
; CHECK-LABEL: @constant_multiplied_8xi1(
415+
; CHECK-NEXT: [[TMP2:%.*]] = insertelement <8 x i1> poison, i1 [[TMP0:%.*]], i64 0
416+
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <8 x i1> [[TMP2]], <8 x i1> poison, <8 x i32> zeroinitializer
417+
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <8 x i1> [[TMP3]] to i8
418+
; CHECK-NEXT: [[TMP5:%.*]] = call range(i8 0, 9) i8 @llvm.ctpop.i8(i8 [[TMP4]])
419+
; CHECK-NEXT: [[TMP6:%.*]] = trunc i8 [[TMP5]] to i1
420+
; CHECK-NEXT: ret i1 [[TMP6]]
421+
;
422+
%2 = insertelement <8 x i1> poison, i1 %0, i32 0
423+
%3 = shufflevector <8 x i1> %2, <8 x i1> poison, <8 x i32> zeroinitializer
424+
%4 = tail call i1 @llvm.vector.reduce.add.v8i1(<8 x i1> %3)
425+
ret i1 %4
426+
}
427+
428+
define i2 @constant_multiplied_4xi2(i2 %0) {
429+
; CHECK-LABEL: @constant_multiplied_4xi2(
430+
; CHECK-NEXT: ret i2 0
431+
;
432+
%2 = insertelement <4 x i2> poison, i2 %0, i32 0
433+
%3 = shufflevector <4 x i2> %2, <4 x i2> poison, <4 x i32> zeroinitializer
434+
%4 = tail call i2 @llvm.vector.reduce.add.v4i2(<4 x i2> %3)
435+
ret i2 %4
436+
}
437+
438+
define i2 @constant_multiplied_5xi2(i2 %0) {
439+
; CHECK-LABEL: @constant_multiplied_5xi2(
440+
; CHECK-NEXT: ret i2 [[TMP0:%.*]]
441+
;
442+
%2 = insertelement <5 x i2> poison, i2 %0, i64 0
443+
%3 = shufflevector <5 x i2> %2, <5 x i2> poison, <5 x i32> zeroinitializer
444+
%4 = tail call i2 @llvm.vector.reduce.add.v5i2(<5 x i2> %3)
445+
ret i2 %4
446+
}
447+
448+
define i2 @constant_multiplied_6xi2(i2 %0) {
449+
; CHECK-LABEL: @constant_multiplied_6xi2(
450+
; CHECK-NEXT: [[TMP2:%.*]] = shl i2 [[TMP0:%.*]], 1
451+
; CHECK-NEXT: ret i2 [[TMP2]]
452+
;
453+
%2 = insertelement <6 x i2> poison, i2 %0, i64 0
454+
%3 = shufflevector <6 x i2> %2, <6 x i2> poison, <6 x i32> zeroinitializer
455+
%4 = tail call i2 @llvm.vector.reduce.add.v6i2(<6 x i2> %3)
456+
ret i2 %4
457+
}
458+
459+
define i2 @constant_multiplied_7xi2(i2 %0) {
460+
; CHECK-LABEL: @constant_multiplied_7xi2(
461+
; CHECK-NEXT: [[TMP2:%.*]] = sub i2 0, [[TMP0:%.*]]
462+
; CHECK-NEXT: ret i2 [[TMP2]]
463+
;
464+
%2 = insertelement <7 x i2> poison, i2 %0, i64 0
465+
%3 = shufflevector <7 x i2> %2, <7 x i2> poison, <7 x i32> zeroinitializer
466+
%4 = tail call i2 @llvm.vector.reduce.add.v7i2(<7 x i2> %3)
467+
ret i2 %4
468+
}
469+
470+
define i32 @negative_scalable_vector(i32 %0) {
471+
; CHECK-LABEL: @negative_scalable_vector(
472+
; CHECK-NEXT: [[TMP2:%.*]] = insertelement <vscale x 4 x i32> poison, i32 [[TMP0:%.*]], i64 0
473+
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <vscale x 4 x i32> [[TMP2]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
474+
; CHECK-NEXT: [[TMP4:%.*]] = tail call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[TMP3]])
475+
; CHECK-NEXT: ret i32 [[TMP4]]
476+
;
477+
%2 = insertelement <vscale x 4 x i32> poison, i32 %0, i64 0
478+
%3 = shufflevector <vscale x 4 x i32> %2, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
479+
%4 = tail call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> %3)
480+
ret i32 %4
481+
}

0 commit comments

Comments
 (0)