Skip to content

Commit 60d1332

Browse files
committed
Init implement TU select
1 parent 2b4c13f commit 60d1332

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,6 +1266,9 @@ class VPInstruction : public VPRecipeWithIRFlags {
12661266
// operand). Only generates scalar values (either for the first lane only or
12671267
// for all lanes, depending on its uses).
12681268
PtrAdd,
1269+
// Selects elements from two vectors (second and third operand) based on a
1270+
// condition vector (first operand) and a pivot index (fourth operand).
1271+
MergeUntilPivot,
12691272
};
12701273

12711274
private:

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ bool VPRecipeBase::mayHaveSideEffects() const {
145145
case VPInstruction::FirstOrderRecurrenceSplice:
146146
case VPInstruction::LogicalAnd:
147147
case VPInstruction::PtrAdd:
148+
case VPInstruction::MergeUntilPivot:
148149
return false;
149150
default:
150151
return true;
@@ -668,7 +669,18 @@ Value *VPInstruction::generatePerPart(VPTransformState &State, unsigned Part) {
668669
}
669670
return NewPhi;
670671
}
672+
case VPInstruction::MergeUntilPivot: {
673+
assert(Part == 0 && "No unrolling expected for predicated vectorization.");
674+
Value *Cond = State.get(getOperand(0), Part);
675+
Value *OnTrue = State.get(getOperand(1), Part);
676+
Value *OnFalse = State.get(getOperand(2), Part);
677+
Value *Pivot = State.get(getOperand(3), VPIteration(0, 0));
678+
assert(Pivot->getType()->isIntegerTy() && "Pivot should be an integer.");
671679

680+
return Builder.CreateIntrinsic(Intrinsic::vp_merge, {OnTrue->getType()},
681+
{Cond, OnTrue, OnFalse, Pivot}, nullptr,
682+
Name);
683+
}
672684
default:
673685
llvm_unreachable("Unsupported opcode for instruction");
674686
}
@@ -759,6 +771,9 @@ bool VPInstruction::onlyFirstLaneUsed(const VPValue *Op) const {
759771
case VPInstruction::BranchOnCond:
760772
case VPInstruction::ResumePhi:
761773
return true;
774+
case VPInstruction::MergeUntilPivot:
775+
// Pivot must be an integer.
776+
return Op == getOperand(3);
762777
};
763778
llvm_unreachable("switch should return");
764779
}
@@ -777,6 +792,7 @@ bool VPInstruction::onlyFirstPartUsed(const VPValue *Op) const {
777792
case VPInstruction::BranchOnCount:
778793
case VPInstruction::BranchOnCond:
779794
case VPInstruction::CanonicalIVIncrementForPart:
795+
case VPInstruction::MergeUntilPivot:
780796
return true;
781797
};
782798
llvm_unreachable("switch should return");
@@ -843,6 +859,9 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
843859
case VPInstruction::PtrAdd:
844860
O << "ptradd";
845861
break;
862+
case VPInstruction::MergeUntilPivot:
863+
O << "merge-until-pivot";
864+
break;
846865
default:
847866
O << Instruction::getOpcodeName(getOpcode());
848867
}

0 commit comments

Comments
 (0)