Skip to content

Commit 44bd087

Browse files
committed
Pass both input types to target
1 parent 6540bc6 commit 44bd087

File tree

6 files changed

+39
-31
lines changed

6 files changed

+39
-31
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,10 +1291,12 @@ class TargetTransformInfo {
12911291
/// represented by the llvm.experimental.partial.reduce.add intrinsic, which
12921292
/// takes an accumulator and a binary operation operand that itself is fed by
12931293
/// two extends. An example of an operation that uses a partial reduction is a
1294-
/// dot product, which reduces a vector to another of 4 times fewer elements.
1294+
/// dot product, which reduces two vectors to another of 4 times fewer and 4
1295+
/// times larger elements.
12951296
InstructionCost
1296-
getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
1297-
ElementCount VF, PartialReductionExtendKind OpAExtend,
1297+
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
1298+
Type *AccumType, ElementCount VF,
1299+
PartialReductionExtendKind OpAExtend,
12981300
PartialReductionExtendKind OpBExtend,
12991301
std::optional<unsigned> BinOp = std::nullopt) const;
13001302

@@ -2130,10 +2132,12 @@ class TargetTransformInfo::Concept {
21302132
/// represented by the llvm.experimental.partial.reduce.add intrinsic, which
21312133
/// takes an accumulator and a binary operation operand that itself is fed by
21322134
/// two extends. An example of an operation that uses a partial reduction is a
2133-
/// dot product, which reduces a vector to another of 4 times fewer elements.
2135+
/// dot product, which reduces two vectors to another of 4 times fewer and 4
2136+
/// times larger elements.
21342137
virtual InstructionCost
2135-
getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
2136-
ElementCount VF, PartialReductionExtendKind OpAExtend,
2138+
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
2139+
Type *AccumType, ElementCount VF,
2140+
PartialReductionExtendKind OpAExtend,
21372141
PartialReductionExtendKind OpBExtend,
21382142
std::optional<unsigned> BinOp) const = 0;
21392143

@@ -2817,12 +2821,13 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
28172821
}
28182822

28192823
InstructionCost getPartialReductionCost(
2820-
unsigned Opcode, Type *InputType, Type *AccumType, ElementCount VF,
2821-
PartialReductionExtendKind OpAExtend,
2824+
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
2825+
ElementCount VF, PartialReductionExtendKind OpAExtend,
28222826
PartialReductionExtendKind OpBExtend,
28232827
std::optional<unsigned> BinOp = std::nullopt) const override {
2824-
return Impl.getPartialReductionCost(Opcode, InputType, AccumType, VF,
2825-
OpAExtend, OpBExtend, BinOp);
2828+
return Impl.getPartialReductionCost(Opcode, InputTypeA, InputTypeB,
2829+
AccumType, VF, OpAExtend, OpBExtend,
2830+
BinOp);
28262831
}
28272832

28282833
unsigned getMaxInterleaveFactor(ElementCount VF) override {

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -586,8 +586,8 @@ class TargetTransformInfoImplBase {
586586
bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; }
587587

588588
InstructionCost
589-
getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
590-
ElementCount VF,
589+
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
590+
Type *AccumType, ElementCount VF,
591591
TTI::PartialReductionExtendKind OpAExtend,
592592
TTI::PartialReductionExtendKind OpBExtend,
593593
std::optional<unsigned> BinOp = std::nullopt) const {

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -864,11 +864,12 @@ bool TargetTransformInfo::shouldPrefetchAddressSpace(unsigned AS) const {
864864
}
865865

866866
InstructionCost TargetTransformInfo::getPartialReductionCost(
867-
unsigned Opcode, Type *InputType, Type *AccumType, ElementCount VF,
868-
PartialReductionExtendKind OpAExtend, PartialReductionExtendKind OpBExtend,
869-
std::optional<unsigned> BinOp) const {
870-
return TTIImpl->getPartialReductionCost(Opcode, InputType, AccumType, VF,
871-
OpAExtend, OpBExtend, BinOp);
867+
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
868+
ElementCount VF, PartialReductionExtendKind OpAExtend,
869+
PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp) const {
870+
return TTIImpl->getPartialReductionCost(Opcode, InputTypeA, InputTypeB,
871+
AccumType, VF, OpAExtend, OpBExtend,
872+
BinOp);
872873
}
873874

874875
unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -359,8 +359,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
359359
}
360360

