@@ -56,7 +56,11 @@ private static void checkCallExpression(SubscriptionContext subscriptionContext)
56
56
return ;
57
57
}
58
58
59
- if (getAssignedName (callExpression ).map (SklearnPipelineSpecifyMemoryArgumentCheck ::isUsedInAnotherPipeline ).orElse (false )) {
59
+ boolean isUsedInAnotherPipeline = getAssignedName (callExpression )
60
+ .map (SklearnPipelineSpecifyMemoryArgumentCheck ::isUsedInAnotherPipeline )
61
+ .orElse (false );
62
+
63
+ if (isUsedInAnotherPipeline ) {
60
64
return ;
61
65
}
62
66
@@ -72,26 +76,44 @@ private static void createIssue(SubscriptionContext subscriptionContext, CallExp
72
76
issue .addQuickFix (quickFix );
73
77
}
74
78
75
- private static boolean isPipelineCreation (CallExpression callExpression ) {
76
- return Optional .ofNullable (callExpression .calleeSymbol ())
77
- .map (Symbol ::fullyQualifiedName )
78
- .map (fqn -> "sklearn.pipeline.Pipeline" .equals (fqn ) || "sklearn.pipeline.make_pipeline" .equals (fqn ))
79
- .orElse (false );
80
- }
81
-
82
79
private static boolean isUsedInAnotherPipeline (Name name ) {
83
80
Symbol symbol = name .symbol ();
84
81
return symbol != null && symbol .usages ().stream ().filter (usage -> !usage .isBindingUsage ()).anyMatch (u -> {
85
82
Tree tree = u .tree ();
86
83
CallExpression callExpression = (CallExpression ) TreeUtils .firstAncestorOfKind (tree , Tree .Kind .CALL_EXPR );
87
84
while (callExpression != null ) {
88
- Optional <String > fullyQualifiedName = Optional .ofNullable (callExpression .calleeSymbol ()).map (Symbol ::fullyQualifiedName );
89
- if (fullyQualifiedName .isPresent () && isPipelineCreation (callExpression )) {
85
+ if (isUsedBySklearnComposeEstimatorOrPipelineCreation (callExpression )) {
90
86
return true ;
91
87
}
92
88
callExpression = (CallExpression ) TreeUtils .firstAncestorOfKind (callExpression , Tree .Kind .CALL_EXPR );
93
89
}
94
90
return false ;
95
91
});
96
92
}
93
+
94
+ private static boolean isPipelineCreation (CallExpression callExpression ) {
95
+ return Optional .ofNullable (callExpression .calleeSymbol ())
96
+ .map (Symbol ::fullyQualifiedName )
97
+ .map (SklearnPipelineSpecifyMemoryArgumentCheck ::isFullyQualifiedNameAPipelineCreation )
98
+ .orElse (false );
99
+ }
100
+
101
+ private static boolean isUsedBySklearnComposeEstimatorOrPipelineCreation (CallExpression callExpression ) {
102
+ Symbol calleeSymbol = callExpression .calleeSymbol ();
103
+ if (calleeSymbol == null ) return false ;
104
+
105
+ String fqn = calleeSymbol .fullyQualifiedName ();
106
+ if (fqn == null ) return false ;
107
+
108
+ return isFullyQualifiedNameAPipelineCreation (fqn ) || isFullyQualifiedNameASklearnComposeEstimator (fqn );
109
+ }
110
+
111
+ private static boolean isFullyQualifiedNameAPipelineCreation (String fqn ) {
112
+ return "sklearn.pipeline.Pipeline" .equals (fqn ) || "sklearn.pipeline.make_pipeline" .equals (fqn );
113
+ }
114
+
115
+ private static boolean isFullyQualifiedNameASklearnComposeEstimator (String fqn ) {
116
+ return fqn .startsWith ("sklearn.compose." );
117
+ }
118
+
97
119
}
0 commit comments