Skip to content

Commit a7b4dd4

Browse files
authored
[LV] Don't create partial reductions if factor doesn't match accumulator (#158603)
Check if the scale-factor of the accumulator is the same as the request ScaleFactor in tryToCreatePartialReductions. This prevents creating partial reductions if not all instructions in the reduction chain form partial reductions. e.g. because we do not form a partial reduction for the loop exit instruction. Currently code-gen works fine, because the scale factor of VPPartialReduction is not used during ::execute, but it means we compute incorrect cost/register pressure, because the partial reduction won't reduce to the specified scaling factor. PR: #158603
1 parent ea0e518 commit a7b4dd4

File tree

5 files changed

+35
-23
lines changed

5 files changed

+35
-23
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8183,8 +8183,11 @@ VPRecipeBase *VPRecipeBuilder::tryToCreateWidenRecipe(VPSingleDefRecipe *R,
81838183
if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr))
81848184
return tryToWidenMemory(Instr, Operands, Range);
81858185

8186-
if (std::optional<unsigned> ScaleFactor = getScalingForReduction(Instr))
8187-
return tryToCreatePartialReduction(Instr, Operands, ScaleFactor.value());
8186+
if (std::optional<unsigned> ScaleFactor = getScalingForReduction(Instr)) {
8187+
if (auto PartialRed =
8188+
tryToCreatePartialReduction(Instr, Operands, ScaleFactor.value()))
8189+
return PartialRed;
8190+
}
81888191

81898192
if (!shouldWiden(Instr, Range))
81908193
return nullptr;
@@ -8218,6 +8221,10 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
82188221
isa<VPPartialReductionRecipe>(BinOpRecipe))
82198222
std::swap(BinOp, Accumulator);
82208223

