Skip to content

Commit 3538605

Browse files
[LLVM][ConstantFolding] Extend constantFoldVectorReduce to include scalable vectors.
1 parent 16f61ac commit 3538605

File tree

2 files changed

+64
-59
lines changed

2 files changed

+64
-59
lines changed

llvm/lib/Analysis/ConstantFolding.cpp

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2163,18 +2163,39 @@ Constant *ConstantFoldBinaryFP(double (*NativeFP)(double, double),
21632163
}
21642164

21652165
Constant *constantFoldVectorReduce(Intrinsic::ID IID, Constant *Op) {
2166-
FixedVectorType *VT = dyn_cast<FixedVectorType>(Op->getType());
2167-
if (!VT)
2168-
return nullptr;
2169-
2170-
// This isn't strictly necessary, but handle the special/common case of zero:
2171-
// all integer reductions of a zero input produce zero.
2172-
if (isa<ConstantAggregateZero>(Op))
2173-
return ConstantInt::get(VT->getElementType(), 0);
2166+
auto *OpVT = cast<VectorType>(Op->getType());
21742167

21752168
// This is the same as the underlying binops - poison propagates.
21762169
if (isa<PoisonValue>(Op) || Op->containsPoisonElement())
2177-
return PoisonValue::get(VT->getElementType());
2170+
return PoisonValue::get(OpVT->getElementType());
2171+
2172+
// Shortcut non-accumulating reductions.
2173+
if (Constant *SplatVal = Op->getSplatValue()) {
2174+
switch (IID) {
2175+
case Intrinsic::vector_reduce_and:
2176+
case Intrinsic::vector_reduce_or:
2177+
case Intrinsic::vector_reduce_smin:
2178+
case Intrinsic::vector_reduce_smax:
2179+
case Intrinsic::vector_reduce_umin:
2180+
case Intrinsic::vector_reduce_umax:
2181+
return SplatVal;
2182+
case Intrinsic::vector_reduce_add:
2183+
case Intrinsic::vector_reduce_mul:
2184+
if (SplatVal->isZeroValue())
2185+
return SplatVal;
2186+
break;
2187+
case Intrinsic::vector_reduce_xor:
2188+
if (SplatVal->isZeroValue())
2189+
return SplatVal;
2190+
if (OpVT->getElementCount().isKnownMultipleOf(2))
2191+
return Constant::getNullValue(OpVT->getElementType());
2192+
break;
2193+
}
2194+
}
2195+
2196+
FixedVectorType *VT = dyn_cast<FixedVectorType>(OpVT);
2197+
if (!VT)
2198+
return nullptr;
21782199

21792200
// TODO: Handle undef.
21802201
auto *EltC = dyn_cast_or_null<ConstantInt>(Op->getAggregateElement(0U));

llvm/test/Transforms/InstSimplify/ConstProp/vecreduce.ll

Lines changed: 34 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ define i32 @add_0() {
1212

1313
define i32 @add_0_scalable_vector() {
1414
; CHECK-LABEL: @add_0_scalable_vector(
15-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> zeroinitializer)
16-
; CHECK-NEXT: ret i32 [[X]]
15+
; CHECK-NEXT: ret i32 0
1716
;
1817
%x = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> zeroinitializer)
1918
ret i32 %x
@@ -89,8 +88,7 @@ define i32 @add_poison() {
8988

9089
define i32 @add_poison_scalable_vector() {
9190
; CHECK-LABEL: @add_poison_scalable_vector(
92-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> poison)
93-
; CHECK-NEXT: ret i32 [[X]]
91+
; CHECK-NEXT: ret i32 poison
9492
;
9593
%x = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> poison)
9694
ret i32 %x
@@ -123,8 +121,7 @@ define i32 @mul_0() {
123121

124122
define i32 @mul_0_scalable_vector() {
125123
; CHECK-LABEL: @mul_0_scalable_vector(
126-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.mul.nxv8i32(<vscale x 8 x i32> zeroinitializer)
127-
; CHECK-NEXT: ret i32 [[X]]
124+
; CHECK-NEXT: ret i32 0
128125
;
129126
%x = call i32 @llvm.vector.reduce.mul.nxv8i32(<vscale x 8 x i32> zeroinitializer)
130127
ret i32 %x
@@ -200,8 +197,7 @@ define i32 @mul_poison() {
200197

201198
define i32 @mul_poison_scalable_vector() {
202199
; CHECK-LABEL: @mul_poison_scalable_vector(
203-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.mul.nxv8i32(<vscale x 8 x i32> poison)
204-
; CHECK-NEXT: ret i32 [[X]]
200+
; CHECK-NEXT: ret i32 poison
205201
;
206202
%x = call i32 @llvm.vector.reduce.mul.nxv8i32(<vscale x 8 x i32> poison)
207203
ret i32 %x
@@ -225,8 +221,7 @@ define i32 @and_0() {
225221

226222
define i32 @and_0_scalable_vector() {
227223
; CHECK-LABEL: @and_0_scalable_vector(
228-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.and.nxv8i32(<vscale x 8 x i32> zeroinitializer)
229-
; CHECK-NEXT: ret i32 [[X]]
224+
; CHECK-NEXT: ret i32 0
230225
;
231226
%x = call i32 @llvm.vector.reduce.and.nxv8i32(<vscale x 8 x i32> zeroinitializer)
232227
ret i32 %x
@@ -242,8 +237,7 @@ define i32 @and_1() {
242237

243238
define i32 @and_1_scalable_vector() {
244239
; CHECK-LABEL: @and_1_scalable_vector(
245-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.and.nxv8i32(<vscale x 8 x i32> splat (i32 1))
246-
; CHECK-NEXT: ret i32 [[X]]
240+
; CHECK-NEXT: ret i32 1
247241
;
248242
%x = call i32 @llvm.vector.reduce.and.nxv8i32(<vscale x 8 x i32> splat (i32 1))
249243
ret i32 %x
@@ -302,8 +296,7 @@ define i32 @and_poison() {
302296

303297
define i32 @and_poison_scalable_vector() {
304298
; CHECK-LABEL: @and_poison_scalable_vector(
305-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.and.nxv8i32(<vscale x 8 x i32> poison)
306-
; CHECK-NEXT: ret i32 [[X]]
299+
; CHECK-NEXT: ret i32 poison
307300
;
308301
%x = call i32 @llvm.vector.reduce.and.nxv8i32(<vscale x 8 x i32> poison)
309302
ret i32 %x
@@ -327,8 +320,7 @@ define i32 @or_0() {
327320

328321
define i32 @or_0_scalable_vector() {
329322
; CHECK-LABEL: @or_0_scalable_vector(
330-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.or.nxv8i32(<vscale x 8 x i32> zeroinitializer)
331-
; CHECK-NEXT: ret i32 [[X]]
323+
; CHECK-NEXT: ret i32 0
332324
;
333325
%x = call i32 @llvm.vector.reduce.or.nxv8i32(<vscale x 8 x i32> zeroinitializer)
334326
ret i32 %x
@@ -344,8 +336,7 @@ define i32 @or_1() {
344336

345337
define i32 @or_1_scalable_vector() {
346338
; CHECK-LABEL: @or_1_scalable_vector(
347-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.or.nxv8i32(<vscale x 8 x i32> splat (i32 1))
348-
; CHECK-NEXT: ret i32 [[X]]
339+
; CHECK-NEXT: ret i32 1
349340
;
350341
%x = call i32 @llvm.vector.reduce.or.nxv8i32(<vscale x 8 x i32> splat (i32 1))
351342
ret i32 %x
@@ -404,8 +395,7 @@ define i32 @or_poison() {
404395

405396
define i32 @or_poison_scalable_vector() {
406397
; CHECK-LABEL: @or_poison_scalable_vector(
407-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.or.nxv8i32(<vscale x 8 x i32> poison)
408-
; CHECK-NEXT: ret i32 [[X]]
398+
; CHECK-NEXT: ret i32 poison
409399
;
410400
%x = call i32 @llvm.vector.reduce.or.nxv8i32(<vscale x 8 x i32> poison)
411401
ret i32 %x
@@ -429,8 +419,7 @@ define i32 @xor_0() {
429419

430420
define i32 @xor_0_scalable_vector() {
431421
; CHECK-LABEL: @xor_0_scalable_vector(
432-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.xor.nxv8i32(<vscale x 8 x i32> zeroinitializer)
433-
; CHECK-NEXT: ret i32 [[X]]
422+
; CHECK-NEXT: ret i32 0
434423
;
435424
%x = call i32 @llvm.vector.reduce.xor.nxv8i32(<vscale x 8 x i32> zeroinitializer)
436425
ret i32 %x
@@ -446,13 +435,21 @@ define i32 @xor_1() {
446435

447436
define i32 @xor_1_scalable_vector() {
448437
; CHECK-LABEL: @xor_1_scalable_vector(
449-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.xor.nxv8i32(<vscale x 8 x i32> splat (i32 1))
450-
; CHECK-NEXT: ret i32 [[X]]
438+
; CHECK-NEXT: ret i32 0
451439
;
452440
%x = call i32 @llvm.vector.reduce.xor.nxv8i32(<vscale x 8 x i32> splat(i32 1))
453441
ret i32 %x
454442
}
455443

444+
define i32 @xor_1_scalable_vector_lane_count_not_known_even() {
445+
; CHECK-LABEL: @xor_1_scalable_vector_lane_count_not_known_even(
446+
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.xor.nxv1i32(<vscale x 1 x i32> splat (i32 1))
447+
; CHECK-NEXT: ret i32 [[X]]
448+
;
449+
%x = call i32 @llvm.vector.reduce.xor.nxv8i32(<vscale x 1 x i32> splat(i32 1))
450+
ret i32 %x
451+
}
452+
456453
define i32 @xor_inc() {
457454
; CHECK-LABEL: @xor_inc(
458455
; CHECK-NEXT: ret i32 10
@@ -506,8 +503,7 @@ define i32 @xor_poison() {
506503

507504
define i32 @xor_poison_scalable_vector() {
508505
; CHECK-LABEL: @xor_poison_scalable_vector(
509-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.xor.nxv8i32(<vscale x 8 x i32> poison)
510-
; CHECK-NEXT: ret i32 [[X]]
506+
; CHECK-NEXT: ret i32 poison
511507
;
512508
%x = call i32 @llvm.vector.reduce.xor.nxv8i32(<vscale x 8 x i32> poison)
513509
ret i32 %x
@@ -531,8 +527,7 @@ define i32 @smin_0() {
531527

532528
define i32 @smin_0_scalable_vector() {
533529
; CHECK-LABEL: @smin_0_scalable_vector(
534-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.smin.nxv8i32(<vscale x 8 x i32> zeroinitializer)
535-
; CHECK-NEXT: ret i32 [[X]]
530+
; CHECK-NEXT: ret i32 0
536531
;
537532
%x = call i32 @llvm.vector.reduce.smin.nxv8i32(<vscale x 8 x i32> zeroinitializer)
538533
ret i32 %x
@@ -548,8 +543,7 @@ define i32 @smin_1() {
548543

549544
define i32 @smin_1_scalable_vector() {
550545
; CHECK-LABEL: @smin_1_scalable_vector(
551-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.smin.nxv8i32(<vscale x 8 x i32> splat (i32 1))
552-
; CHECK-NEXT: ret i32 [[X]]
546+
; CHECK-NEXT: ret i32 1
553547
;
554548
%x = call i32 @llvm.vector.reduce.smin.nxv8i32(<vscale x 8 x i32> splat(i32 1))
555549
ret i32 %x
@@ -608,8 +602,7 @@ define i32 @smin_poison() {
608602

609603
define i32 @smin_poison_scalable_vector() {
610604
; CHECK-LABEL: @smin_poison_scalable_vector(
611-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.smin.nxv8i32(<vscale x 8 x i32> poison)
612-
; CHECK-NEXT: ret i32 [[X]]
605+
; CHECK-NEXT: ret i32 poison
613606
;
614607
%x = call i32 @llvm.vector.reduce.smin.nxv8i32(<vscale x 8 x i32> poison)
615608
ret i32 %x
@@ -633,8 +626,7 @@ define i32 @smax_0() {
633626

634627
define i32 @smax_0_scalable_vector() {
635628
; CHECK-LABEL: @smax_0_scalable_vector(
636-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.smax.nxv8i32(<vscale x 8 x i32> zeroinitializer)
637-
; CHECK-NEXT: ret i32 [[X]]
629+
; CHECK-NEXT: ret i32 0
638630
;
639631
%x = call i32 @llvm.vector.reduce.smax.nxv8i32(<vscale x 8 x i32> zeroinitializer)
640632
ret i32 %x
@@ -650,8 +642,7 @@ define i32 @smax_1() {
650642

651643
define i32 @smax_1_scalable_vector() {
652644
; CHECK-LABEL: @smax_1_scalable_vector(
653-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.smax.nxv8i32(<vscale x 8 x i32> splat (i32 1))
654-
; CHECK-NEXT: ret i32 [[X]]
645+
; CHECK-NEXT: ret i32 1
655646
;
656647
%x = call i32 @llvm.vector.reduce.smax.nxv8i32(<vscale x 8 x i32> splat(i32 1))
657648
ret i32 %x
@@ -710,8 +701,7 @@ define i32 @smax_poison() {
710701

711702
define i32 @smax_poison_scalable_vector() {
712703
; CHECK-LABEL: @smax_poison_scalable_vector(
713-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.smax.nxv8i32(<vscale x 8 x i32> poison)
714-
; CHECK-NEXT: ret i32 [[X]]
704+
; CHECK-NEXT: ret i32 poison
715705
;
716706
%x = call i32 @llvm.vector.reduce.smax.nxv8i32(<vscale x 8 x i32> poison)
717707
ret i32 %x
@@ -735,8 +725,7 @@ define i32 @umin_0() {
735725

736726
define i32 @umin_0_scalable_vector() {
737727
; CHECK-LABEL: @umin_0_scalable_vector(
738-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.umin.nxv8i32(<vscale x 8 x i32> zeroinitializer)
739-
; CHECK-NEXT: ret i32 [[X]]
728+
; CHECK-NEXT: ret i32 0
740729
;
741730
%x = call i32 @llvm.vector.reduce.umin.nxv8i32(<vscale x 8 x i32> zeroinitializer)
742731
ret i32 %x
@@ -752,8 +741,7 @@ define i32 @umin_1() {
752741

753742
define i32 @umin_1_scalable_vector() {
754743
; CHECK-LABEL: @umin_1_scalable_vector(
755-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.umin.nxv8i32(<vscale x 8 x i32> splat (i32 1))
756-
; CHECK-NEXT: ret i32 [[X]]
744+
; CHECK-NEXT: ret i32 1
757745
;
758746
%x = call i32 @llvm.vector.reduce.umin.nxv8i32(<vscale x 8 x i32> splat (i32 1))
759747
ret i32 %x
@@ -812,8 +800,7 @@ define i32 @umin_poison() {
812800

813801
define i32 @umin_poison_scalable_vector() {
814802
; CHECK-LABEL: @umin_poison_scalable_vector(
815-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.umin.nxv8i32(<vscale x 8 x i32> poison)
816-
; CHECK-NEXT: ret i32 [[X]]
803+
; CHECK-NEXT: ret i32 poison
817804
;
818805
%x = call i32 @llvm.vector.reduce.umin.nxv8i32(<vscale x 8 x i32> poison)
819806
ret i32 %x
@@ -837,8 +824,7 @@ define i32 @umax_0() {
837824

838825
define i32 @umax_0_scalable_vector() {
839826
; CHECK-LABEL: @umax_0_scalable_vector(
840-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.umax.nxv8i32(<vscale x 8 x i32> zeroinitializer)
841-
; CHECK-NEXT: ret i32 [[X]]
827+
; CHECK-NEXT: ret i32 0
842828
;
843829
%x = call i32 @llvm.vector.reduce.umax.nxv8i32(<vscale x 8 x i32> zeroinitializer)
844830
ret i32 %x
@@ -854,8 +840,7 @@ define i32 @umax_1() {
854840

855841
define i32 @umax_1_scalable_vector() {
856842
; CHECK-LABEL: @umax_1_scalable_vector(
857-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.umax.nxv8i32(<vscale x 8 x i32> splat (i32 1))
858-
; CHECK-NEXT: ret i32 [[X]]
843+
; CHECK-NEXT: ret i32 1
859844
;
860845
%x = call i32 @llvm.vector.reduce.umax.nxv8i32(<vscale x 8 x i32> splat(i32 1))
861846
ret i32 %x
@@ -914,8 +899,7 @@ define i32 @umax_poison() {
914899

915900
define i32 @umax_poison_scalable_vector() {
916901
; CHECK-LABEL: @umax_poison_scalable_vector(
917-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.umax.nxv8i32(<vscale x 8 x i32> poison)
918-
; CHECK-NEXT: ret i32 [[X]]
902+
; CHECK-NEXT: ret i32 poison
919903
;
920904
%x = call i32 @llvm.vector.reduce.umax.nxv8i32(<vscale x 8 x i32> poison)
921905
ret i32 %x

0 commit comments

Comments
 (0)