Skip to content

Commit 3396b46

Browse files
authored
[LV] Allow partial reductions with an extended bin op (#165536)
A pattern of the form reduce.add(ext(mul)) is valid for a partial reduction as long as the mul and its operands fulfill the requirements of a normal partial reduction. The mul's extend operands will be optimised to the wider extend, and we already have oneUse checks in place to make sure the mul and operands can be modified safely. 1. -> #165536 2. #165543
1 parent 2cf550a commit 3396b46

File tree

5 files changed

+702
-83
lines changed

5 files changed

+702
-83
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8061,6 +8061,19 @@ bool VPRecipeBuilder::getScaledReductions(
80618061
if (Op == PHI)
80628062
std::swap(Op, PhiOp);
80638063

8064+
using namespace llvm::PatternMatch;
8065+
// If Op is an extend, then it's still a valid partial reduction if the
8066+
// extended mul fulfills the other requirements.
8067+
// For example, reduce.add(ext(mul(ext(A), ext(B)))) is still a valid partial
8068+
// reduction since the inner extends will be widened. We already have oneUse
8069+
// checks on the inner extends so widening them is safe.
8070+
std::optional<TTI::PartialReductionExtendKind> OuterExtKind = std::nullopt;
8071+
if (match(Op, m_ZExtOrSExt(m_Mul(m_Value(), m_Value())))) {
8072+
auto *Cast = cast<CastInst>(Op);
8073+
OuterExtKind = TTI::getPartialReductionExtendKind(Cast->getOpcode());
8074+
Op = Cast->getOperand(0);
8075+
}
8076+
80648077
// Try and get a scaled reduction from the first non-phi operand.
80658078
// If one is found, we use the discovered reduction instruction in
80668079
// place of the accumulator for costing.
@@ -8077,8 +8090,6 @@ bool VPRecipeBuilder::getScaledReductions(
80778090
if (PhiOp != PHI)
80788091
return false;
80798092

8080-
using namespace llvm::PatternMatch;
8081-
80828093
// If the update is a binary operator, check both of its operands to see if
80838094
// they are extends. Otherwise, see if the update comes directly from an
80848095
// extend.
@@ -8088,7 +8099,7 @@ bool VPRecipeBuilder::getScaledReductions(
80888099
Type *ExtOpTypes[2] = {nullptr};
80898100
TTI::PartialReductionExtendKind ExtKinds[2] = {TTI::PR_None};
80908101

8091-
auto CollectExtInfo = [this, &Exts, &ExtOpTypes,
8102+
auto CollectExtInfo = [this, OuterExtKind, &Exts, &ExtOpTypes,
80928103
&ExtKinds](SmallVectorImpl<Value *> &Ops) -> bool {
80938104
for (const auto &[I, OpI] : enumerate(Ops)) {
80948105
const APInt *C;
@@ -8109,6 +8120,10 @@ bool VPRecipeBuilder::getScaledReductions(
81098120

81108121
ExtOpTypes[I] = ExtOp->getType();
81118122
ExtKinds[I] = TTI::getPartialReductionExtendKind(Exts[I]);
8123+
// The outer extend kind must be the same as the inner extends, so that
8124+
// they can be folded together.
8125+
if (OuterExtKind.has_value() && OuterExtKind.value() != ExtKinds[I])
8126+
return false;
81128127
}
81138128
return true;
81148129
};

llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll

Lines changed: 0 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -467,83 +467,3 @@ loop:
467467
exit:
468468
ret i32 %red.next
469469
}
470-
471-
define i64 @partial_reduction_mul_two_users(i64 %n, ptr %a, i16 %b, i32 %c) {
472-
; CHECK-LABEL: define i64 @partial_reduction_mul_two_users(
473-
; CHECK-SAME: i64 [[N:%.*]], ptr [[A:%.*]], i16 [[B:%.*]], i32 [[C:%.*]]) #[[ATTR0]] {
474-
; CHECK-NEXT: [[ENTRY:.*]]:
475-
; CHECK-NEXT: [[TMP0:%.*]] = add i64 [[N]], 1
476-
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP0]], 8
477-
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
478-
; CHECK: [[VECTOR_PH]]:
479-
; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[TMP0]], 8
480-
; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 [[TMP0]], [[N_MOD_VF]]
481-
; CHECK-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement <8 x i16> poison, i16 [[B]], i64 0
482-
; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <8 x i16> [[BROADCAST_SPLATINSERT]], <8 x i16> poison, <8 x i32> zeroinitializer
483-
; CHECK-NEXT: [[TMP1:%.*]] = sext <8 x i16> [[BROADCAST_SPLAT]] to <8 x i32>
484-
; CHECK-NEXT: [[TMP2:%.*]] = mul <8 x i32> [[TMP1]], [[TMP1]]
485-
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
486-
; CHECK: [[VECTOR_BODY]]:
487-
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
488-
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x i64> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ]
489-
; CHECK-NEXT: [[TMP4:%.*]] = load i16, ptr [[A]], align 2
490-
; CHECK-NEXT: [[BROADCAST_SPLATINSERT1:%.*]] = insertelement <8 x i16> poison, i16 [[TMP4]], i64 0
491-
; CHECK-NEXT: [[BROADCAST_SPLAT2:%.*]] = shufflevector <8 x i16> [[BROADCAST_SPLATINSERT1]], <8 x i16> poison, <8 x i32> zeroinitializer
492-
; CHECK-NEXT: [[TMP3:%.*]] = zext <8 x i32> [[TMP2]] to <8 x i64>
493-
; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i64> @llvm.vector.partial.reduce.add.v4i64.v8i64(<4 x i64> [[VEC_PHI]], <8 x i64> [[TMP3]])
494-
; CHECK-NEXT: [[TMP5:%.*]] = sext <8 x i16> [[BROADCAST_SPLAT2]] to <8 x i32>
495-
; CHECK-NEXT: [[TMP6:%.*]] = sext <8 x i32> [[TMP5]] to <8 x i64>
496-
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
497-
; CHECK-NEXT: [[TMP7:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
498-
; CHECK-NEXT: br i1 [[TMP7]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP18:![0-9]+]]
499-
; CHECK: [[MIDDLE_BLOCK]]:
500-
; CHECK-NEXT: [[TMP8:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[PARTIAL_REDUCE]])
501-
; CHECK-NEXT: [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <8 x i64> [[TMP6]], i32 7
502-
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP0]], [[N_VEC]]
503-
; CHECK-NEXT: br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]]
504-
; CHECK: [[SCALAR_PH]]:
505-
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
506-
; CHECK-NEXT: [[SCALAR_RECUR_INIT:%.*]] = phi i64 [ [[VECTOR_RECUR_EXTRACT]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
507-
; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i64 [ [[TMP8]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
508-
; CHECK-NEXT: br label %[[LOOP:.*]]
509-
; CHECK: [[LOOP]]:
510-
; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
511-
; CHECK-NEXT: [[RES1:%.*]] = phi i64 [ [[SCALAR_RECUR_INIT]], %[[SCALAR_PH]] ], [ [[LOAD_EXT_EXT:%.*]], %[[LOOP]] ]
512-
; CHECK-NEXT: [[RES2:%.*]] = phi i64 [ [[BC_MERGE_RDX]], %[[SCALAR_PH]] ], [ [[ADD:%.*]], %[[LOOP]] ]
513-
; CHECK-NEXT: [[LOAD:%.*]] = load i16, ptr [[A]], align 2
514-
; CHECK-NEXT: [[IV_NEXT]] = add i64 [[IV]], 1
515-
; CHECK-NEXT: [[CONV:%.*]] = sext i16 [[B]] to i32
516-
; CHECK-NEXT: [[MUL:%.*]] = mul i32 [[CONV]], [[CONV]]
517-
; CHECK-NEXT: [[MUL_EXT:%.*]] = zext i32 [[MUL]] to i64
518-
; CHECK-NEXT: [[ADD]] = add i64 [[RES2]], [[MUL_EXT]]
519-
; CHECK-NEXT: [[OR:%.*]] = or i32 [[MUL]], [[C]]
520-
; CHECK-NEXT: [[LOAD_EXT:%.*]] = sext i16 [[LOAD]] to i32
521-
; CHECK-NEXT: [[LOAD_EXT_EXT]] = sext i32 [[LOAD_EXT]] to i64
522-
; CHECK-NEXT: [[EXITCOND740_NOT:%.*]] = icmp eq i64 [[IV]], [[N]]
523-
; CHECK-NEXT: br i1 [[EXITCOND740_NOT]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP19:![0-9]+]]
524-
; CHECK: [[EXIT]]:
525-
; CHECK-NEXT: [[ADD_LCSSA:%.*]] = phi i64 [ [[ADD]], %[[LOOP]] ], [ [[TMP8]], %[[MIDDLE_BLOCK]] ]
526-
; CHECK-NEXT: ret i64 [[ADD_LCSSA]]
527-
;
528-
entry:
529-
br label %loop
530-
531-
loop:
532-
%iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ]
533-
%res1 = phi i64 [ 0, %entry ], [ %load.ext.ext, %loop ]
534-
%res2 = phi i64 [ 0, %entry ], [ %add, %loop ]
535-
%load = load i16, ptr %a, align 2
536-
%iv.next = add i64 %iv, 1
537-
%conv = sext i16 %b to i32
538-
%mul = mul i32 %conv, %conv
539-
%mul.ext = zext i32 %mul to i64
540-
%add = add i64 %res2, %mul.ext
541-
%second_use = or i32 %mul, %c ; this value is otherwise unused, but that's sufficient for the test
542-
%load.ext = sext i16 %load to i32
543-
%load.ext.ext = sext i32 %load.ext to i64
544-
%exitcond740.not = icmp eq i64 %iv, %n
545-
br i1 %exitcond740.not, label %exit, label %loop
546-
547-
exit:
548-
ret i64 %add
549-
}

0 commit comments

Comments
 (0)