Skip to content

Commit 4591813

Browse files
authored
SONARPY-1834 check if pipeline is used by sklearn.compose estimator (#2034)
* SONARPY-1834 check if pipeline is used by sklearn.compose estimator * SONARPY-1834 small refactoring
1 parent 841f35c commit 4591813

File tree

2 files changed

+43
-10
lines changed

2 files changed

+43
-10
lines changed

python-checks/src/main/java/org/sonar/python/checks/SklearnPipelineSpecifyMemoryArgumentCheck.java

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,11 @@ private static void checkCallExpression(SubscriptionContext subscriptionContext)
5656
return;
5757
}
5858

59-
if (getAssignedName(callExpression).map(SklearnPipelineSpecifyMemoryArgumentCheck::isUsedInAnotherPipeline).orElse(false)) {
59+
boolean isUsedInAnotherPipeline = getAssignedName(callExpression)
60+
.map(SklearnPipelineSpecifyMemoryArgumentCheck::isUsedInAnotherPipeline)
61+
.orElse(false);
62+
63+
if (isUsedInAnotherPipeline) {
6064
return;
6165
}
6266

@@ -72,26 +76,44 @@ private static void createIssue(SubscriptionContext subscriptionContext, CallExp
7276
issue.addQuickFix(quickFix);
7377
}
7478

75-
private static boolean isPipelineCreation(CallExpression callExpression) {
76-
return Optional.ofNullable(callExpression.calleeSymbol())
77-
.map(Symbol::fullyQualifiedName)
78-
.map(fqn -> "sklearn.pipeline.Pipeline".equals(fqn) || "sklearn.pipeline.make_pipeline".equals(fqn))
79-
.orElse(false);
80-
}
81-
8279
private static boolean isUsedInAnotherPipeline(Name name) {
8380
Symbol symbol = name.symbol();
8481
return symbol != null && symbol.usages().stream().filter(usage -> !usage.isBindingUsage()).anyMatch(u -> {
8582
Tree tree = u.tree();
8683
CallExpression callExpression = (CallExpression) TreeUtils.firstAncestorOfKind(tree, Tree.Kind.CALL_EXPR);
8784
while (callExpression != null) {
88-
Optional<String> fullyQualifiedName = Optional.ofNullable(callExpression.calleeSymbol()).map(Symbol::fullyQualifiedName);
89-
if (fullyQualifiedName.isPresent() && isPipelineCreation(callExpression)) {
85+
if (isUsedBySklearnComposeEstimatorOrPipelineCreation(callExpression)) {
9086
return true;
9187
}
9288
callExpression = (CallExpression) TreeUtils.firstAncestorOfKind(callExpression, Tree.Kind.CALL_EXPR);
9389
}
9490
return false;
9591
});
9692
}
93+
94+
private static boolean isPipelineCreation(CallExpression callExpression) {
95+
return Optional.ofNullable(callExpression.calleeSymbol())
96+
.map(Symbol::fullyQualifiedName)
97+
.map(SklearnPipelineSpecifyMemoryArgumentCheck::isFullyQualifiedNameAPipelineCreation)
98+
.orElse(false);
99+
}
100+
101+
private static boolean isUsedBySklearnComposeEstimatorOrPipelineCreation(CallExpression callExpression) {
102+
Symbol calleeSymbol = callExpression.calleeSymbol();
103+
if(calleeSymbol == null) return false;
104+
105+
String fqn = calleeSymbol.fullyQualifiedName();
106+
if(fqn == null) return false;
107+
108+
return isFullyQualifiedNameAPipelineCreation(fqn) || isFullyQualifiedNameASklearnComposeEstimator(fqn);
109+
}
110+
111+
private static boolean isFullyQualifiedNameAPipelineCreation(String fqn) {
112+
return "sklearn.pipeline.Pipeline".equals(fqn) || "sklearn.pipeline.make_pipeline".equals(fqn);
113+
}
114+
115+
private static boolean isFullyQualifiedNameASklearnComposeEstimator(String fqn) {
116+
return fqn.startsWith("sklearn.compose.");
117+
}
118+
97119
}

python-checks/src/test/resources/checks/sklearn_pipeline_specify_memory_argument.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,17 @@ def __init__(self):
120120
#^^^^^^^^^^^^^
121121
r1 = make_pipeline(p5) # Noncompliant
122122

123+
def more_nested5():
124+
from sklearn.pipeline import Pipeline
125+
from sklearn.compose import ColumnTransformer, make_column_selector
126+
numeric_transformer = Pipeline(steps=[('scaler', StandardScaler())])
127+
128+
preprocessor = ColumnTransformer(transformers=[('num', numeric_transformer, list(X_train.columns.values))])
129+
clf = Pipeline(steps=[('preprocessor', preprocessor),
130+
('classifier', LogisticRegression())])
131+
132+
preprocessor = make_column_selector(transformers=[('num', clf, list(X_train.columns.values))])
133+
123134

124135
def other():
125136
from sklearn.pipeline import Pipeline, make_pipeline

0 commit comments

Comments
 (0)