Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 49 additions & 11 deletions llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3642,6 +3642,37 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
Sub = VecOp->getDefiningRecipe();
VecOp = Tmp;
}

// If ValB is a constant and can be safely extended, truncate it to the same
// type as ExtA's operand, then extend it to the same type as ExtA. This
// creates two uniform extends that can more easily be matched by the rest of
// the bundling code. The ExtB reference, ValB and operand 1 of Mul are all
// replaced with the new extend of the constant.
auto ExtendAndReplaceConstantOp = [&Ctx](VPWidenCastRecipe *ExtA,
VPWidenCastRecipe *&ExtB,
VPValue *&ValB, VPWidenRecipe *Mul) {
if (!ExtA || ExtB || !ValB->isLiveIn())
return;
Type *NarrowTy = Ctx.Types.inferScalarType(ExtA->getOperand(0));
Instruction::CastOps ExtOpc = ExtA->getOpcode();
const APInt *Const;
if (!match(ValB, m_APInt(Const)) ||
!llvm::canConstantBeExtended(
Const, NarrowTy, TTI::getPartialReductionExtendKind(ExtOpc)))
return;
// The truncate ensures that the type of each extended operand is the
// same, and it's been proven that the constant can be extended from
// NarrowTy safely. Necessary since ExtA's extended operand would be
// e.g. an i8, while the const will likely be an i32. This will be
// elided by later optimisations.
VPBuilder Builder(Mul);
auto *Trunc =
Builder.createWidenCast(Instruction::CastOps::Trunc, ValB, NarrowTy);
Type *WideTy = Ctx.Types.inferScalarType(ExtA);
ValB = ExtB = Builder.createWidenCast(ExtOpc, Trunc, WideTy);
Mul->setOperand(1, ExtB);
};

// Try to match reduce.add(mul(...)).
if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
auto *RecipeA =
Expand All @@ -3650,6 +3681,9 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());

// Convert reduce.add(mul(ext, const)) to reduce.add(mul(ext, ext(const)))
ExtendAndReplaceConstantOp(RecipeA, RecipeB, B, Mul);

