@@ -108,8 +108,15 @@ static bool isNeg(Value *V);
108108static Value *getNegOperand (Value *V);
109109
110110namespace {
111+ template <typename T, typename IterT>
112+ std::optional<T> findCommonBetweenCollections (IterT A, IterT B) {
113+ auto Common = llvm::find_if (A, [B](T I){return llvm::is_contained (B, I);});
114+ if (Common != A.end ())
115+ return std::make_optional (*Common);
116+ return std::nullopt ;
117+ }
111118
112- class ComplexDeinterleavingLegacyPass : public FunctionPass {
119+ class ComplexDeinterleavingLegacyPass : public FunctionPass {
113120public:
114121 static char ID;
115122
@@ -337,7 +344,7 @@ class ComplexDeinterleavingGraph {
337344 NodePtr identifyPartialReduction (Value *R, Value *I);
338345 NodePtr identifyDotProduct (Value *Inst);
339346
340- NodePtr identifyNode (Value *R, Value *I, bool *FromCache = nullptr );
347+ NodePtr identifyNode (Value *R, Value *I);
341348
342349 // / Determine if a sum of complex numbers can be formed from \p RealAddends
343350 // / and \p ImagAddens. If \p Accumulator is not null, add the result to it.
@@ -902,16 +909,16 @@ ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
902909
903910ComplexDeinterleavingGraph::NodePtr
904911ComplexDeinterleavingGraph::identifyDotProduct (Value *V) {
905- auto *Inst = cast<Instruction>(V);
906912
907913 if (!TL->isComplexDeinterleavingOperationSupported (
908- ComplexDeinterleavingOperation::CDot, Inst ->getType ())) {
914+ ComplexDeinterleavingOperation::CDot, V ->getType ())) {
909915 LLVM_DEBUG (dbgs () << " Target doesn't support complex deinterleaving "
910916 " operation CDot with the type "
911- << *Inst ->getType () << " \n " );
917+ << *V ->getType () << " \n " );
912918 return nullptr ;
913919 }
914920
921+ auto *Inst = cast<Instruction>(V);
915922 auto *RealUser = cast<Instruction>(*Inst->user_begin ());
916923
917924 NodePtr CN =
@@ -987,13 +994,26 @@ ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
987994 BReal = UnwrapCast (BReal);
988995 BImag = UnwrapCast (BImag);
989996
990- bool WasANodeFromCache = false ;
991- NodePtr Node = identifyNode (AReal, AImag, &WasANodeFromCache);
997+ VectorType *VTy = cast<VectorType>(V->getType ());
998+ Type *ExpectedOperandTy = VectorType::getSubdividedVectorType (VTy, 2 );
999+ if (AReal->getType () != ExpectedOperandTy)
1000+ return nullptr ;
1001+ if (AImag->getType () != ExpectedOperandTy)
1002+ return nullptr ;
1003+ if (BReal->getType () != ExpectedOperandTy)
1004+ return nullptr ;
1005+ if (BImag->getType () != ExpectedOperandTy)
1006+ return nullptr ;
1007+
1008+ if (Phi->getType () != VTy && RealUser->getType () != VTy)
1009+ return nullptr ;
1010+
1011+ NodePtr Node = identifyNode (AReal, AImag);
9921012
9931013 // In the case that a node was identified to figure out the rotation, ensure
9941014 // that trying to identify a node with AReal and AImag post-unwrap results in
9951015 // the same node
996- if (Node && ANode && !WasANodeFromCache ) {
1016+ if (ANode && Node != ANode ) {
9971017 LLVM_DEBUG (
9981018 dbgs ()
9991019 << " Identified node is different from previously identified node. "
@@ -1010,38 +1030,17 @@ ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
10101030
10111031ComplexDeinterleavingGraph::NodePtr
10121032ComplexDeinterleavingGraph::identifyPartialReduction (Value *R, Value *I) {
1013- if (!I->hasOneUser ())
1033+ // Partial reductions don't support non-vector types, so check these first
1034+ if (!isa<VectorType>(R->getType ()) || !isa<VectorType>(I->getType ()))
10141035 return nullptr ;
10151036
1016- VectorType *RealTy = dyn_cast<VectorType>(R->getType ());
1017- if (!RealTy)
1018- return nullptr ;
1019- VectorType *ImagTy = dyn_cast<VectorType>(I->getType ());
1020- if (!ImagTy)
1021- return nullptr ;
1022-
1023- if (RealTy->isScalableTy () != ImagTy->isScalableTy ())
1024- return nullptr ;
1025- if (RealTy->getElementType () != ImagTy->getElementType ())
1026- return nullptr ;
1027-
1028- // `I` is known to only have one user, so iterate over the Phi (R) users to
1029- // find the common user between R and I
1030- auto *CommonUser = *I->user_begin ();
1031- bool CommonUserFound = false ;
1032- for (auto *User : R->users ()) {
1033- if (User == CommonUser) {
1034- CommonUserFound = true ;
1035- break ;
1036- }
1037- }
1038-
1039- if (!CommonUserFound)
1037+ auto CommonUser = findCommonBetweenCollections<Value*>(R->users (), I->users ());
1038+ if (!CommonUser)
10401039 return nullptr ;
10411040
1042- auto *IInst = dyn_cast<IntrinsicInst>(CommonUser);
1041+ auto *IInst = dyn_cast<IntrinsicInst>(* CommonUser);
10431042 if (!IInst || IInst->getIntrinsicID () !=
1044- Intrinsic::experimental_vector_partial_reduce_add)
1043+ Intrinsic::experimental_vector_partial_reduce_add)
10451044 return nullptr ;
10461045
10471046 if (NodePtr CN = identifyDotProduct (IInst))
@@ -1051,12 +1050,10 @@ ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) {
10511050}
10521051
10531052ComplexDeinterleavingGraph::NodePtr
1054- ComplexDeinterleavingGraph::identifyNode (Value *R, Value *I, bool *FromCache ) {
1053+ ComplexDeinterleavingGraph::identifyNode (Value *R, Value *I) {
10551054 auto It = CachedResult.find ({R, I});
10561055 if (It != CachedResult.end ()) {
10571056 LLVM_DEBUG (dbgs () << " - Folding to existing node\n " );
1058- if (FromCache != nullptr )
1059- *FromCache = true ;
10601057 return It->second ;
10611058 }
10621059
0 commit comments