Skip to content

Commit acc8ac1

Browse files
committed
[LV] Simplify vp reduction creation code
1 parent 07fb694 commit acc8ac1

File tree

3 files changed

+26
-21
lines changed

3 files changed

+26
-21
lines changed

llvm/include/llvm/Transforms/Utils/LoopUtils.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -423,8 +423,9 @@ LLVM_ABI Value *createSimpleReduction(IRBuilderBase &B, Value *Src,
423423
RecurKind RdxKind);
424424
/// Overloaded function to generate vector-predication intrinsics for
425425
/// reduction.
426-
LLVM_ABI Value *createSimpleReduction(VectorBuilder &VB, Value *Src,
427-
RecurKind RdxKind);
426+
LLVM_ABI Value *createSimpleReduction(IRBuilderBase &B, Value *Src,
427+
RecurKind RdxKind, Value *Mask,
428+
Value *EVL);
428429

429430
/// Create a reduction of the given vector \p Src for a reduction of kind
430431
/// RecurKind::AnyOf. The start value of the reduction is \p InitVal.
@@ -442,8 +443,9 @@ LLVM_ABI Value *createOrderedReduction(IRBuilderBase &B, RecurKind RdxKind,
442443
Value *Src, Value *Start);
443444
/// Overloaded function to generate vector-predication intrinsics for ordered
444445
/// reduction.
445-
LLVM_ABI Value *createOrderedReduction(VectorBuilder &VB, RecurKind RdxKind,
446-
Value *Src, Value *Start);
446+
LLVM_ABI Value *createOrderedReduction(IRBuilderBase &B, RecurKind RdxKind,
447+
Value *Src, Value *Start, Value *Mask,
448+
Value *EVL);
447449

448450
/// Get the intersection (logical and) of all of the potential IR flags
449451
/// of each scalar operation (VL) that will be converted into a vector (I).

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1333,18 +1333,20 @@ Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
13331333
}
13341334
}
13351335

1336-
Value *llvm::createSimpleReduction(VectorBuilder &VBuilder, Value *Src,
1337-
RecurKind Kind) {
1336+
Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
1337+
RecurKind Kind, Value *Mask, Value *EVL) {
13381338
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
13391339
!RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind) &&
13401340
"AnyOf or FindLastIV reductions are not supported.");
13411341
Intrinsic::ID Id = getReductionIntrinsicID(Kind);
1342-
auto *SrcTy = cast<VectorType>(Src->getType());
1343-
Type *SrcEltTy = SrcTy->getElementType();
1342+
auto VPID = VPIntrinsic::getForIntrinsic(Id);
1343+
assert(VPReductionIntrinsic::isVPReduction(VPID) &&
1344+
"No VPIntrinsic for this reduction");
1345+
auto *EltTy = cast<VectorType>(Src->getType())->getElementType();
13441346
Value *Iden =
1345-
getRecurrenceIdentity(Kind, SrcEltTy, VBuilder.getFastMathFlags());
1346-
Value *Ops[] = {Iden, Src};
1347-
return VBuilder.createSimpleReduction(Id, SrcTy, Ops);
1347+
getRecurrenceIdentity(Kind, EltTy, Builder.getFastMathFlags());
1348+
Value *Ops[] = {Iden, Src, Mask, EVL};
1349+
return Builder.CreateIntrinsic(EltTy, VPID, Ops);
13481350
}
13491351

13501352
Value *llvm::createOrderedReduction(IRBuilderBase &B, RecurKind Kind,
@@ -1357,17 +1359,21 @@ Value *llvm::createOrderedReduction(IRBuilderBase &B, RecurKind Kind,
13571359
return B.CreateFAddReduce(Start, Src);
13581360
}
13591361

1360-
Value *llvm::createOrderedReduction(VectorBuilder &VBuilder, RecurKind Kind,
1361-
Value *Src, Value *Start) {
1362+
Value *llvm::createOrderedReduction(IRBuilderBase &Builder, RecurKind Kind,
1363+
Value *Src, Value *Start, Value *Mask,
1364+
Value *EVL) {
13621365
assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
13631366
"Unexpected reduction kind");
13641367
assert(Src->getType()->isVectorTy() && "Expected a vector type");
13651368
assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
13661369

13671370
Intrinsic::ID Id = getReductionIntrinsicID(RecurKind::FAdd);
1368-
auto *SrcTy = cast<VectorType>(Src->getType());
1369-
Value *Ops[] = {Start, Src};
1370-
return VBuilder.createSimpleReduction(Id, SrcTy, Ops);
1371+
auto VPID = VPIntrinsic::getForIntrinsic(Id);
1372+
assert(VPReductionIntrinsic::isVPReduction(VPID) &&
1373+
"No VPIntrinsic for this reduction");
1374+
auto *EltTy = cast<VectorType>(Src->getType())->getElementType();
1375+
Value *Ops[] = {Start, Src, Mask, EVL};
1376+
return Builder.CreateIntrinsic(EltTy, VPID, Ops);
13711377
}
13721378

13731379
void llvm::propagateIRFlags(Value *I, ArrayRef<Value *> VL, Value *OpValue,

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2524,21 +2524,18 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
25242524
Value *VecOp = State.get(getVecOp());
25252525
Value *EVL = State.get(getEVL(), VPLane(0));
25262526

2527-
VectorBuilder VBuilder(Builder);
2528-
VBuilder.setEVL(EVL);
25292527
Value *Mask;
25302528
// TODO: move the all-true mask generation into VectorBuilder.
25312529
if (VPValue *CondOp = getCondOp())
25322530
Mask = State.get(CondOp);
25332531
else
25342532
Mask = Builder.CreateVectorSplat(State.VF, Builder.getTrue());
2535-
VBuilder.setMask(Mask);
25362533

25372534
Value *NewRed;
25382535
if (isOrdered()) {
2539-
NewRed = createOrderedReduction(VBuilder, Kind, VecOp, Prev);
2536+
NewRed = createOrderedReduction(Builder, Kind, VecOp, Prev, Mask, EVL);
25402537
} else {
2541-
NewRed = createSimpleReduction(VBuilder, VecOp, Kind);
2538+
NewRed = createSimpleReduction(Builder, VecOp, Kind, Mask, EVL);
25422539
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind))
25432540
NewRed = createMinMaxOp(Builder, Kind, NewRed, Prev);
25442541
else

0 commit comments

Comments
 (0)