@@ -197,6 +197,24 @@ struct VectorLayout {
197197 uint64_t SplitSize = 0 ;
198198};
199199
200+ static bool isStructOfMatchingFixedVectors (Type *Ty) {
201+ if (!isa<StructType>(Ty))
202+ return false ;
203+ unsigned StructSize = Ty->getNumContainedTypes ();
204+ if (StructSize < 1 )
205+ return false ;
206+ FixedVectorType *VecTy = dyn_cast<FixedVectorType>(Ty->getContainedType (0 ));
207+ if (!VecTy)
208+ return false ;
209+ unsigned VecSize = VecTy->getNumElements ();
210+ for (unsigned I = 1 ; I < StructSize; I++) {
211+ VecTy = dyn_cast<FixedVectorType>(Ty->getContainedType (I));
212+ if (!VecTy || VecSize != VecTy->getNumElements ())
213+ return false ;
214+ }
215+ return true ;
216+ }
217+
200218// / Concatenate the given fragments to a single vector value of the type
201219// / described in @p VS.
202220static Value *concatenate (IRBuilder<> &Builder, ArrayRef<Value *> Fragments,
@@ -276,6 +294,7 @@ class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
276294 bool visitBitCastInst (BitCastInst &BCI);
277295 bool visitInsertElementInst (InsertElementInst &IEI);
278296 bool visitExtractElementInst (ExtractElementInst &EEI);
297+ bool visitExtractValueInst (ExtractValueInst &EVI);
279298 bool visitShuffleVectorInst (ShuffleVectorInst &SVI);
280299 bool visitPHINode (PHINode &PHI);
281300 bool visitLoadInst (LoadInst &LI);
@@ -667,14 +686,26 @@ bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {
667686bool ScalarizerVisitor::isTriviallyScalarizable (Intrinsic::ID ID) {
668687 if (isTriviallyVectorizable (ID))
669688 return true ;
689+ // TODO: Move frexp to isTriviallyVectorizable.
690+ // https://github.com/llvm/llvm-project/issues/112408
691+ switch (ID) {
692+ case Intrinsic::frexp:
693+ return true ;
694+ }
670695 return Intrinsic::isTargetIntrinsic (ID) &&
671696 TTI->isTargetIntrinsicTriviallyScalarizable (ID);
672697}
673698
674699// / If a call to a vector typed intrinsic function, split into a scalar call per
675700// / element if possible for the intrinsic.
676701bool ScalarizerVisitor::splitCall (CallInst &CI) {
677- std::optional<VectorSplit> VS = getVectorSplit (CI.getType ());
702+ Type *CallType = CI.getType ();
703+ bool AreAllVectorsOfMatchingSize = isStructOfMatchingFixedVectors (CallType);
704+ std::optional<VectorSplit> VS;
705+ if (AreAllVectorsOfMatchingSize)
706+ VS = getVectorSplit (CallType->getContainedType (0 ));
707+ else
708+ VS = getVectorSplit (CallType);
678709 if (!VS)
679710 return false ;
680711
@@ -699,6 +730,23 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
699730 if (isVectorIntrinsicWithOverloadTypeAtArg (ID, -1 ))
700731 Tys.push_back (VS->SplitTy );
701732
733+ if (AreAllVectorsOfMatchingSize) {
734+ for (unsigned I = 1 ; I < CallType->getNumContainedTypes (); I++) {
735+ std::optional<VectorSplit> CurrVS =
736+ getVectorSplit (cast<FixedVectorType>(CallType->getContainedType (I)));
737+ // This case does not seem to happen, but it is possible for
738+ // VectorSplit.NumPacked >= NumElems. If that happens a VectorSplit
739+ // is not returned and we will bailout of handling this call.
740+ // The secondary bailout case is if NumPacked does not match.
741+ // This can happen if ScalarizeMinBits is not set to the default.
742+ // This means with certain ScalarizeMinBits intrinsics like frexp
743+ // will only scalarize when the struct elements have the same bitness.
744+ if (!CurrVS || CurrVS->NumPacked != VS->NumPacked )
745+ return false ;
746+ if (isVectorIntrinsicWithStructReturnOverloadAtField (ID, I))
747+ Tys.push_back (CurrVS->SplitTy );
748+ }
749+ }
702750 // Assumes that any vector type has the same number of elements as the return
703751 // vector type, which is true for all current intrinsics.
704752 for (unsigned I = 0 ; I != NumArgs; ++I) {
@@ -1030,6 +1078,31 @@ bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
10301078 return true ;
10311079}
10321080
1081+ bool ScalarizerVisitor::visitExtractValueInst (ExtractValueInst &EVI) {
1082+ Value *Op = EVI.getOperand (0 );
1083+ Type *OpTy = Op->getType ();
1084+ ValueVector Res;
1085+ if (!isStructOfMatchingFixedVectors (OpTy))
1086+ return false ;
1087+ Type *VecType = cast<FixedVectorType>(OpTy->getContainedType (0 ));
1088+ std::optional<VectorSplit> VS = getVectorSplit (VecType);
1089+ if (!VS)
1090+ return false ;
1091+ IRBuilder<> Builder (&EVI);
1092+ Scatterer Op0 = scatter (&EVI, Op, *VS);
1093+ assert (!EVI.getIndices ().empty () && " Make sure an index exists" );
1094+ // Note for our use case we only care about the top level index.
1095+ unsigned Index = EVI.getIndices ()[0 ];
1096+ for (unsigned OpIdx = 0 ; OpIdx < Op0.size (); ++OpIdx) {
1097+ Value *ResElem = Builder.CreateExtractValue (
1098+ Op0[OpIdx], Index, EVI.getName () + " .elem" + Twine (Index));
1099+ Res.push_back (ResElem);
1100+ }
1101+
1102+ gather (&EVI, Res, *VS);
1103+ return true ;
1104+ }
1105+
10331106bool ScalarizerVisitor::visitExtractElementInst (ExtractElementInst &EEI) {
10341107 std::optional<VectorSplit> VS = getVectorSplit (EEI.getOperand (0 )->getType ());
10351108 if (!VS)
@@ -1209,6 +1282,35 @@ bool ScalarizerVisitor::finish() {
12091282 Res = concatenate (Builder, CV, VS, Op->getName ());
12101283
12111284 Res->takeName (Op);
1285+ } else if (auto *Ty = dyn_cast<StructType>(Op->getType ())) {
1286+ BasicBlock *BB = Op->getParent ();
1287+ IRBuilder<> Builder (Op);
1288+ if (isa<PHINode>(Op))
1289+ Builder.SetInsertPoint (BB, BB->getFirstInsertionPt ());
1290+
1291+ // Iterate over each element in the struct
1292+ unsigned NumOfStructElements = Ty->getNumElements ();
1293+ SmallVector<ValueVector, 4 > ElemCV (NumOfStructElements);
1294+ for (unsigned I = 0 ; I < NumOfStructElements; ++I) {
1295+ for (auto *CVelem : CV) {
1296+ Value *Elem = Builder.CreateExtractValue (
1297+ CVelem, I, Op->getName () + " .elem" + Twine (I));
1298+ ElemCV[I].push_back (Elem);
1299+ }
1300+ }
1301+ Res = PoisonValue::get (Ty);
1302+ for (unsigned I = 0 ; I < NumOfStructElements; ++I) {
1303+ Type *ElemTy = Ty->getElementType (I);
1304+ assert (isa<FixedVectorType>(ElemTy) &&
1305+ " Only Structs of all FixedVectorType supported" );
1306+ VectorSplit VS = *getVectorSplit (ElemTy);
1307+ assert (VS.NumFragments == CV.size ());
1308+
1309+ Value *ConcatenatedVector =
1310+ concatenate (Builder, ElemCV[I], VS, Op->getName ());
1311+ Res = Builder.CreateInsertValue (Res, ConcatenatedVector, I,
1312+ Op->getName () + " .insert" );
1313+ }
12121314 } else {
12131315 assert (CV.size () == 1 && Op->getType () == CV[0 ]->getType ());
12141316 Res = CV[0 ];
0 commit comments