// Match reduce.add/sub(mul(ext, ext)).
if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
Expand All @@ -3659,7 +3693,6 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
cast<VPWidenRecipe>(Sub), Red);
return new VPExpressionRecipe(RecipeA, RecipeB, Mul, Red);
}
// Match reduce.add(mul).
// TODO: Add an expression type for this variant with a negated mul
if (!Sub && IsMulAccValidAndClampRange(Mul, nullptr, nullptr, nullptr))
return new VPExpressionRecipe(Mul, Red);
Expand All @@ -3668,18 +3701,23 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
// variants.
if (Sub)
return nullptr;
// Match reduce.add(ext(mul(ext(A), ext(B)))).
// All extend recipes must have same opcode or A == B
// which can be transform to reduce.add(zext(mul(sext(A), sext(B)))).
if (match(VecOp, m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
m_ZExtOrSExt(m_VPValue()))))) {

// Match reduce.add(ext(mul(A, B))).
if (match(VecOp, m_ZExtOrSExt(m_Mul(m_VPValue(A), m_VPValue(B))))) {
auto *Ext = cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
auto *Mul = cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
auto *Ext0 =
cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
auto *Ext1 =
cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
auto *Ext0 = dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
auto *Ext1 = dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());

// reduce.add(ext(mul(ext, const)))
// -> reduce.add(ext(mul(ext, ext(const))))
ExtendAndReplaceConstantOp(Ext0, Ext1, B, Mul);

// Match reduce.add(ext(mul(ext(A), ext(B))))
// All extend recipes must have same opcode or A == B
// which can be transformed to reduce.add(zext(mul(sext(A), sext(B)))).
Comment on lines +3717 to +3718
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know you've just copied this from above, but this comment is not accurate.

The cases it tries to handle are:

reduce.add(zext(mul(zext(a), zext(b))))
-> reduce.add(mul(wider_zext(a), wider_zext(b)))

reduce.add(sext(mul(sext(a), sext(b))))
-> reduce.add(mul(wider_sext(a), wider_sext(b)))

and the other case (and reason for checking Ext0 == Ext1) is because that would mean the mul is non-negative which means that the final zero-extend can be folded away, i.e.

reduce.add(zext(mul(sext(a), sext(a)))) // result of mul is nneg
-> reduce.add(mul(wider_sext(a), wider_sext(a)))

if (Ext0 && Ext1 &&
(Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
Ext0->getOpcode() == Ext1->getOpcode() &&
IsMulAccValidAndClampRange(Mul, Ext0, Ext1, Ext) && Mul->hasOneUse()) {
auto *NewExt0 = new VPWidenCastRecipe(
Expand Down
82 changes: 82 additions & 0 deletions llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
Original file line number Diff line number Diff line change
Expand Up @@ -2800,6 +2800,88 @@ exit:
ret i64 %r.0.lcssa
}

define i32 @reduction_expression_ext_mulacc_livein(ptr %a, i16 %c) {
; CHECK-LABEL: define i32 @reduction_expression_ext_mulacc_livein(
; CHECK-SAME: ptr [[A:%.*]], i16 [[C:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: br label %[[VECTOR_PH:.*]]
; CHECK: [[VECTOR_PH]]:
; CHECK-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement <4 x i16> poison, i16 [[C]], i64 0
; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <4 x i16> [[BROADCAST_SPLATINSERT]], <4 x i16> poison, <4 x i32> zeroinitializer
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
; CHECK: [[VECTOR_BODY]]:
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi i32 [ 0, %[[VECTOR_PH]] ], [ [[TMP5:%.*]], %[[VECTOR_BODY]] ]
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i8>, ptr [[TMP0]], align 1
; CHECK-NEXT: [[TMP1:%.*]] = zext <4 x i8> [[WIDE_LOAD]] to <4 x i16>
; CHECK-NEXT: [[TMP2:%.*]] = mul <4 x i16> [[BROADCAST_SPLAT]], [[TMP1]]
; CHECK-NEXT: [[TMP3:%.*]] = zext <4 x i16> [[TMP2]] to <4 x i32>
; CHECK-NEXT: [[TMP4:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP3]])
; CHECK-NEXT: [[TMP5]] = add i32 [[VEC_PHI]], [[TMP4]]
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP32:![0-9]+]]
; CHECK: [[MIDDLE_BLOCK]]:
; CHECK-NEXT: br label %[[FOR_EXIT:.*]]
; CHECK: [[FOR_EXIT]]:
; CHECK-NEXT: ret i32 [[TMP5]]
;
; CHECK-INTERLEAVED-LABEL: define i32 @reduction_expression_ext_mulacc_livein(
; CHECK-INTERLEAVED-SAME: ptr [[A:%.*]], i16 [[C:%.*]]) {
; CHECK-INTERLEAVED-NEXT: [[ENTRY:.*:]]
; CHECK-INTERLEAVED-NEXT: br label %[[VECTOR_PH:.*]]
; CHECK-INTERLEAVED: [[VECTOR_PH]]:
; CHECK-INTERLEAVED-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement <4 x i16> poison, i16 [[C]], i64 0
; CHECK-INTERLEAVED-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <4 x i16> [[BROADCAST_SPLATINSERT]], <4 x i16> poison, <4 x i32> zeroinitializer
; CHECK-INTERLEAVED-NEXT: br label %[[VECTOR_BODY:.*]]
; CHECK-INTERLEAVED: [[VECTOR_BODY]]:
; CHECK-INTERLEAVED-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
; CHECK-INTERLEAVED-NEXT: [[VEC_PHI:%.*]] = phi i32 [ 0, %[[VECTOR_PH]] ], [ [[TMP8:%.*]], %[[VECTOR_BODY]] ]
; CHECK-INTERLEAVED-NEXT: [[VEC_PHI1:%.*]] = phi i32 [ 0, %[[VECTOR_PH]] ], [ [[TMP11:%.*]], %[[VECTOR_BODY]] ]
; CHECK-INTERLEAVED-NEXT: [[TMP0:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
; CHECK-INTERLEAVED-NEXT: [[TMP1:%.*]] = getelementptr i8, ptr [[TMP0]], i32 4
; CHECK-INTERLEAVED-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i8>, ptr [[TMP0]], align 1
; CHECK-INTERLEAVED-NEXT: [[WIDE_LOAD2:%.*]] = load <4 x i8>, ptr [[TMP1]], align 1
; CHECK-INTERLEAVED-NEXT: [[TMP2:%.*]] = zext <4 x i8> [[WIDE_LOAD]] to <4 x i16>
; CHECK-INTERLEAVED-NEXT: [[TMP3:%.*]] = zext <4 x i8> [[WIDE_LOAD2]] to <4 x i16>
; CHECK-INTERLEAVED-NEXT: [[TMP4:%.*]] = mul <4 x i16> [[BROADCAST_SPLAT]], [[TMP2]]
; CHECK-INTERLEAVED-NEXT: [[TMP5:%.*]] = mul <4 x i16> [[BROADCAST_SPLAT]], [[TMP3]]
; CHECK-INTERLEAVED-NEXT: [[TMP6:%.*]] = zext <4 x i16> [[TMP4]] to <4 x i32>
; CHECK-INTERLEAVED-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP6]])
; CHECK-INTERLEAVED-NEXT: [[TMP8]] = add i32 [[VEC_PHI]], [[TMP7]]
; CHECK-INTERLEAVED-NEXT: [[TMP9:%.*]] = zext <4 x i16> [[TMP5]] to <4 x i32>
; CHECK-INTERLEAVED-NEXT: [[TMP10:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP9]])
; CHECK-INTERLEAVED-NEXT: [[TMP11]] = add i32 [[VEC_PHI1]], [[TMP10]]
; CHECK-INTERLEAVED-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
; CHECK-INTERLEAVED-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
; CHECK-INTERLEAVED-NEXT: br i1 [[TMP12]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP32:![0-9]+]]
; CHECK-INTERLEAVED: [[MIDDLE_BLOCK]]:
; CHECK-INTERLEAVED-NEXT: [[BIN_RDX:%.*]] = add i32 [[TMP11]], [[TMP8]]
; CHECK-INTERLEAVED-NEXT: br label %[[FOR_EXIT:.*]]
; CHECK-INTERLEAVED: [[FOR_EXIT]]:
; CHECK-INTERLEAVED-NEXT: ret i32 [[BIN_RDX]]
;
entry:
br label %for.body

for.body: ; preds = %for.body, %entry
%iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
%accum = phi i32 [ 0, %entry ], [ %add, %for.body ]
%gep.a = getelementptr i8, ptr %a, i64 %iv
%load.a = load i8, ptr %gep.a, align 1
%ext.a = zext i8 %load.a to i16
%mul = mul i16 %c, %ext.a
%mul.ext = zext i16 %mul to i32
%add = add i32 %mul.ext, %accum
%iv.next = add i64 %iv, 1
%exitcond.not = icmp eq i64 %iv.next, 1024
br i1 %exitcond.not, label %for.exit, label %for.body

for.exit: ; preds = %for.body
ret i32 %add
}

declare float @llvm.fmuladd.f32(float, float, float)

!6 = distinct !{!6, !7, !8}
Expand Down
Loading