Skip to content

Commit a69dcac

Browse files
SamTebbs33aokblast
authored andcommitted
[LV] Bundle (partial) reductions with a mul of a constant (llvm#162503)
A reduction (including partial reductions) with a multiply of a constant value can be bundled by first converting it from `reduce.add(mul(ext, const))` to `reduce.add(mul(ext, ext(const)))` as long as it is safe to extend the constant. This PR adds such bundling by first truncating the constant to the source type of the other extend, then extending it to the destination type of the extend. The first truncate is necessary so that the types of each extend's operand are then the same, and the call to canConstantBeExtended proves that the extend following a truncate is safe to do. The truncate is removed by optimisations. This is a stacked PR, 1a and 1b can be merged in any order: 1a. llvm#147302 1b. llvm#163175 2. -> llvm#162503
1 parent f02ba7e commit a69dcac

File tree

3 files changed

+676
-11
lines changed

3 files changed

+676
-11
lines changed

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3648,6 +3648,37 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
36483648
Sub = VecOp->getDefiningRecipe();
36493649
VecOp = Tmp;
36503650
}
3651+
3652+
// If ValB is a constant and can be safely extended, truncate it to the same
3653+
// type as ExtA's operand, then extend it to the same type as ExtA. This
3654+
// creates two uniform extends that can more easily be matched by the rest of
3655+
// the bundling code. The ExtB reference, ValB and operand 1 of Mul are all
3656+
// replaced with the new extend of the constant.
3657+
auto ExtendAndReplaceConstantOp = [&Ctx](VPWidenCastRecipe *ExtA,
3658+
VPWidenCastRecipe *&ExtB,
3659+
VPValue *&ValB, VPWidenRecipe *Mul) {
3660+
if (!ExtA || ExtB || !ValB->isLiveIn())
3661+
return;
3662+
Type *NarrowTy = Ctx.Types.inferScalarType(ExtA->getOperand(0));
3663+
Instruction::CastOps ExtOpc = ExtA->getOpcode();
3664+
const APInt *Const;
3665+
if (!match(ValB, m_APInt(Const)) ||
3666+
!llvm::canConstantBeExtended(
3667+
Const, NarrowTy, TTI::getPartialReductionExtendKind(ExtOpc)))
3668+
return;
3669+
// The truncate ensures that the type of each extended operand is the
3670+
// same, and it's been proven that the constant can be extended from
3671+
// NarrowTy safely. Necessary since ExtA's extended operand would be
3672+
// e.g. an i8, while the const will likely be an i32. This will be
3673+
// elided by later optimisations.
3674+
VPBuilder Builder(Mul);
3675+
auto *Trunc =
3676+
Builder.createWidenCast(Instruction::CastOps::Trunc, ValB, NarrowTy);
3677+
Type *WideTy = Ctx.Types.inferScalarType(ExtA);
3678+
ValB = ExtB = Builder.createWidenCast(ExtOpc, Trunc, WideTy);
3679+
Mul->setOperand(1, ExtB);
3680+
};
3681+
36513682
// Try to match reduce.add(mul(...)).
36523683
if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
36533684
auto *RecipeA =
@@ -3656,6 +3687,9 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
36563687
dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
36573688
auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());
36583689

