@@ -197,6 +197,11 @@ struct VectorLayout {
197197 uint64_t SplitSize = 0 ;
198198};
199199
200+ static bool isStructOfVectors (Type *Ty) {
201+ return isa<StructType>(Ty) && Ty->getNumContainedTypes () > 0 &&
202+ isa<FixedVectorType>(Ty->getContainedType (0 ));
203+ }
204+
200205// / Concatenate the given fragments to a single vector value of the type
201206// / described in @p VS.
202207static Value *concatenate (IRBuilder<> &Builder, ArrayRef<Value *> Fragments,
@@ -276,6 +281,7 @@ class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
276281 bool visitBitCastInst (BitCastInst &BCI);
277282 bool visitInsertElementInst (InsertElementInst &IEI);
278283 bool visitExtractElementInst (ExtractElementInst &EEI);
284+ bool visitExtractValueInst (ExtractValueInst &EVI);
279285 bool visitShuffleVectorInst (ShuffleVectorInst &SVI);
280286 bool visitPHINode (PHINode &PHI);
281287 bool visitLoadInst (LoadInst &LI);
@@ -552,7 +558,10 @@ void ScalarizerVisitor::transferMetadataAndIRFlags(Instruction *Op,
552558// Determine how Ty is split, if at all.
553559std::optional<VectorSplit> ScalarizerVisitor::getVectorSplit (Type *Ty) {
554560 VectorSplit Split;
555- Split.VecTy = dyn_cast<FixedVectorType>(Ty);
561+ if (isStructOfVectors (Ty))
562+ Split.VecTy = cast<FixedVectorType>(Ty->getContainedType (0 ));
563+ else
564+ Split.VecTy = dyn_cast<FixedVectorType>(Ty);
556565 if (!Split.VecTy )
557566 return {};
558567
@@ -1030,6 +1039,33 @@ bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
10301039 return true ;
10311040}
10321041
1042+ bool ScalarizerVisitor::visitExtractValueInst (ExtractValueInst &EVI) {
1043+ Value *Op = EVI.getOperand (0 );
1044+ Type *OpTy = Op->getType ();
1045+ ValueVector Res;
1046+ if (!isStructOfVectors (OpTy))
1047+ return false ;
1048+ // Note: isStructOfVectors is also used in getVectorSplit.
1049+ // The intent is to bail on this visit if it isn't a struct
1050+ // of vectors. Downside is that when it is true we do two
1051+ // isStructOfVectors calls.
1052+ std::optional<VectorSplit> VS = getVectorSplit (OpTy);
1053+ if (!VS)
1054+ return false ;
1055+ Scatterer Op0 = scatter (&EVI, Op, *VS);
1056+ assert (!EVI.getIndices ().empty () && " Make sure an index exists" );
1057+ // Note for our use case we only care about the top level index.
1058+ unsigned Index = EVI.getIndices ()[0 ];
1059+ for (unsigned OpIdx = 0 ; OpIdx < Op0.size (); ++OpIdx) {
1060+ Value *ResElem = Builder.CreateExtractValue (
1061+ Op0[OpIdx], Index, EVI.getName () + " .elem" + std::to_string (Index));
1062+ Res.push_back (ResElem);
1063+ }
1064+ // replaceUses(&EVI, Res);
1065+ gather (&EVI, Res, *VS);
1066+ return true ;
1067+ }
1068+
10331069bool ScalarizerVisitor::visitExtractElementInst (ExtractElementInst &EEI) {
10341070 std::optional<VectorSplit> VS = getVectorSplit (EEI.getOperand (0 )->getType ());
10351071 if (!VS)
@@ -1196,7 +1232,7 @@ bool ScalarizerVisitor::finish() {
11961232 if (!Op->use_empty ()) {
11971233 // The value is still needed, so recreate it using a series of
11981234 // insertelements and/or shufflevectors.
1199- Value *Res;
1235+ Value *Res = nullptr ;
12001236 if (auto *Ty = dyn_cast<FixedVectorType>(Op->getType ())) {
12011237 BasicBlock *BB = Op->getParent ();
12021238 IRBuilder<> Builder (Op);
@@ -1209,6 +1245,35 @@ bool ScalarizerVisitor::finish() {
12091245 Res = concatenate (Builder, CV, VS, Op->getName ());
12101246
12111247 Res->takeName (Op);
1248+ } else if (auto *Ty = dyn_cast<StructType>(Op->getType ())) {
1249+ BasicBlock *BB = Op->getParent ();
1250+ IRBuilder<> Builder (Op);
1251+ if (isa<PHINode>(Op))
1252+ Builder.SetInsertPoint (BB, BB->getFirstInsertionPt ());
1253+
1254+ // Iterate over each element in the struct
1255+ uint NumOfStructElements = Ty->getNumElements ();
1256+ SmallVector<ValueVector, 4 > ElemCV (NumOfStructElements);
1257+ for (unsigned I = 0 ; I < NumOfStructElements; ++I) {
1258+ for (auto *CVelem : CV) {
1259+ Value *Elem = Builder.CreateExtractValue (
1260+ CVelem, I, Op->getName () + " .elem" + std::to_string (I));
1261+ ElemCV[I].push_back (Elem);
1262+ }
1263+ }
1264+ Res = PoisonValue::get (Ty);
1265+ for (unsigned I = 0 ; I < NumOfStructElements; ++I) {
1266+ Type *ElemTy = Ty->getElementType (I);
1267+ assert (isa<FixedVectorType>(ElemTy) &&
1268+ " Only Structs of all FixedVectorType supported" );
1269+ VectorSplit VS = *getVectorSplit (ElemTy);
1270+ assert (VS.NumFragments == CV.size ());
1271+
1272+ Value *ConcatenatedVector =
1273+ concatenate (Builder, ElemCV[I], VS, Op->getName ());
1274+ Res = Builder.CreateInsertValue (Res, ConcatenatedVector, I,
1275+ Op->getName () + " .insert" );
1276+ }
12121277 } else {
12131278 assert (CV.size () == 1 && Op->getType () == CV[0 ]->getType ());
12141279 Res = CV[0 ];
0 commit comments