Skip to content

Commit 96c6fd3

Browse files
[LLVM][ConstantFolding] Extend constantFoldVectorReduce to include scalable vectors. (#165437)
1 parent 8ea447b commit 96c6fd3

File tree

2 files changed

+86
-62
lines changed

2 files changed

+86
-62
lines changed

llvm/lib/Analysis/ConstantFolding.cpp

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2163,18 +2163,42 @@ Constant *ConstantFoldBinaryFP(double (*NativeFP)(double, double),
21632163
}
21642164

21652165
Constant *constantFoldVectorReduce(Intrinsic::ID IID, Constant *Op) {
2166-
FixedVectorType *VT = dyn_cast<FixedVectorType>(Op->getType());
2167-
if (!VT)
2168-
return nullptr;
2169-
2170-
// This isn't strictly necessary, but handle the special/common case of zero:
2171-
// all integer reductions of a zero input produce zero.
2172-
if (isa<ConstantAggregateZero>(Op))
2173-
return ConstantInt::get(VT->getElementType(), 0);
2166+
auto *OpVT = cast<VectorType>(Op->getType());
21742167

21752168
// This is the same as the underlying binops - poison propagates.
2176-
if (isa<PoisonValue>(Op) || Op->containsPoisonElement())
2177-
return PoisonValue::get(VT->getElementType());
2169+
if (Op->containsPoisonElement())
2170+
return PoisonValue::get(OpVT->getElementType());
2171+
2172+
// Shortcut non-accumulating reductions.
2173+
if (Constant *SplatVal = Op->getSplatValue()) {
2174+
switch (IID) {
2175+
case Intrinsic::vector_reduce_and:
2176+
case Intrinsic::vector_reduce_or:
2177+
case Intrinsic::vector_reduce_smin:
2178+
case Intrinsic::vector_reduce_smax:
2179+
case Intrinsic::vector_reduce_umin:
2180+
case Intrinsic::vector_reduce_umax:
2181+
return SplatVal;
2182+
case Intrinsic::vector_reduce_add:
2183+
if (SplatVal->isNullValue())
2184+
return SplatVal;
2185+
break;
2186+
case Intrinsic::vector_reduce_mul:
2187+
if (SplatVal->isNullValue() || SplatVal->isOneValue())
2188+
return SplatVal;
2189+
break;
2190+
case Intrinsic::vector_reduce_xor:
2191+
if (SplatVal->isNullValue())
2192+
return SplatVal;
2193+
if (OpVT->getElementCount().isKnownMultipleOf(2))
2194+
return Constant::getNullValue(OpVT->getElementType());
2195+
break;
2196+
}
2197+
}
2198+
2199+
FixedVectorType *VT = dyn_cast<FixedVectorType>(OpVT);
2200+
if (!VT)
2201+
return nullptr;
21782202

21792203
// TODO: Handle undef.
21802204
auto *EltC = dyn_cast_or_null<ConstantInt>(Op->getAggregateElement(0U));

llvm/test/Transforms/InstSimplify/ConstProp/vecreduce.ll

Lines changed: 52 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ define i32 @add_0() {
1212

1313
define i32 @add_0_scalable_vector() {
1414
; CHECK-LABEL: @add_0_scalable_vector(
15-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> zeroinitializer)
16-
; CHECK-NEXT: ret i32 [[X]]
15+
; CHECK-NEXT: ret i32 0
1716
;
1817
%x = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> zeroinitializer)
1918
ret i32 %x
@@ -89,8 +88,7 @@ define i32 @add_poison() {
8988

9089
define i32 @add_poison_scalable_vector() {
9190
; CHECK-LABEL: @add_poison_scalable_vector(
92-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> poison)
93-
; CHECK-NEXT: ret i32 [[X]]
91+
; CHECK-NEXT: ret i32 poison
9492
;
9593
%x = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> poison)
9694
ret i32 %x
@@ -123,8 +121,7 @@ define i32 @mul_0() {
123121

124122
define i32 @mul_0_scalable_vector() {
125123
; CHECK-LABEL: @mul_0_scalable_vector(
126-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.mul.nxv8i32(<vscale x 8 x i32> zeroinitializer)
127-
; CHECK-NEXT: ret i32 [[X]]
124+
; CHECK-NEXT: ret i32 0
128125
;
129126
%x = call i32 @llvm.vector.reduce.mul.nxv8i32(<vscale x 8 x i32> zeroinitializer)
130127
ret i32 %x
@@ -140,13 +137,29 @@ define i32 @mul_1() {
140137

141138
define i32 @mul_1_scalable_vector() {
142139
; CHECK-LABEL: @mul_1_scalable_vector(
143-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.mul.nxv8i32(<vscale x 8 x i32> splat (i32 1))
144-
; CHECK-NEXT: ret i32 [[X]]
140+
; CHECK-NEXT: ret i32 1
145141
;
146142
%x = call i32 @llvm.vector.reduce.mul.nxv8i32(<vscale x 8 x i32> splat (i32 1))
147143
ret i32 %x
148144
}
149145

146+
define i32 @mul_2() {
147+
; CHECK-LABEL: @mul_2(
148+
; CHECK-NEXT: ret i32 256
149+
;
150+
%x = call i32 @llvm.vector.reduce.mul.v8i32(<8 x i32> <i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2>)
151+
ret i32 %x
152+
}
153+
154+
define i32 @mul_2_scalable_vector() {
155+
; CHECK-LABEL: @mul_2_scalable_vector(
156+
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.mul.nxv8i32(<vscale x 8 x i32> splat (i32 2))
157+
; CHECK-NEXT: ret i32 [[X]]
158+
;
159+
%x = call i32 @llvm.vector.reduce.mul.nxv8i32(<vscale x 8 x i32> splat (i32 2))
160+
ret i32 %x
161+
}
162+
150163
define i32 @mul_inc() {
151164
; CHECK-LABEL: @mul_inc(
152165
; CHECK-NEXT: ret i32 40320
@@ -200,8 +213,7 @@ define i32 @mul_poison() {
200213

201214
define i32 @mul_poison_scalable_vector() {
202215
; CHECK-LABEL: @mul_poison_scalable_vector(
203-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.mul.nxv8i32(<vscale x 8 x i32> poison)
204-
; CHECK-NEXT: ret i32 [[X]]
216+
; CHECK-NEXT: ret i32 poison
205217
;
206218
%x = call i32 @llvm.vector.reduce.mul.nxv8i32(<vscale x 8 x i32> poison)
207219
ret i32 %x
@@ -225,8 +237,7 @@ define i32 @and_0() {
225237

226238
define i32 @and_0_scalable_vector() {
227239
; CHECK-LABEL: @and_0_scalable_vector(
228-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.and.nxv8i32(<vscale x 8 x i32> zeroinitializer)
229-
; CHECK-NEXT: ret i32 [[X]]
240+
; CHECK-NEXT: ret i32 0
230241
;
231242
%x = call i32 @llvm.vector.reduce.and.nxv8i32(<vscale x 8 x i32> zeroinitializer)
232243
ret i32 %x
@@ -242,8 +253,7 @@ define i32 @and_1() {
242253

243254
define i32 @and_1_scalable_vector() {
244255
; CHECK-LABEL: @and_1_scalable_vector(
245-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.and.nxv8i32(<vscale x 8 x i32> splat (i32 1))
246-
; CHECK-NEXT: ret i32 [[X]]
256+
; CHECK-NEXT: ret i32 1
247257
;
248258
%x = call i32 @llvm.vector.reduce.and.nxv8i32(<vscale x 8 x i32> splat (i32 1))
249259
ret i32 %x
@@ -302,8 +312,7 @@ define i32 @and_poison() {
302312

303313
define i32 @and_poison_scalable_vector() {
304314
; CHECK-LABEL: @and_poison_scalable_vector(
305-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.and.nxv8i32(<vscale x 8 x i32> poison)
306-
; CHECK-NEXT: ret i32 [[X]]
315+
; CHECK-NEXT: ret i32 poison
307316
;
308317
%x = call i32 @llvm.vector.reduce.and.nxv8i32(<vscale x 8 x i32> poison)
309318
ret i32 %x
@@ -327,8 +336,7 @@ define i32 @or_0() {
327336

328337
define i32 @or_0_scalable_vector() {
329338
; CHECK-LABEL: @or_0_scalable_vector(
330-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.or.nxv8i32(<vscale x 8 x i32> zeroinitializer)
331-
; CHECK-NEXT: ret i32 [[X]]
339+
; CHECK-NEXT: ret i32 0
332340
;
333341
%x = call i32 @llvm.vector.reduce.or.nxv8i32(<vscale x 8 x i32> zeroinitializer)
334342
ret i32 %x
@@ -344,8 +352,7 @@ define i32 @or_1() {
344352

345353
define i32 @or_1_scalable_vector() {
346354
; CHECK-LABEL: @or_1_scalable_vector(
347-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.or.nxv8i32(<vscale x 8 x i32> splat (i32 1))
348-
; CHECK-NEXT: ret i32 [[X]]
355+
; CHECK-NEXT: ret i32 1
349356
;
350357
%x = call i32 @llvm.vector.reduce.or.nxv8i32(<vscale x 8 x i32> splat (i32 1))
351358
ret i32 %x
@@ -404,8 +411,7 @@ define i32 @or_poison() {
404411

405412
define i32 @or_poison_scalable_vector() {
406413
; CHECK-LABEL: @or_poison_scalable_vector(
407-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.or.nxv8i32(<vscale x 8 x i32> poison)
408-
; CHECK-NEXT: ret i32 [[X]]
414+
; CHECK-NEXT: ret i32 poison
409415
;
410416
%x = call i32 @llvm.vector.reduce.or.nxv8i32(<vscale x 8 x i32> poison)
411417
ret i32 %x
@@ -429,8 +435,7 @@ define i32 @xor_0() {
429435

430436
define i32 @xor_0_scalable_vector() {
431437
; CHECK-LABEL: @xor_0_scalable_vector(
432-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.xor.nxv8i32(<vscale x 8 x i32> zeroinitializer)
433-
; CHECK-NEXT: ret i32 [[X]]
438+
; CHECK-NEXT: ret i32 0
434439
;
435440
%x = call i32 @llvm.vector.reduce.xor.nxv8i32(<vscale x 8 x i32> zeroinitializer)
436441
ret i32 %x
@@ -446,13 +451,21 @@ define i32 @xor_1() {
446451

447452
define i32 @xor_1_scalable_vector() {
448453
; CHECK-LABEL: @xor_1_scalable_vector(
449-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.xor.nxv8i32(<vscale x 8 x i32> splat (i32 1))
450-
; CHECK-NEXT: ret i32 [[X]]
454+
; CHECK-NEXT: ret i32 0
451455
;
452456
%x = call i32 @llvm.vector.reduce.xor.nxv8i32(<vscale x 8 x i32> splat(i32 1))
453457
ret i32 %x
454458
}
455459

460+
define i32 @xor_1_scalable_vector_lane_count_not_known_even() {
461+
; CHECK-LABEL: @xor_1_scalable_vector_lane_count_not_known_even(
462+
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.xor.nxv1i32(<vscale x 1 x i32> splat (i32 1))
463+
; CHECK-NEXT: ret i32 [[X]]
464+
;
465+
%x = call i32 @llvm.vector.reduce.xor.nxv8i32(<vscale x 1 x i32> splat(i32 1))
466+
ret i32 %x
467+
}
468+
456469
define i32 @xor_inc() {
457470
; CHECK-LABEL: @xor_inc(
458471
; CHECK-NEXT: ret i32 10
@@ -506,8 +519,7 @@ define i32 @xor_poison() {
506519

507520
define i32 @xor_poison_scalable_vector() {
508521
; CHECK-LABEL: @xor_poison_scalable_vector(
509-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.xor.nxv8i32(<vscale x 8 x i32> poison)
510-
; CHECK-NEXT: ret i32 [[X]]
522+
; CHECK-NEXT: ret i32 poison
511523
;
512524
%x = call i32 @llvm.vector.reduce.xor.nxv8i32(<vscale x 8 x i32> poison)
513525
ret i32 %x
@@ -531,8 +543,7 @@ define i32 @smin_0() {
531543

532544
define i32 @smin_0_scalable_vector() {
533545
; CHECK-LABEL: @smin_0_scalable_vector(
534-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.smin.nxv8i32(<vscale x 8 x i32> zeroinitializer)
535-
; CHECK-NEXT: ret i32 [[X]]
546+
; CHECK-NEXT: ret i32 0
536547
;
537548
%x = call i32 @llvm.vector.reduce.smin.nxv8i32(<vscale x 8 x i32> zeroinitializer)
538549
ret i32 %x
@@ -548,8 +559,7 @@ define i32 @smin_1() {
548559

549560
define i32 @smin_1_scalable_vector() {
550561
; CHECK-LABEL: @smin_1_scalable_vector(
551-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.smin.nxv8i32(<vscale x 8 x i32> splat (i32 1))
552-
; CHECK-NEXT: ret i32 [[X]]
562+
; CHECK-NEXT: ret i32 1
553563
;
554564
%x = call i32 @llvm.vector.reduce.smin.nxv8i32(<vscale x 8 x i32> splat(i32 1))
555565
ret i32 %x
@@ -608,8 +618,7 @@ define i32 @smin_poison() {
608618

609619
define i32 @smin_poison_scalable_vector() {
610620
; CHECK-LABEL: @smin_poison_scalable_vector(
611-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.smin.nxv8i32(<vscale x 8 x i32> poison)
612-
; CHECK-NEXT: ret i32 [[X]]
621+
; CHECK-NEXT: ret i32 poison
613622
;
614623
%x = call i32 @llvm.vector.reduce.smin.nxv8i32(<vscale x 8 x i32> poison)
615624
ret i32 %x
@@ -633,8 +642,7 @@ define i32 @smax_0() {
633642

634643
define i32 @smax_0_scalable_vector() {
635644
; CHECK-LABEL: @smax_0_scalable_vector(
636-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.smax.nxv8i32(<vscale x 8 x i32> zeroinitializer)
637-
; CHECK-NEXT: ret i32 [[X]]
645+
; CHECK-NEXT: ret i32 0
638646
;
639647
%x = call i32 @llvm.vector.reduce.smax.nxv8i32(<vscale x 8 x i32> zeroinitializer)
640648
ret i32 %x
@@ -650,8 +658,7 @@ define i32 @smax_1() {
650658

651659
define i32 @smax_1_scalable_vector() {
652660
; CHECK-LABEL: @smax_1_scalable_vector(
653-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.smax.nxv8i32(<vscale x 8 x i32> splat (i32 1))
654-
; CHECK-NEXT: ret i32 [[X]]
661+
; CHECK-NEXT: ret i32 1
655662
;
656663
%x = call i32 @llvm.vector.reduce.smax.nxv8i32(<vscale x 8 x i32> splat(i32 1))
657664
ret i32 %x
@@ -710,8 +717,7 @@ define i32 @smax_poison() {
710717

711718
define i32 @smax_poison_scalable_vector() {
712719
; CHECK-LABEL: @smax_poison_scalable_vector(
713-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.smax.nxv8i32(<vscale x 8 x i32> poison)
714-
; CHECK-NEXT: ret i32 [[X]]
720+
; CHECK-NEXT: ret i32 poison
715721
;
716722
%x = call i32 @llvm.vector.reduce.smax.nxv8i32(<vscale x 8 x i32> poison)
717723
ret i32 %x
@@ -735,8 +741,7 @@ define i32 @umin_0() {
735741

736742
define i32 @umin_0_scalable_vector() {
737743
; CHECK-LABEL: @umin_0_scalable_vector(
738-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.umin.nxv8i32(<vscale x 8 x i32> zeroinitializer)
739-
; CHECK-NEXT: ret i32 [[X]]
744+
; CHECK-NEXT: ret i32 0
740745
;
741746
%x = call i32 @llvm.vector.reduce.umin.nxv8i32(<vscale x 8 x i32> zeroinitializer)
742747
ret i32 %x
@@ -752,8 +757,7 @@ define i32 @umin_1() {
752757

753758
define i32 @umin_1_scalable_vector() {
754759
; CHECK-LABEL: @umin_1_scalable_vector(
755-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.umin.nxv8i32(<vscale x 8 x i32> splat (i32 1))
756-
; CHECK-NEXT: ret i32 [[X]]
760+
; CHECK-NEXT: ret i32 1
757761
;
758762
%x = call i32 @llvm.vector.reduce.umin.nxv8i32(<vscale x 8 x i32> splat (i32 1))
759763
ret i32 %x
@@ -812,8 +816,7 @@ define i32 @umin_poison() {
812816

813817
define i32 @umin_poison_scalable_vector() {
814818
; CHECK-LABEL: @umin_poison_scalable_vector(
815-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.umin.nxv8i32(<vscale x 8 x i32> poison)
816-
; CHECK-NEXT: ret i32 [[X]]
819+
; CHECK-NEXT: ret i32 poison
817820
;
818821
%x = call i32 @llvm.vector.reduce.umin.nxv8i32(<vscale x 8 x i32> poison)
819822
ret i32 %x
@@ -837,8 +840,7 @@ define i32 @umax_0() {
837840

838841
define i32 @umax_0_scalable_vector() {
839842
; CHECK-LABEL: @umax_0_scalable_vector(
840-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.umax.nxv8i32(<vscale x 8 x i32> zeroinitializer)
841-
; CHECK-NEXT: ret i32 [[X]]
843+
; CHECK-NEXT: ret i32 0
842844
;
843845
%x = call i32 @llvm.vector.reduce.umax.nxv8i32(<vscale x 8 x i32> zeroinitializer)
844846
ret i32 %x
@@ -854,8 +856,7 @@ define i32 @umax_1() {
854856

855857
define i32 @umax_1_scalable_vector() {
856858
; CHECK-LABEL: @umax_1_scalable_vector(
857-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.umax.nxv8i32(<vscale x 8 x i32> splat (i32 1))
858-
; CHECK-NEXT: ret i32 [[X]]
859+
; CHECK-NEXT: ret i32 1
859860
;
860861
%x = call i32 @llvm.vector.reduce.umax.nxv8i32(<vscale x 8 x i32> splat(i32 1))
861862
ret i32 %x
@@ -914,8 +915,7 @@ define i32 @umax_poison() {
914915

915916
define i32 @umax_poison_scalable_vector() {
916917
; CHECK-LABEL: @umax_poison_scalable_vector(
917-
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.vector.reduce.umax.nxv8i32(<vscale x 8 x i32> poison)
918-
; CHECK-NEXT: ret i32 [[X]]
918+
; CHECK-NEXT: ret i32 poison
919919
;
920920
%x = call i32 @llvm.vector.reduce.umax.nxv8i32(<vscale x 8 x i32> poison)
921921
ret i32 %x

0 commit comments

Comments
 (0)