361361
InstructionCost
362-
getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
363-
ElementCount VF,
362+
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
363+
Type *AccumType, ElementCount VF,
364364
TTI::PartialReductionExtendKind OpAExtend,
365365
TTI::PartialReductionExtendKind OpBExtend,
366366
std::optional<unsigned> BinOp) const {
@@ -371,7 +371,10 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
371371
if (Opcode != Instruction::Add)
372372
return Invalid;
373373

374-
EVT InputEVT = EVT::getEVT(InputType);
374+
if (InputTypeA != InputTypeB)
375+
return Invalid;
376+
377+
EVT InputEVT = EVT::getEVT(InputTypeA);
375378
EVT AccumEVT = EVT::getEVT(AccumType);
376379

377380
if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
@@ -411,7 +414,7 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
411414
!ST->isSVEorStreamingSVEAvailable()))
412415
return Invalid;
413416

414-
if (!BinOp || (*BinOp) != Instruction::Mul)
417+
if (!BinOp || *BinOp != Instruction::Mul)
415418
return Invalid;
416419

417420
return Cost;

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8824,10 +8824,6 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
88248824
Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
88258825
Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
88268826

8827-
// Check that the extends extend from the same type.
8828-
if (A->getType() != B->getType())
8829-
return std::nullopt;
8830-
88318827
TTI::PartialReductionExtendKind OpAExtend =
88328828
TargetTransformInfo::getPartialReductionExtendKind(ExtA);
88338829
TTI::PartialReductionExtendKind OpBExtend =
@@ -8842,8 +8838,9 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
88428838
if (LoopVectorizationPlanner::getDecisionAndClampRange(
88438839
[&](ElementCount VF) {
88448840
InstructionCost Cost = TTI->getPartialReductionCost(
8845-
Update->getOpcode(), A->getType(), PHI->getType(), VF,
8846-
OpAExtend, OpBExtend, std::make_optional(BinOp->getOpcode()));
8841+
Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
8842+
VF, OpAExtend, OpBExtend,
8843+
std::make_optional(BinOp->getOpcode()));
88478844
return Cost.isValid();
88488845
},
88498846
Range))

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,10 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
304304
VPRecipeBase *ExtBR = BinOpR->getOperand(1)->getDefiningRecipe();
305305

306306
auto *PhiType = Ctx.Types.inferScalarType(getOperand(1));
307-
auto *ExtTy = Ctx.Types.inferScalarType(ExtAR ? ExtAR->getOperand(0)
308-
: BinOpR->getOperand(0));
307+
auto *InputTypeA = Ctx.Types.inferScalarType(ExtAR ? ExtAR->getOperand(0)
308+
: BinOpR->getOperand(0));
309+
auto *InputTypeB = Ctx.Types.inferScalarType(ExtBR ? ExtBR->getOperand(0)
310+
: BinOpR->getOperand(1));
309311

310312
auto GetExtendKind = [](VPRecipeBase *R) {
311313
// The extend could come from outside the plan.
@@ -321,8 +323,8 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
321323
return TargetTransformInfo::PR_None;
322324
};
323325

324-
return Ctx.TTI.getPartialReductionCost(getOpcode(), ExtTy, PhiType, VF,
325-
GetExtendKind(ExtAR),
326+
return Ctx.TTI.getPartialReductionCost(getOpcode(), InputTypeA, InputTypeB,
327+
PhiType, VF, GetExtendKind(ExtAR),
326328
GetExtendKind(ExtBR), Opcode);
327329
}
328330

0 commit comments

Comments
 (0)