8224+
if (ScaleFactor !=
8225+
vputils::getVFScaleFactor(Accumulator->getDefiningRecipe()))
8226+
return nullptr;
8227+
82218228
unsigned ReductionOpcode = Reduction->getOpcode();
82228229
if (ReductionOpcode == Instruction::Sub) {
82238230
auto *const Zero = ConstantInt::get(Reduction->getType(), 0);

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -395,20 +395,6 @@ bool VPDominatorTree::properlyDominates(const VPRecipeBase *A,
395395
return Base::properlyDominates(ParentA, ParentB);
396396
}
397397

398-
/// Get the VF scaling factor applied to the recipe's output, if the recipe has
399-
/// one.
400-
static unsigned getVFScaleFactor(VPValue *R) {
401-
if (auto *RR = dyn_cast<VPReductionPHIRecipe>(R))
402-
return RR->getVFScaleFactor();
403-
if (auto *RR = dyn_cast<VPPartialReductionRecipe>(R))
404-
return RR->getVFScaleFactor();
405-
assert(
406-
(!isa<VPInstruction>(R) || cast<VPInstruction>(R)->getOpcode() !=
407-
VPInstruction::ReductionStartVector) &&
408-
"getting scaling factor of reduction-start-vector not implemented yet");
409-
return 1;
410-
}
411-
412398
bool VPRegisterUsage::exceedsMaxNumRegs(const TargetTransformInfo &TTI,
413399
unsigned OverrideMaxNumRegs) const {
414400
return any_of(MaxLocalUsers, [&TTI, &OverrideMaxNumRegs](auto &LU) {
@@ -571,7 +557,8 @@ SmallVector<VPRegisterUsage, 8> llvm::calculateRegisterUsageForPlan(
571557
} else {
572558
// The output from scaled phis and scaled reductions actually has
573559
// fewer lanes than the VF.
574-
unsigned ScaleFactor = getVFScaleFactor(VPV);
560+
unsigned ScaleFactor =
561+
vputils::getVFScaleFactor(VPV->getDefiningRecipe());
575562
ElementCount VF = VFs[J].divideCoefficientBy(ScaleFactor);
576563
LLVM_DEBUG(if (VF != VFs[J]) {
577564
dbgs() << "LV(REG): Scaled down VF from " << VFs[J] << " to " << VF

llvm/lib/Transforms/Vectorize/VPlanUtils.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,20 @@ VPBasicBlock *vputils::getFirstLoopHeader(VPlan &Plan, VPDominatorTree &VPDT) {
141141
return I == DepthFirst.end() ? nullptr : cast<VPBasicBlock>(*I);
142142
}
143143

144+
unsigned vputils::getVFScaleFactor(VPRecipeBase *R) {
145+
if (!R)
146+
return 1;
147+
if (auto *RR = dyn_cast<VPReductionPHIRecipe>(R))
148+
return RR->getVFScaleFactor();
149+
if (auto *RR = dyn_cast<VPPartialReductionRecipe>(R))
150+
return RR->getVFScaleFactor();
151+
assert(
152+
(!isa<VPInstruction>(R) || cast<VPInstruction>(R)->getOpcode() !=
153+
VPInstruction::ReductionStartVector) &&
154+
"getting scaling factor of reduction-start-vector not implemented yet");
155+
return 1;
156+
}
157+
144158
std::optional<VPValue *>
145159
vputils::getRecipesForUncountableExit(VPlan &Plan,
146160
SmallVectorImpl<VPRecipeBase *> &Recipes,

llvm/lib/Transforms/Vectorize/VPlanUtils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ bool isUniformAcrossVFsAndUFs(VPValue *V);
102102
/// exist.
103103
VPBasicBlock *getFirstLoopHeader(VPlan &Plan, VPDominatorTree &VPDT);
104104

105+
/// Get the VF scaling factor applied to the recipe's output, if the recipe has
106+
/// one.
107+
unsigned getVFScaleFactor(VPRecipeBase *R);
108+
105109
/// Returns the VPValue representing the uncountable exit comparison used by
106110
/// AnyOf if the recipes it depends on can be traced back to live-ins and
107111
/// the addresses (in GEP/PtrAdd form) of any (non-masked) load used in

llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,8 +1381,8 @@ for.body: ; preds = %for.body.preheader,
13811381
br i1 %exitcond.not, label %for.cond.cleanup, label %for.body, !loop !1
13821382
}
13831383

1384-
define i32 @red_extended_add_chain(ptr %start, ptr %end, i32 %offset) {
1385-
; CHECK-NEON-LABEL: define i32 @red_extended_add_chain(
1384+
define i32 @red_extended_add_incomplete_chain(ptr %start, ptr %end, i32 %offset) {
1385+
; CHECK-NEON-LABEL: define i32 @red_extended_add_incomplete_chain(
13861386
; CHECK-NEON-SAME: ptr [[START:%.*]], ptr [[END:%.*]], i32 [[OFFSET:%.*]]) #[[ATTR1:[0-9]+]] {
13871387
; CHECK-NEON-NEXT: entry:
13881388
; CHECK-NEON-NEXT: [[START2:%.*]] = ptrtoint ptr [[START]] to i64
@@ -1404,7 +1404,7 @@ define i32 @red_extended_add_chain(ptr %start, ptr %end, i32 %offset) {
14041404
; CHECK-NEON-NEXT: [[NEXT_GEP:%.*]] = getelementptr i8, ptr [[START]], i64 [[INDEX]]
14051405
; CHECK-NEON-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[NEXT_GEP]], align 1
14061406
; CHECK-NEON-NEXT: [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
1407-
; CHECK-NEON-NEXT: [[PARTIAL_REDUCE:%.*]] = call <16 x i32> @llvm.vector.partial.reduce.add.v16i32.v16i32(<16 x i32> [[VEC_PHI]], <16 x i32> [[TMP3]])
1407+
; CHECK-NEON-NEXT: [[PARTIAL_REDUCE:%.*]] = add <16 x i32> [[VEC_PHI]], [[TMP3]]
14081408
; CHECK-NEON-NEXT: [[TMP4]] = add <16 x i32> [[PARTIAL_REDUCE]], [[BROADCAST_SPLAT]]
14091409
; CHECK-NEON-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
14101410
; CHECK-NEON-NEXT: [[TMP5:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
@@ -1415,7 +1415,7 @@ define i32 @red_extended_add_chain(ptr %start, ptr %end, i32 %offset) {
14151415
; CHECK-NEON-NEXT: br i1 [[CMP_N]], label [[EXIT:%.*]], label [[SCALAR_PH]]
14161416
; CHECK-NEON: scalar.ph:
14171417
;
1418-
; CHECK-SVE-LABEL: define i32 @red_extended_add_chain(
1418+
; CHECK-SVE-LABEL: define i32 @red_extended_add_incomplete_chain(
14191419
; CHECK-SVE-SAME: ptr [[START:%.*]], ptr [[END:%.*]], i32 [[OFFSET:%.*]]) #[[ATTR1:[0-9]+]] {
14201420
; CHECK-SVE-NEXT: entry:
14211421
; CHECK-SVE-NEXT: [[START2:%.*]] = ptrtoint ptr [[START]] to i64
@@ -1452,7 +1452,7 @@ define i32 @red_extended_add_chain(ptr %start, ptr %end, i32 %offset) {
14521452
; CHECK-SVE-NEXT: br i1 [[CMP_N]], label [[EXIT:%.*]], label [[SCALAR_PH]]
14531453
; CHECK-SVE: scalar.ph:
14541454
;
1455-
; CHECK-SVE-MAXBW-LABEL: define i32 @red_extended_add_chain(
1455+
; CHECK-SVE-MAXBW-LABEL: define i32 @red_extended_add_incomplete_chain(
14561456
; CHECK-SVE-MAXBW-SAME: ptr [[START:%.*]], ptr [[END:%.*]], i32 [[OFFSET:%.*]]) #[[ATTR1:[0-9]+]] {
14571457
; CHECK-SVE-MAXBW-NEXT: entry:
14581458
; CHECK-SVE-MAXBW-NEXT: [[START2:%.*]] = ptrtoint ptr [[START]] to i64
@@ -1478,7 +1478,7 @@ define i32 @red_extended_add_chain(ptr %start, ptr %end, i32 %offset) {
14781478
; CHECK-SVE-MAXBW-NEXT: [[NEXT_GEP:%.*]] = getelementptr i8, ptr [[START]], i64 [[INDEX]]
14791479
; CHECK-SVE-MAXBW-NEXT: [[WIDE_LOAD:%.*]] = load <vscale x 8 x i8>, ptr [[NEXT_GEP]], align 1
14801480
; CHECK-SVE-MAXBW-NEXT: [[TMP7:%.*]] = zext <vscale x 8 x i8> [[WIDE_LOAD]] to <vscale x 8 x i32>
1481-
; CHECK-SVE-MAXBW-NEXT: [[PARTIAL_REDUCE:%.*]] = call <vscale x 8 x i32> @llvm.vector.partial.reduce.add.nxv8i32.nxv8i32(<vscale x 8 x i32> [[VEC_PHI]], <vscale x 8 x i32> [[TMP7]])
1481+
; CHECK-SVE-MAXBW-NEXT: [[PARTIAL_REDUCE:%.*]] = add <vscale x 8 x i32> [[VEC_PHI]], [[TMP7]]
14821482
; CHECK-SVE-MAXBW-NEXT: [[TMP8]] = add <vscale x 8 x i32> [[PARTIAL_REDUCE]], [[BROADCAST_SPLAT]]
14831483
; CHECK-SVE-MAXBW-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
14841484
; CHECK-SVE-MAXBW-NEXT: [[TMP9:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]

0 commit comments

Comments
 (0)