@@ -200,11 +200,17 @@ struct VectorLayout {
200200static bool isStructAllVectors (Type *Ty) {
201201 if (!isa<StructType>(Ty))
202202 return false ;
203-
204- for (unsigned I = 0 ; I < Ty->getNumContainedTypes (); I++)
205- if (!isa<FixedVectorType>(Ty->getContainedType (I)))
203+ if (Ty->getNumContainedTypes () < 1 )
204+ return false ;
205+ FixedVectorType *VecTy = dyn_cast<FixedVectorType>(Ty->getContainedType (0 ));
206+ if (!VecTy)
207+ return false ;
208+ unsigned VecSize = VecTy->getNumElements ();
209+ for (unsigned I = 1 ; I < Ty->getNumContainedTypes (); I++) {
210+ VecTy = dyn_cast<FixedVectorType>(Ty->getContainedType (I));
211+ if (!VecTy || VecSize != VecTy->getNumElements ())
206212 return false ;
207-
213+ }
208214 return true ;
209215}
210216
@@ -679,8 +685,9 @@ bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {
679685bool ScalarizerVisitor::isTriviallyScalarizable (Intrinsic::ID ID) {
680686 if (isTriviallyVectorizable (ID))
681687 return true ;
688+ // TODO: investigate vectorizable frexp
682689 switch (ID) {
683- case Intrinsic::frexp:
690+ case Intrinsic::frexp:
684691 return true ;
685692 }
686693 return Intrinsic::isTargetIntrinsic (ID) &&
@@ -690,10 +697,10 @@ bool ScalarizerVisitor::isTriviallyScalarizable(Intrinsic::ID ID) {
690697// / If a call to a vector typed intrinsic function, split into a scalar call per
691698// / element if possible for the intrinsic.
692699bool ScalarizerVisitor::splitCall (CallInst &CI) {
693- Type* CallType = CI.getType ();
694- bool areAllVectors = isStructAllVectors (CallType);
695- std::optional<VectorSplit> VS;
696- if (areAllVectors )
700+ Type * CallType = CI.getType ();
701+ bool AreAllVectors = isStructAllVectors (CallType);
702+ std::optional<VectorSplit> VS;
703+ if (AreAllVectors )
697704 VS = getVectorSplit (CallType->getContainedType (0 ));
698705 else
699706 VS = getVectorSplit (CallType);
@@ -721,12 +728,12 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
721728 if (isVectorIntrinsicWithOverloadTypeAtArg (ID, -1 ))
722729 Tys.push_back (VS->SplitTy );
723730
724- if (areAllVectors ) {
725- Type* PrevType = CallType->getContainedType (0 );
726- Type* CallType = CI.getType ();
727- for (unsigned I = 1 ; I < CallType->getNumContainedTypes (); I++) {
728- Type* CurrType = cast<FixedVectorType>(CallType->getContainedType (I));
729- if (PrevType != CurrType) {
731+ if (AreAllVectors ) {
732+ Type * PrevType = CallType->getContainedType (0 );
733+ Type * CallType = CI.getType ();
734+ for (unsigned I = 1 ; I < CallType->getNumContainedTypes (); I++) {
735+ Type * CurrType = cast<FixedVectorType>(CallType->getContainedType (I));
736+ if (PrevType != CurrType) {
730737 std::optional<VectorSplit> CurrVS = getVectorSplit (CurrType);
731738 Tys.push_back (CurrVS->SplitTy );
732739 PrevType = CurrType;
@@ -1070,7 +1077,7 @@ bool ScalarizerVisitor::visitExtractValueInst(ExtractValueInst &EVI) {
10701077 ValueVector Res;
10711078 if (!isStructAllVectors (OpTy))
10721079 return false ;
1073- Type* VecType = cast<FixedVectorType>(OpTy->getContainedType (0 ));
1080+ Type * VecType = cast<FixedVectorType>(OpTy->getContainedType (0 ));
10741081 std::optional<VectorSplit> VS = getVectorSplit (VecType);
10751082 if (!VS)
10761083 return false ;
@@ -1084,7 +1091,7 @@ bool ScalarizerVisitor::visitExtractValueInst(ExtractValueInst &EVI) {
10841091 Op0[OpIdx], Index, EVI.getName () + " .elem" + std::to_string (Index));
10851092 Res.push_back (ResElem);
10861093 }
1087- // replaceUses(&EVI, Res);
1094+
10881095 gather (&EVI, Res, *VS);
10891096 return true ;
10901097}
0 commit comments