-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[VPlan] Improve code in VPInstruction::generate (NFC) #169470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-llvm-transforms Author: Ramkumar Ramachandra (artagnon) ChangesFull diff: https://github.com/llvm/llvm-project/pull/169470.diff 1 Files Affected:
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 54fdec3bcf4a1..e73967e1e96dc 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -722,10 +722,9 @@ Value *VPInstruction::generate(VPTransformState &State) {
return Builder.CreateCmp(CmpInst::Predicate::ICMP_ULT, VIVElem0, ScalarTC,
Name);
- auto *Int1Ty = Type::getInt1Ty(Builder.getContext());
- auto PredTy = VectorType::get(
- Int1Ty, State.VF * cast<ConstantInt>(getOperand(2)->getLiveInIRValue())
- ->getZExtValue());
+ ElementCount EC = State.VF.multiplyCoefficientBy(
+ cast<ConstantInt>(getOperand(2)->getLiveInIRValue())->getZExtValue());
+ auto *PredTy = VectorType::get(Builder.getInt1Ty(), EC);
return Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask,
{PredTy, ScalarTC->getType()},
{VIVElem0, ScalarTC}, nullptr, Name);
@@ -755,7 +754,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
Value *Step = createStepForVF(Builder, ScalarTC->getType(), State.VF, UF);
Value *Sub = Builder.CreateSub(ScalarTC, Step);
Value *Cmp = Builder.CreateICmp(CmpInst::Predicate::ICMP_UGT, ScalarTC, Step);
- Value *Zero = ConstantInt::get(ScalarTC->getType(), 0);
+ Value *Zero = ConstantInt::getNullValue(ScalarTC->getType());
return Builder.CreateSelect(Cmp, Sub, Zero);
}
case VPInstruction::ExplicitVectorLength: {
@@ -767,11 +766,11 @@ Value *VPInstruction::generate(VPTransformState &State) {
"Requested vector length should be an integer.");
assert(State.VF.isScalable() && "Expected scalable vector factor.");
- Value *VFArg = State.Builder.getInt32(State.VF.getKnownMinValue());
+ Value *VFArg = Builder.getInt32(State.VF.getKnownMinValue());
- Value *EVL = State.Builder.CreateIntrinsic(
- State.Builder.getInt32Ty(), Intrinsic::experimental_get_vector_length,
- {AVL, VFArg, State.Builder.getTrue()});
+ Value *EVL = Builder.CreateIntrinsic(
+ Builder.getInt32Ty(), Intrinsic::experimental_get_vector_length,
+ {AVL, VFArg, Builder.getTrue()});
return EVL;
}
case VPInstruction::CanonicalIVIncrementForPart: {
@@ -808,8 +807,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
cast<StructType>(State.TypeAnalysis.inferScalarType(getOperand(0)));
Value *Res = PoisonValue::get(toVectorizedTy(StructTy, State.VF));
for (const auto &[LaneIndex, Op] : enumerate(operands())) {
- for (unsigned FieldIndex = 0; FieldIndex != StructTy->getNumElements();
- FieldIndex++) {
+ for (unsigned FieldIndex : seq<unsigned>(StructTy->getNumElements())) {
Value *ScalarValue =
Builder.CreateExtractValue(State.get(Op, true), FieldIndex);
Value *VectorValue = Builder.CreateExtractValue(Res, FieldIndex);
@@ -825,8 +823,8 @@ Value *VPInstruction::generate(VPTransformState &State) {
auto NumOfElements = ElementCount::getFixed(getNumOperands());
Value *Res = PoisonValue::get(toVectorizedTy(ScalarTy, NumOfElements));
for (const auto &[Idx, Op] : enumerate(operands()))
- Res = State.Builder.CreateInsertElement(Res, State.get(Op, true),
- State.Builder.getInt32(Idx));
+ Res = Builder.CreateInsertElement(Res, State.get(Op, true),
+ Builder.getInt32(Idx));
return Res;
}
case VPInstruction::ReductionStartVector: {
@@ -839,9 +837,8 @@ Value *VPInstruction::generate(VPTransformState &State) {
ElementCount VF = State.VF.divideCoefficientBy(
cast<ConstantInt>(getOperand(2)->getLiveInIRValue())->getZExtValue());
auto *Iden = Builder.CreateVectorSplat(VF, State.get(getOperand(1), true));
- Constant *Zero = Builder.getInt32(0);
return Builder.CreateInsertElement(Iden, State.get(getOperand(0), true),
- Zero);
+ Builder.getInt32(0));
}
case VPInstruction::ComputeAnyOfResult: {
// FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
@@ -849,7 +846,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
auto *PhiR = cast<VPReductionPHIRecipe>(getOperand(0));
auto *OrigPhi = cast<PHINode>(PhiR->getUnderlyingValue());
Value *ReducedPartRdx = State.get(getOperand(2));
- for (unsigned Idx = 3; Idx < getNumOperands(); ++Idx)
+ for (unsigned Idx : seq<unsigned>(3, getNumOperands()))
ReducedPartRdx =
Builder.CreateBinOp(Instruction::Or, State.get(getOperand(Idx)),
ReducedPartRdx, "bin.rdx");
@@ -877,7 +874,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
MinMaxKind = IsSigned ? RecurKind::SMax : RecurKind::UMax;
else
MinMaxKind = IsSigned ? RecurKind::SMin : RecurKind::UMin;
- for (unsigned Part = 1; Part < UF; ++Part)
+ for (unsigned Part : seq<unsigned>(1, UF))
ReducedPartRdx = createMinMaxOp(Builder, MinMaxKind, ReducedPartRdx,
State.get(getOperand(3 + Part)));
@@ -900,7 +897,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
// each part of the reduction.
unsigned UF = getNumOperands() - 1;
VectorParts RdxParts(UF);
- for (unsigned Part = 0; Part < UF; ++Part)
+ for (unsigned Part : seq<unsigned>(UF))
RdxParts[Part] = State.get(getOperand(1 + Part), PhiR->isInLoop());
IRBuilderBase::FastMathFlagGuard FMFG(Builder);
@@ -918,14 +915,12 @@ Value *VPInstruction::generate(VPTransformState &State) {
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RK))
ReducedPartRdx = createMinMaxOp(Builder, RK, ReducedPartRdx, RdxPart);
else {
- Instruction::BinaryOps Opcode;
// For sub-recurrences, each UF's reduction variable is already
// negative, we need to do: reduce.add(-acc_uf0 + -acc_uf1)
- if (RK == RecurKind::Sub)
- Opcode = Instruction::Add;
- else
- Opcode =
- (Instruction::BinaryOps)RecurrenceDescriptor::getOpcode(RK);
+ Instruction::BinaryOps Opcode =
+ RK == RecurKind::Sub
+ ? Instruction::Add
+ : (Instruction::BinaryOps)RecurrenceDescriptor::getOpcode(RK);
ReducedPartRdx =
Builder.CreateBinOp(Opcode, RdxPart, ReducedPartRdx, "bin.rdx");
}
@@ -990,7 +985,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
Value *LaneToExtract = State.get(getOperand(0), true);
Type *IdxTy = State.TypeAnalysis.inferScalarType(getOperand(0));
Value *Res = nullptr;
- Value *RuntimeVF = getRuntimeVF(State.Builder, IdxTy, State.VF);
+ Value *RuntimeVF = getRuntimeVF(Builder, IdxTy, State.VF);
for (unsigned Idx = 1; Idx != getNumOperands(); ++Idx) {
Value *VectorStart =
@@ -1020,8 +1015,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
// If there are multiple operands, create a chain of selects to pick the
// first operand with an active lane and add the number of lanes of the
// preceding operands.
- Value *RuntimeVF =
- getRuntimeVF(State.Builder, State.Builder.getInt64Ty(), State.VF);
+ Value *RuntimeVF = getRuntimeVF(Builder, Builder.getInt64Ty(), State.VF);
unsigned LastOpIdx = getNumOperands() - 1;
Value *Res = nullptr;
for (int Idx = LastOpIdx; Idx >= 0; --Idx) {
|
|
@llvm/pr-subscribers-vectorizers Author: Ramkumar Ramachandra (artagnon) ChangesFull diff: https://github.com/llvm/llvm-project/pull/169470.diff 1 Files Affected:
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 54fdec3bcf4a1..e73967e1e96dc 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -722,10 +722,9 @@ Value *VPInstruction::generate(VPTransformState &State) {
return Builder.CreateCmp(CmpInst::Predicate::ICMP_ULT, VIVElem0, ScalarTC,
Name);
- auto *Int1Ty = Type::getInt1Ty(Builder.getContext());
- auto PredTy = VectorType::get(
- Int1Ty, State.VF * cast<ConstantInt>(getOperand(2)->getLiveInIRValue())
- ->getZExtValue());
+ ElementCount EC = State.VF.multiplyCoefficientBy(
+ cast<ConstantInt>(getOperand(2)->getLiveInIRValue())->getZExtValue());
+ auto *PredTy = VectorType::get(Builder.getInt1Ty(), EC);
return Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask,
{PredTy, ScalarTC->getType()},
{VIVElem0, ScalarTC}, nullptr, Name);
@@ -755,7 +754,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
Value *Step = createStepForVF(Builder, ScalarTC->getType(), State.VF, UF);
Value *Sub = Builder.CreateSub(ScalarTC, Step);
Value *Cmp = Builder.CreateICmp(CmpInst::Predicate::ICMP_UGT, ScalarTC, Step);
- Value *Zero = ConstantInt::get(ScalarTC->getType(), 0);
+ Value *Zero = ConstantInt::getNullValue(ScalarTC->getType());
return Builder.CreateSelect(Cmp, Sub, Zero);
}
case VPInstruction::ExplicitVectorLength: {
@@ -767,11 +766,11 @@ Value *VPInstruction::generate(VPTransformState &State) {
"Requested vector length should be an integer.");
assert(State.VF.isScalable() && "Expected scalable vector factor.");
- Value *VFArg = State.Builder.getInt32(State.VF.getKnownMinValue());
+ Value *VFArg = Builder.getInt32(State.VF.getKnownMinValue());
- Value *EVL = State.Builder.CreateIntrinsic(
- State.Builder.getInt32Ty(), Intrinsic::experimental_get_vector_length,
- {AVL, VFArg, State.Builder.getTrue()});
+ Value *EVL = Builder.CreateIntrinsic(
+ Builder.getInt32Ty(), Intrinsic::experimental_get_vector_length,
+ {AVL, VFArg, Builder.getTrue()});
return EVL;
}
case VPInstruction::CanonicalIVIncrementForPart: {
@@ -808,8 +807,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
cast<StructType>(State.TypeAnalysis.inferScalarType(getOperand(0)));
Value *Res = PoisonValue::get(toVectorizedTy(StructTy, State.VF));
for (const auto &[LaneIndex, Op] : enumerate(operands())) {
- for (unsigned FieldIndex = 0; FieldIndex != StructTy->getNumElements();
- FieldIndex++) {
+ for (unsigned FieldIndex : seq<unsigned>(StructTy->getNumElements())) {
Value *ScalarValue =
Builder.CreateExtractValue(State.get(Op, true), FieldIndex);
Value *VectorValue = Builder.CreateExtractValue(Res, FieldIndex);
@@ -825,8 +823,8 @@ Value *VPInstruction::generate(VPTransformState &State) {
auto NumOfElements = ElementCount::getFixed(getNumOperands());
Value *Res = PoisonValue::get(toVectorizedTy(ScalarTy, NumOfElements));
for (const auto &[Idx, Op] : enumerate(operands()))
- Res = State.Builder.CreateInsertElement(Res, State.get(Op, true),
- State.Builder.getInt32(Idx));
+ Res = Builder.CreateInsertElement(Res, State.get(Op, true),
+ Builder.getInt32(Idx));
return Res;
}
case VPInstruction::ReductionStartVector: {
@@ -839,9 +837,8 @@ Value *VPInstruction::generate(VPTransformState &State) {
ElementCount VF = State.VF.divideCoefficientBy(
cast<ConstantInt>(getOperand(2)->getLiveInIRValue())->getZExtValue());
auto *Iden = Builder.CreateVectorSplat(VF, State.get(getOperand(1), true));
- Constant *Zero = Builder.getInt32(0);
return Builder.CreateInsertElement(Iden, State.get(getOperand(0), true),
- Zero);
+ Builder.getInt32(0));
}
case VPInstruction::ComputeAnyOfResult: {
// FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
@@ -849,7 +846,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
auto *PhiR = cast<VPReductionPHIRecipe>(getOperand(0));
auto *OrigPhi = cast<PHINode>(PhiR->getUnderlyingValue());
Value *ReducedPartRdx = State.get(getOperand(2));
- for (unsigned Idx = 3; Idx < getNumOperands(); ++Idx)
+ for (unsigned Idx : seq<unsigned>(3, getNumOperands()))
ReducedPartRdx =
Builder.CreateBinOp(Instruction::Or, State.get(getOperand(Idx)),
ReducedPartRdx, "bin.rdx");
@@ -877,7 +874,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
MinMaxKind = IsSigned ? RecurKind::SMax : RecurKind::UMax;
else
MinMaxKind = IsSigned ? RecurKind::SMin : RecurKind::UMin;
- for (unsigned Part = 1; Part < UF; ++Part)
+ for (unsigned Part : seq<unsigned>(1, UF))
ReducedPartRdx = createMinMaxOp(Builder, MinMaxKind, ReducedPartRdx,
State.get(getOperand(3 + Part)));
@@ -900,7 +897,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
// each part of the reduction.
unsigned UF = getNumOperands() - 1;
VectorParts RdxParts(UF);
- for (unsigned Part = 0; Part < UF; ++Part)
+ for (unsigned Part : seq<unsigned>(UF))
RdxParts[Part] = State.get(getOperand(1 + Part), PhiR->isInLoop());
IRBuilderBase::FastMathFlagGuard FMFG(Builder);
@@ -918,14 +915,12 @@ Value *VPInstruction::generate(VPTransformState &State) {
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RK))
ReducedPartRdx = createMinMaxOp(Builder, RK, ReducedPartRdx, RdxPart);
else {
- Instruction::BinaryOps Opcode;
// For sub-recurrences, each UF's reduction variable is already
// negative, we need to do: reduce.add(-acc_uf0 + -acc_uf1)
- if (RK == RecurKind::Sub)
- Opcode = Instruction::Add;
- else
- Opcode =
- (Instruction::BinaryOps)RecurrenceDescriptor::getOpcode(RK);
+ Instruction::BinaryOps Opcode =
+ RK == RecurKind::Sub
+ ? Instruction::Add
+ : (Instruction::BinaryOps)RecurrenceDescriptor::getOpcode(RK);
ReducedPartRdx =
Builder.CreateBinOp(Opcode, RdxPart, ReducedPartRdx, "bin.rdx");
}
@@ -990,7 +985,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
Value *LaneToExtract = State.get(getOperand(0), true);
Type *IdxTy = State.TypeAnalysis.inferScalarType(getOperand(0));
Value *Res = nullptr;
- Value *RuntimeVF = getRuntimeVF(State.Builder, IdxTy, State.VF);
+ Value *RuntimeVF = getRuntimeVF(Builder, IdxTy, State.VF);
for (unsigned Idx = 1; Idx != getNumOperands(); ++Idx) {
Value *VectorStart =
@@ -1020,8 +1015,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
// If there are multiple operands, create a chain of selects to pick the
// first operand with an active lane and add the number of lanes of the
// preceding operands.
- Value *RuntimeVF =
- getRuntimeVF(State.Builder, State.Builder.getInt64Ty(), State.VF);
+ Value *RuntimeVF = getRuntimeVF(Builder, Builder.getInt64Ty(), State.VF);
unsigned LastOpIdx = getNumOperands() - 1;
Value *Res = nullptr;
for (int Idx = LastOpIdx; Idx >= 0; --Idx) {
|
|
Please can you add more description to the commit message explaining the purpose of the PR and/or why it's improving the code? |
Done. |
SamTebbs33
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks sensible to me 👍
Make miscellaneous improvements including inlining some expressions and re-using the existing State.Builder reference.