Skip to content

Commit 7c4f188

Browse files
authored
[LV] Support multiplies by constants when forming scaled reductions. (#161092)
We can create partial reductions for multiplies with constants, if the constant is small enough to be extended from source to destination type w/o changing the value. This only handles constant on the right side of a multiply, relying on other passes to canonicalize the input. Alive2 Proofs: https://alive2.llvm.org/ce/z/iWRMr6 PR: #161092
1 parent 2165aa4 commit 7c4f188

File tree

5 files changed

+46
-17
lines changed

5 files changed

+46
-17
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7954,6 +7954,13 @@ bool VPRecipeBuilder::getScaledReductions(
79547954
auto CollectExtInfo = [this, &Exts, &ExtOpTypes,
79557955
&ExtKinds](SmallVectorImpl<Value *> &Ops) -> bool {
79567956
for (const auto &[I, OpI] : enumerate(Ops)) {
7957+
auto *CI = dyn_cast<ConstantInt>(OpI);
7958+
if (I > 0 && CI &&
7959+
canConstantBeExtended(CI, ExtOpTypes[0], ExtKinds[0])) {
7960+
ExtOpTypes[I] = ExtOpTypes[0];
7961+
ExtKinds[I] = ExtKinds[0];
7962+
continue;
7963+
}
79577964
Value *ExtOp;
79587965
if (!match(OpI, m_ZExtOrSExt(m_Value(ExtOp))))
79597966
return false;

llvm/lib/Transforms/Vectorize/VPlan.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1753,6 +1753,16 @@ void LoopVectorizationPlanner::printPlans(raw_ostream &O) {
17531753
}
17541754
#endif
17551755

1756+
bool llvm::canConstantBeExtended(const ConstantInt *CI, Type *NarrowType,
1757+
TTI::PartialReductionExtendKind ExtKind) {
1758+
APInt TruncatedVal = CI->getValue().trunc(NarrowType->getScalarSizeInBits());
1759+
unsigned WideSize = CI->getType()->getScalarSizeInBits();
1760+
APInt ExtendedVal = ExtKind == TTI::PR_SignExtend
1761+
? TruncatedVal.sext(WideSize)
1762+
: TruncatedVal.zext(WideSize);
1763+
return ExtendedVal == CI->getValue();
1764+
}
1765+
17561766
TargetTransformInfo::OperandValueInfo
17571767
VPCostContext::getOperandInfo(VPValue *V) const {
17581768
if (!V->isLiveIn())

llvm/lib/Transforms/Vectorize/VPlanHelpers.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,10 @@ class VPlanPrinter {
468468
};
469469
#endif
470470

471+
/// Check if a constant \p CI can be safely treated as having been extended
472+
/// from a narrower type with the given extension kind.
473+
bool canConstantBeExtended(const ConstantInt *CI, Type *NarrowType,
474+
TTI::PartialReductionExtendKind ExtKind);
471475
} // end namespace llvm
472476

473477
#endif // LLVM_TRANSFORMS_VECTORIZE_VPLAN_H

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,14 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
340340
: Widen->getOperand(1));
341341
ExtAType = GetExtendKind(ExtAR);
342342
ExtBType = GetExtendKind(ExtBR);
343+
344+
if (!ExtBR && Widen->getOperand(1)->isLiveIn()) {
345+
auto *CI = cast<ConstantInt>(Widen->getOperand(1)->getLiveInIRValue());
346+
if (canConstantBeExtended(CI, InputTypeA, ExtAType)) {
347+
InputTypeB = InputTypeA;
348+
ExtBType = ExtAType;
349+
}
350+
}
343351
};
344352

345353
if (isa<VPWidenCastRecipe>(OpR)) {

llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,22 @@ define i32 @red_zext_mul_by_63(ptr %start, ptr %end) {
2020
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
2121
; CHECK: [[VECTOR_BODY]]:
2222
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
23-
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP5:%.*]], %[[VECTOR_BODY]] ]
23+
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ]
2424
; CHECK-NEXT: [[NEXT_GEP:%.*]] = getelementptr i8, ptr [[START]], i64 [[INDEX]]
2525
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[NEXT_GEP]], align 1
2626
; CHECK-NEXT: [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
2727
; CHECK-NEXT: [[TMP4:%.*]] = mul <16 x i32> [[TMP3]], splat (i32 63)
28-
; CHECK-NEXT: [[TMP5]] = add <16 x i32> [[VEC_PHI]], [[TMP4]]
28+
; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP4]])
2929
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
30-
; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
31-
; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
30+
; CHECK-NEXT: [[TMP5:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
31+
; CHECK-NEXT: br i1 [[TMP5]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
3232
; CHECK: [[MIDDLE_BLOCK]]:
33-
; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP5]])
33+
; CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE]])
3434
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP1]], [[N_VEC]]
3535
; CHECK-NEXT: br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]]
3636
; CHECK: [[SCALAR_PH]]:
3737
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi ptr [ [[TMP2]], %[[MIDDLE_BLOCK]] ], [ [[START]], %[[ENTRY]] ]
38-
; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP7]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
38+
; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP6]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
3939
; CHECK-NEXT: br label %[[LOOP:.*]]
4040
; CHECK: [[LOOP]]:
4141
; CHECK-NEXT: [[PTR_IV:%.*]] = phi ptr [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[GEP_IV_NEXT:%.*]], %[[LOOP]] ]
@@ -48,7 +48,7 @@ define i32 @red_zext_mul_by_63(ptr %start, ptr %end) {
4848
; CHECK-NEXT: [[EC:%.*]] = icmp eq ptr [[PTR_IV]], [[END]]
4949
; CHECK-NEXT: br i1 [[EC]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP3:![0-9]+]]
5050
; CHECK: [[EXIT]]:
51-
; CHECK-NEXT: [[RED_NEXT_LCSSA:%.*]] = phi i32 [ [[RED_NEXT]], %[[LOOP]] ], [ [[TMP7]], %[[MIDDLE_BLOCK]] ]
51+
; CHECK-NEXT: [[RED_NEXT_LCSSA:%.*]] = phi i32 [ [[RED_NEXT]], %[[LOOP]] ], [ [[TMP6]], %[[MIDDLE_BLOCK]] ]
5252
; CHECK-NEXT: ret i32 [[RED_NEXT_LCSSA]]
5353
;
5454
entry:
@@ -86,17 +86,17 @@ define i32 @red_zext_mul_by_255(ptr %start, ptr %end) {
8686
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
8787
; CHECK: [[VECTOR_BODY]]:
8888
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
89-
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP5:%.*]], %[[VECTOR_BODY]] ]
89+
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ]
9090
; CHECK-NEXT: [[NEXT_GEP:%.*]] = getelementptr i8, ptr [[START]], i64 [[INDEX]]
9191
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[NEXT_GEP]], align 1
9292
; CHECK-NEXT: [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
9393
; CHECK-NEXT: [[TMP4:%.*]] = mul <16 x i32> [[TMP3]], splat (i32 255)
94-
; CHECK-NEXT: [[TMP5]] = add <16 x i32> [[VEC_PHI]], [[TMP4]]
94+
; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP4]])
9595
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
9696
; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
9797
; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
9898
; CHECK: [[MIDDLE_BLOCK]]:
99-
; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP5]])
99+
; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE]])
100100
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP1]], [[N_VEC]]
101101
; CHECK-NEXT: br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]]
102102
; CHECK: [[SCALAR_PH]]:
@@ -218,22 +218,22 @@ define i32 @red_sext_mul_by_63(ptr %start, ptr %end) {
218218
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
219219
; CHECK: [[VECTOR_BODY]]:
220220
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
221-
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP5:%.*]], %[[VECTOR_BODY]] ]
221+
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ]
222222
; CHECK-NEXT: [[NEXT_GEP:%.*]] = getelementptr i8, ptr [[START]], i64 [[INDEX]]
223223
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[NEXT_GEP]], align 1
224224
; CHECK-NEXT: [[TMP3:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
225225
; CHECK-NEXT: [[TMP4:%.*]] = mul <16 x i32> [[TMP3]], splat (i32 63)
226-
; CHECK-NEXT: [[TMP5]] = add <16 x i32> [[VEC_PHI]], [[TMP4]]
226+
; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP4]])
227227
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
228-
; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
229-
; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
228+
; CHECK-NEXT: [[TMP5:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
229+
; CHECK-NEXT: br i1 [[TMP5]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
230230
; CHECK: [[MIDDLE_BLOCK]]:
231-
; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP5]])
231+
; CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE]])
232232
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP1]], [[N_VEC]]
233233
; CHECK-NEXT: br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]]
234234
; CHECK: [[SCALAR_PH]]:
235235
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi ptr [ [[TMP2]], %[[MIDDLE_BLOCK]] ], [ [[START]], %[[ENTRY]] ]
236-
; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP7]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
236+
; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP6]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
237237
; CHECK-NEXT: br label %[[LOOP:.*]]
238238
; CHECK: [[LOOP]]:
239239
; CHECK-NEXT: [[PTR_IV:%.*]] = phi ptr [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[GEP_IV_NEXT:%.*]], %[[LOOP]] ]
@@ -246,7 +246,7 @@ define i32 @red_sext_mul_by_63(ptr %start, ptr %end) {
246246
; CHECK-NEXT: [[EC:%.*]] = icmp eq ptr [[PTR_IV]], [[END]]
247247
; CHECK-NEXT: br i1 [[EC]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP9:![0-9]+]]
248248
; CHECK: [[EXIT]]:
249-
; CHECK-NEXT: [[RED_NEXT_LCSSA:%.*]] = phi i32 [ [[RED_NEXT]], %[[LOOP]] ], [ [[TMP7]], %[[MIDDLE_BLOCK]] ]
249+
; CHECK-NEXT: [[RED_NEXT_LCSSA:%.*]] = phi i32 [ [[RED_NEXT]], %[[LOOP]] ], [ [[TMP6]], %[[MIDDLE_BLOCK]] ]
250250
; CHECK-NEXT: ret i32 [[RED_NEXT_LCSSA]]
251251
;
252252
entry:

0 commit comments

Comments
 (0)