@@ -551,6 +551,11 @@ Value *VPInstruction::generate(VPTransformState &State) {
551551 }
552552 case Instruction::ExtractElement: {
553553 assert (State.VF .isVector () && " Only extract elements from vectors" );
554+ if (getOperand (1 )->isLiveIn ()) {
555+ unsigned IdxToExtract =
556+ cast<ConstantInt>(getOperand (1 )->getLiveInIRValue ())->getZExtValue ();
557+ return State.get (getOperand (0 ), VPLane (IdxToExtract));
558+ }
554559 Value *Vec = State.get (getOperand (0 ));
555560 Value *Idx = State.get (getOperand (1 ), /* IsScalar=*/ true );
556561 return Builder.CreateExtractElement (Vec, Idx, Name);
@@ -664,6 +669,34 @@ Value *VPInstruction::generate(VPTransformState &State) {
664669 return Builder.CreateVectorSplat (
665670 State.VF , State.get (getOperand (0 ), /* IsScalar*/ true ), " broadcast" );
666671 }
672+ case VPInstruction::BuildStructVector: {
673+ // For struct types, we need to build a new 'wide' struct type, where each
674+ // element is widened, i.e., we create a struct of vectors.
675+ auto *StructTy =
676+ cast<StructType>(State.TypeAnalysis .inferScalarType (getOperand (0 )));
677+ Value *Res = PoisonValue::get (toVectorizedTy (StructTy, State.VF ));
678+ for (const auto &[LaneIndex, Op] : enumerate(operands ())) {
679+ for (unsigned FieldIndex = 0 ; FieldIndex != StructTy->getNumElements ();
680+ FieldIndex++) {
681+ Value *ScalarValue =
682+ Builder.CreateExtractValue (State.get (Op, true ), FieldIndex);
683+ Value *VectorValue = Builder.CreateExtractValue (Res, FieldIndex);
684+ VectorValue =
685+ Builder.CreateInsertElement (VectorValue, ScalarValue, LaneIndex);
686+ Res = Builder.CreateInsertValue (Res, VectorValue, FieldIndex);
687+ }
688+ }
689+ return Res;
690+ }
691+ case VPInstruction::BuildVector: {
692+ auto *ScalarTy = State.TypeAnalysis .inferScalarType (getOperand (0 ));
693+ auto NumOfElements = ElementCount::getFixed (getNumOperands ());
694+ Value *Res = PoisonValue::get (toVectorizedTy (ScalarTy, NumOfElements));
695+ for (const auto &[Idx, Op] : enumerate(operands ()))
696+ Res = State.Builder .CreateInsertElement (Res, State.get (Op, true ),
697+ State.Builder .getInt32 (Idx));
698+ return Res;
699+ }
667700 case VPInstruction::ReductionStartVector: {
668701 if (State.VF .isScalar ())
669702 return State.get (getOperand (0 ), true );
@@ -953,10 +986,11 @@ void VPInstruction::execute(VPTransformState &State) {
953986 if (!hasResult ())
954987 return ;
955988 assert (GeneratedValue && " generate must produce a value" );
956- assert (
957- (GeneratedValue->getType ()->isVectorTy () == !GeneratesPerFirstLaneOnly ||
958- State.VF .isScalar ()) &&
959- " scalar value but not only first lane defined" );
989+ assert ((((GeneratedValue->getType ()->isVectorTy () ||
990+ GeneratedValue->getType ()->isStructTy ()) ==
991+ !GeneratesPerFirstLaneOnly) ||
992+ State.VF .isScalar ()) &&
993+ " scalar value but not only first lane defined" );
960994 State.set (this , GeneratedValue,
961995 /* IsScalar*/ GeneratesPerFirstLaneOnly);
962996}
@@ -970,6 +1004,8 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
9701004 case Instruction::ICmp:
9711005 case Instruction::Select:
9721006 case VPInstruction::AnyOf:
1007+ case VPInstruction::BuildStructVector:
1008+ case VPInstruction::BuildVector:
9731009 case VPInstruction::CalculateTripCountMinusVF:
9741010 case VPInstruction::CanonicalIVIncrementForPart:
9751011 case VPInstruction::ExtractLastElement:
@@ -1092,6 +1128,12 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
10921128 case VPInstruction::Broadcast:
10931129 O << " broadcast" ;
10941130 break ;
1131+ case VPInstruction::BuildStructVector:
1132+ O << " buildstructvector" ;
1133+ break ;
1134+ case VPInstruction::BuildVector:
1135+ O << " buildvector" ;
1136+ break ;
10951137 case VPInstruction::ExtractLastElement:
10961138 O << " extract-last-element" ;
10971139 break ;
@@ -2686,45 +2728,27 @@ static void scalarizeInstruction(const Instruction *Instr,
26862728
26872729void VPReplicateRecipe::execute (VPTransformState &State) {
26882730 Instruction *UI = getUnderlyingInstr ();
2689- if (State.Lane ) { // Generate a single instance.
2690- assert ((State.VF .isScalar () || !isSingleScalar ()) &&
2691- " uniform recipe shouldn't be predicated" );
2692- assert (!State.VF .isScalable () && " Can't scalarize a scalable vector" );
2693- scalarizeInstruction (UI, this , *State.Lane , State);
2694- // Insert scalar instance packing it into a vector.
2695- if (State.VF .isVector () && shouldPack ()) {
2696- Value *WideValue;
2697- // If we're constructing lane 0, initialize to start from poison.
2698- if (State.Lane ->isFirstLane ()) {
2699- assert (!State.VF .isScalable () && " VF is assumed to be non scalable." );
2700- WideValue = PoisonValue::get (VectorType::get (UI->getType (), State.VF ));
2701- } else {
2702- WideValue = State.get (this );
2703- }
2704- State.set (this , State.packScalarIntoVectorizedValue (this , WideValue,
2705- *State.Lane ));
2706- }
2707- return ;
2708- }
27092731
2710- if (IsSingleScalar) {
2711- // Uniform within VL means we need to generate lane 0.
2732+ if (!State.Lane ) {
2733+ assert (IsSingleScalar && " VPReplicateRecipes outside replicate regions "
2734+ " must have already been unrolled" );
27122735 scalarizeInstruction (UI, this , VPLane (0 ), State);
27132736 return ;
27142737 }
27152738
2716- // A store of a loop varying value to a uniform address only needs the last
2717- // copy of the store.
2718- if (isa<StoreInst>(UI) && vputils::isSingleScalar (getOperand (1 ))) {
2719- auto Lane = VPLane::getLastLaneForVF (State.VF );
2720- scalarizeInstruction (UI, this , VPLane (Lane), State);
2721- return ;
2739+ assert ((State.VF .isScalar () || !isSingleScalar ()) &&
2740+ " uniform recipe shouldn't be predicated" );
2741+ assert (!State.VF .isScalable () && " Can't scalarize a scalable vector" );
2742+ scalarizeInstruction (UI, this , *State.Lane , State);
2743+ // Insert scalar instance packing it into a vector.
2744+ if (State.VF .isVector () && shouldPack ()) {
2745+ Value *WideValue =
2746+ State.Lane ->isFirstLane ()
2747+ ? PoisonValue::get (VectorType::get (UI->getType (), State.VF ))
2748+ : State.get (this );
2749+ State.set (this , State.packScalarIntoVectorizedValue (this , WideValue,
2750+ *State.Lane ));
27222751 }
2723-
2724- // Generate scalar instances for all VF lanes.
2725- const unsigned EndLane = State.VF .getFixedValue ();
2726- for (unsigned Lane = 0 ; Lane < EndLane; ++Lane)
2727- scalarizeInstruction (UI, this , VPLane (Lane), State);
27282752}
27292753
27302754bool VPReplicateRecipe::shouldPack () const {
0 commit comments