3690+
// Convert reduce.add(mul(ext, const)) to reduce.add(mul(ext, ext(const)))
3691+
ExtendAndReplaceConstantOp(RecipeA, RecipeB, B, Mul);
3692+
36593693
// Match reduce.add/sub(mul(ext, ext)).
36603694
if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
36613695
match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
@@ -3665,7 +3699,6 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
36653699
cast<VPWidenRecipe>(Sub), Red);
36663700
return new VPExpressionRecipe(RecipeA, RecipeB, Mul, Red);
36673701
}
3668-
// Match reduce.add(mul).
36693702
// TODO: Add an expression type for this variant with a negated mul
36703703
if (!Sub && IsMulAccValidAndClampRange(Mul, nullptr, nullptr, nullptr))
36713704
return new VPExpressionRecipe(Mul, Red);
@@ -3674,18 +3707,26 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
36743707
// variants.
36753708
if (Sub)
36763709
return nullptr;
3677-
// Match reduce.add(ext(mul(ext(A), ext(B)))).
3678-
// All extend recipes must have same opcode or A == B
3679-
// which can be transform to reduce.add(zext(mul(sext(A), sext(B)))).
3680-
if (match(VecOp, m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
3681-
m_ZExtOrSExt(m_VPValue()))))) {
3710+
3711+
// Match reduce.add(ext(mul(A, B))).
3712+
if (match(VecOp, m_ZExtOrSExt(m_Mul(m_VPValue(A), m_VPValue(B))))) {
36823713
auto *Ext = cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
36833714
auto *Mul = cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
3684-
auto *Ext0 =
3685-
cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
3686-
auto *Ext1 =
3687-
cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
3688-
if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
3715+
auto *Ext0 = dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
3716+
auto *Ext1 = dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
3717+
3718+
// reduce.add(ext(mul(ext, const)))
3719+
// -> reduce.add(ext(mul(ext, ext(const))))
3720+
ExtendAndReplaceConstantOp(Ext0, Ext1, B, Mul);
3721+
3722+
// reduce.add(ext(mul(ext(A), ext(B))))
3723+
// -> reduce.add(mul(wider_ext(A), wider_ext(B)))
3724+
// The inner extends must either have the same opcode as the outer extend or
3725+
// be the same, in which case the multiply can never result in a negative
3726+
// value and the outer extend can be folded away by doing wider
3727+
// extends for the operands of the mul.
3728+
if (Ext0 && Ext1 &&
3729+
(Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
36893730
Ext0->getOpcode() == Ext1->getOpcode() &&
36903731
IsMulAccValidAndClampRange(Mul, Ext0, Ext1, Ext) && Mul->hasOneUse()) {
36913732
auto *NewExt0 = new VPWidenCastRecipe(

llvm/test/Transforms/LoopVectorize/reduction-inloop.ll

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2800,6 +2800,88 @@ exit:
28002800
ret i64 %r.0.lcssa
28012801
}
28022802

2803+
define i32 @reduction_expression_ext_mulacc_livein(ptr %a, i16 %c) {
2804+
; CHECK-LABEL: define i32 @reduction_expression_ext_mulacc_livein(
2805+
; CHECK-SAME: ptr [[A:%.*]], i16 [[C:%.*]]) {
2806+
; CHECK-NEXT: [[ENTRY:.*:]]
2807+
; CHECK-NEXT: br label %[[VECTOR_PH:.*]]
2808+
; CHECK: [[VECTOR_PH]]:
2809+
; CHECK-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement <4 x i16> poison, i16 [[C]], i64 0
2810+
; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <4 x i16> [[BROADCAST_SPLATINSERT]], <4 x i16> poison, <4 x i32> zeroinitializer
2811+
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
2812+
; CHECK: [[VECTOR_BODY]]:
2813+
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
2814+
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi i32 [ 0, %[[VECTOR_PH]] ], [ [[TMP5:%.*]], %[[VECTOR_BODY]] ]
2815+
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
2816+
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i8>, ptr [[TMP0]], align 1
2817+
; CHECK-NEXT: [[TMP1:%.*]] = zext <4 x i8> [[WIDE_LOAD]] to <4 x i16>
2818+
; CHECK-NEXT: [[TMP2:%.*]] = mul <4 x i16> [[BROADCAST_SPLAT]], [[TMP1]]
2819+
; CHECK-NEXT: [[TMP3:%.*]] = zext <4 x i16> [[TMP2]] to <4 x i32>
2820+
; CHECK-NEXT: [[TMP4:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP3]])
2821+
; CHECK-NEXT: [[TMP5]] = add i32 [[VEC_PHI]], [[TMP4]]
2822+
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
2823+
; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
2824+
; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP32:![0-9]+]]
2825+
; CHECK: [[MIDDLE_BLOCK]]:
2826+
; CHECK-NEXT: br label %[[FOR_EXIT:.*]]
2827+
; CHECK: [[FOR_EXIT]]:
2828+
; CHECK-NEXT: ret i32 [[TMP5]]
2829+
;
2830+
; CHECK-INTERLEAVED-LABEL: define i32 @reduction_expression_ext_mulacc_livein(
2831+
; CHECK-INTERLEAVED-SAME: ptr [[A:%.*]], i16 [[C:%.*]]) {
2832+
; CHECK-INTERLEAVED-NEXT: [[ENTRY:.*:]]
2833+
; CHECK-INTERLEAVED-NEXT: br label %[[VECTOR_PH:.*]]
2834+
; CHECK-INTERLEAVED: [[VECTOR_PH]]:
2835+
; CHECK-INTERLEAVED-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement <4 x i16> poison, i16 [[C]], i64 0
2836+
; CHECK-INTERLEAVED-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <4 x i16> [[BROADCAST_SPLATINSERT]], <4 x i16> poison, <4 x i32> zeroinitializer
2837+
; CHECK-INTERLEAVED-NEXT: br label %[[VECTOR_BODY:.*]]
2838+
; CHECK-INTERLEAVED: [[VECTOR_BODY]]:
2839+
; CHECK-INTERLEAVED-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
2840+
; CHECK-INTERLEAVED-NEXT: [[VEC_PHI:%.*]] = phi i32 [ 0, %[[VECTOR_PH]] ], [ [[TMP8:%.*]], %[[VECTOR_BODY]] ]
2841+
; CHECK-INTERLEAVED-NEXT: [[VEC_PHI1:%.*]] = phi i32 [ 0, %[[VECTOR_PH]] ], [ [[TMP11:%.*]], %[[VECTOR_BODY]] ]
2842+
; CHECK-INTERLEAVED-NEXT: [[TMP0:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
2843+
; CHECK-INTERLEAVED-NEXT: [[TMP1:%.*]] = getelementptr i8, ptr [[TMP0]], i32 4
2844+
; CHECK-INTERLEAVED-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i8>, ptr [[TMP0]], align 1
2845+
; CHECK-INTERLEAVED-NEXT: [[WIDE_LOAD2:%.*]] = load <4 x i8>, ptr [[TMP1]], align 1
2846+
; CHECK-INTERLEAVED-NEXT: [[TMP2:%.*]] = zext <4 x i8> [[WIDE_LOAD]] to <4 x i16>
2847+
; CHECK-INTERLEAVED-NEXT: [[TMP3:%.*]] = zext <4 x i8> [[WIDE_LOAD2]] to <4 x i16>
2848+
; CHECK-INTERLEAVED-NEXT: [[TMP4:%.*]] = mul <4 x i16> [[BROADCAST_SPLAT]], [[TMP2]]
2849+
; CHECK-INTERLEAVED-NEXT: [[TMP5:%.*]] = mul <4 x i16> [[BROADCAST_SPLAT]], [[TMP3]]
2850+
; CHECK-INTERLEAVED-NEXT: [[TMP6:%.*]] = zext <4 x i16> [[TMP4]] to <4 x i32>
2851+
; CHECK-INTERLEAVED-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP6]])
2852+
; CHECK-INTERLEAVED-NEXT: [[TMP8]] = add i32 [[VEC_PHI]], [[TMP7]]
2853+
; CHECK-INTERLEAVED-NEXT: [[TMP9:%.*]] = zext <4 x i16> [[TMP5]] to <4 x i32>
2854+
; CHECK-INTERLEAVED-NEXT: [[TMP10:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP9]])
2855+
; CHECK-INTERLEAVED-NEXT: [[TMP11]] = add i32 [[VEC_PHI1]], [[TMP10]]
2856+
; CHECK-INTERLEAVED-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
2857+
; CHECK-INTERLEAVED-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
2858+
; CHECK-INTERLEAVED-NEXT: br i1 [[TMP12]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP32:![0-9]+]]
2859+
; CHECK-INTERLEAVED: [[MIDDLE_BLOCK]]:
2860+
; CHECK-INTERLEAVED-NEXT: [[BIN_RDX:%.*]] = add i32 [[TMP11]], [[TMP8]]
2861+
; CHECK-INTERLEAVED-NEXT: br label %[[FOR_EXIT:.*]]
2862+
; CHECK-INTERLEAVED: [[FOR_EXIT]]:
2863+
; CHECK-INTERLEAVED-NEXT: ret i32 [[BIN_RDX]]
2864+
;
2865+
entry:
2866+
br label %for.body
2867+
2868+
for.body: ; preds = %for.body, %entry
2869+
%iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
2870+
%accum = phi i32 [ 0, %entry ], [ %add, %for.body ]
2871+
%gep.a = getelementptr i8, ptr %a, i64 %iv
2872+
%load.a = load i8, ptr %gep.a, align 1
2873+
%ext.a = zext i8 %load.a to i16
2874+
%mul = mul i16 %c, %ext.a
2875+
%mul.ext = zext i16 %mul to i32
2876+
%add = add i32 %mul.ext, %accum
2877+
%iv.next = add i64 %iv, 1
2878+
%exitcond.not = icmp eq i64 %iv.next, 1024
2879+
br i1 %exitcond.not, label %for.exit, label %for.body
2880+
2881+
for.exit: ; preds = %for.body
2882+
ret i32 %add
2883+
}
2884+
28032885
declare float @llvm.fmuladd.f32(float, float, float)
28042886

28052887
!6 = distinct !{!6, !7, !8}

0 commit comments

Comments
 (0)