Skip to content

Commit 92d62ef

Browse files
authored
Rule S6969 : "memory" parameter should be specified for Scikit-Learn Pipeline (#1772)
1 parent 8271f4e commit 92d62ef

File tree

10 files changed

+455
-45
lines changed

10 files changed

+455
-45
lines changed

python-checks/src/main/java/org/sonar/python/checks/CheckList.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ public static Iterable<Class> getChecks() {
350350
StrongCryptographicKeysCheck.class,
351351
SklearnCachedPipelineDontAccessTransformersCheck.class,
352352
SuperfluousCurlyBraceCheck.class,
353+
SklearnPipelineSpecifyMemoryArgumentCheck.class,
353354
TempFileCreationCheck.class,
354355
ImplicitlySkippedTestCheck.class,
355356
ToDoCommentCheck.class,

python-checks/src/main/java/org/sonar/python/checks/SklearnCachedPipelineDontAccessTransformersCheck.java

Lines changed: 2 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,8 @@
1919
*/
2020
package org.sonar.python.checks;
2121

22-
import java.util.ArrayList;
2322
import java.util.Collection;
2423
import java.util.HashMap;
25-
import java.util.List;
2624
import java.util.Map;
2725
import java.util.Optional;
2826
import java.util.stream.Stream;
@@ -32,7 +30,6 @@
3230
import org.sonar.plugins.python.api.SubscriptionContext;
3331
import org.sonar.plugins.python.api.quickfix.PythonQuickFix;
3432
import org.sonar.plugins.python.api.symbols.Symbol;
35-
import org.sonar.plugins.python.api.tree.AssignmentStatement;
3633
import org.sonar.plugins.python.api.tree.CallExpression;
3734
import org.sonar.plugins.python.api.tree.Expression;
3835
import org.sonar.plugins.python.api.tree.ListLiteral;
@@ -41,14 +38,13 @@
4138
import org.sonar.plugins.python.api.tree.RegularArgument;
4239
import org.sonar.plugins.python.api.tree.StringLiteral;
4340
import org.sonar.plugins.python.api.tree.Tree;
44-
import org.sonar.plugins.python.api.tree.Tuple;
45-
import org.sonar.plugins.python.api.tree.UnpackingExpression;
4641
import org.sonar.python.quickfix.TextEditUtils;
47-
import org.sonar.python.semantic.SymbolUtils;
4842
import org.sonar.python.tree.TreeUtils;
4943
import org.sonar.python.tree.TupleImpl;
5044
import org.sonar.python.types.InferredTypes;
5145

46+
import static org.sonar.python.checks.utils.Expressions.getAssignedName;
47+
5248
@Rule(key = "S6971")
5349
public class SklearnCachedPipelineDontAccessTransformersCheck extends PythonSubscriptionCheck {
5450

@@ -185,43 +181,4 @@ private static Optional<PipelineCreation> isPipelineCreation(CallExpression call
185181
return null;
186182
});
187183
}
188-
189-
private static Optional<Name> getAssignedName(Expression expression) {
190-
if (expression.is(Tree.Kind.NAME)) {
191-
return Optional.of((Name) expression);
192-
}
193-
if (expression.is(Tree.Kind.QUALIFIED_EXPR)) {
194-
return getAssignedName(((QualifiedExpression) expression).name());
195-
}
196-
197-
var assignment = (AssignmentStatement) TreeUtils.firstAncestorOfKind(expression, Tree.Kind.ASSIGNMENT_STMT);
198-
if (assignment == null) {
199-
return Optional.empty();
200-
}
201-
202-
var expressions = SymbolUtils.assignmentsLhs(assignment);
203-
if (expressions.size() != 1) {
204-
List<Expression> rhsExpressions = getExpressionsFromRhs(assignment.assignedValue());
205-
var rhsIndex = rhsExpressions.indexOf(expression);
206-
if (rhsIndex != -1) {
207-
return getAssignedName(expressions.get(rhsIndex));
208-
} else {
209-
return Optional.empty();
210-
}
211-
}
212-
213-
return getAssignedName(expressions.get(0));
214-
}
215-
216-
private static List<Expression> getExpressionsFromRhs(Expression rhs) {
217-
List<Expression> expressions = new ArrayList<>();
218-
if (rhs.is(Tree.Kind.TUPLE)) {
219-
expressions.addAll(((Tuple) rhs).elements());
220-
} else if (rhs.is(Tree.Kind.LIST_LITERAL)) {
221-
expressions.addAll(((ListLiteral) rhs).elements().expressions());
222-
} else if (rhs.is(Tree.Kind.UNPACKING_EXPR)) {
223-
return getExpressionsFromRhs(((UnpackingExpression) rhs).expression());
224-
}
225-
return expressions;
226-
}
227184
}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
* SonarQube Python Plugin
3+
* Copyright (C) 2011-2024 SonarSource SA
4+
* mailto:info AT sonarsource DOT com
5+
*
6+
* This program is free software; you can redistribute it and/or
7+
* modify it under the terms of the GNU Lesser General Public
8+
* License as published by the Free Software Foundation; either
9+
* version 3 of the License, or (at your option) any later version.
10+
*
11+
* This program is distributed in the hope that it will be useful,
12+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14+
* Lesser General Public License for more details.
15+
*
16+
* You should have received a copy of the GNU Lesser General Public License
17+
* along with this program; if not, write to the Free Software Foundation,
18+
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
19+
*/
20+
package org.sonar.python.checks;
21+
22+
import java.util.Optional;
23+
import org.sonar.check.Rule;
24+
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
25+
import org.sonar.plugins.python.api.SubscriptionContext;
26+
import org.sonar.plugins.python.api.quickfix.PythonQuickFix;
27+
import org.sonar.plugins.python.api.symbols.Symbol;
28+
import org.sonar.plugins.python.api.tree.CallExpression;
29+
import org.sonar.plugins.python.api.tree.Name;
30+
import org.sonar.plugins.python.api.tree.Tree;
31+
import org.sonar.python.quickfix.TextEditUtils;
32+
import org.sonar.python.tree.TreeUtils;
33+
34+
import static org.sonar.python.checks.utils.Expressions.getAssignedName;
35+
36+
@Rule(key = "S6969")
37+
public class SklearnPipelineSpecifyMemoryArgumentCheck extends PythonSubscriptionCheck {
38+
39+
public static final String MESSAGE = "Specify a memory argument for the pipeline.";
40+
public static final String MESSAGE_QUICKFIX = "Add the memory argument";
41+
42+
@Override
43+
public void initialize(Context context) {
44+
context.registerSyntaxNodeConsumer(Tree.Kind.CALL_EXPR, SklearnPipelineSpecifyMemoryArgumentCheck::checkCallExpression);
45+
}
46+
47+
private static void checkCallExpression(SubscriptionContext subscriptionContext) {
48+
Optional.of(subscriptionContext.syntaxNode())
49+
.map(CallExpression.class::cast)
50+
.filter(SklearnPipelineSpecifyMemoryArgumentCheck::isPipelineCreation)
51+
.ifPresent(
52+
callExpression -> {
53+
var memoryArgument = TreeUtils.argumentByKeyword("memory", callExpression.arguments());
54+
55+
if (memoryArgument != null) {
56+
return;
57+
}
58+
59+
if (getAssignedName(callExpression).map(SklearnPipelineSpecifyMemoryArgumentCheck::isUsedInAnotherPipeline).orElse(false)) {
60+
return;
61+
}
62+
63+
createIssue(subscriptionContext, callExpression);
64+
});
65+
}
66+
67+
private static void createIssue(SubscriptionContext subscriptionContext, CallExpression callExpression) {
68+
var issue = subscriptionContext.addIssue(callExpression.callee(), MESSAGE);
69+
var quickFix = PythonQuickFix.newQuickFix(MESSAGE_QUICKFIX)
70+
.addTextEdit(TextEditUtils.insertBefore(callExpression.rightPar(), ", memory=None"))
71+
.build();
72+
issue.addQuickFix(quickFix);
73+
}
74+
75+
private static boolean isPipelineCreation(CallExpression callExpression) {
76+
return Optional.ofNullable(callExpression.calleeSymbol())
77+
.map(Symbol::fullyQualifiedName)
78+
.map(fqn -> "sklearn.pipeline.Pipeline".equals(fqn) || "sklearn.pipeline.make_pipeline".equals(fqn))
79+
.orElse(false);
80+
}
81+
82+
private static boolean isUsedInAnotherPipeline(Name name) {
83+
Symbol symbol = name.symbol();
84+
return symbol != null && symbol.usages().stream().filter(usage -> !usage.isBindingUsage()).anyMatch(u -> {
85+
Tree tree = u.tree();
86+
CallExpression callExpression = (CallExpression) TreeUtils.firstAncestorOfKind(tree, Tree.Kind.CALL_EXPR);
87+
while (callExpression != null) {
88+
Optional<String> fullyQualifiedName = Optional.ofNullable(callExpression.calleeSymbol()).map(Symbol::fullyQualifiedName);
89+
if (fullyQualifiedName.isPresent() && isPipelineCreation(callExpression)) {
90+
return true;
91+
}
92+
callExpression = (CallExpression) TreeUtils.firstAncestorOfKind(callExpression, Tree.Kind.CALL_EXPR);
93+
}
94+
return false;
95+
});
96+
}
97+
}

python-checks/src/main/java/org/sonar/python/checks/utils/Expressions.java

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
*/
2020
package org.sonar.python.checks.utils;
2121

22+
import java.util.ArrayList;
2223
import java.util.Arrays;
2324
import java.util.Collections;
2425
import java.util.HashSet;
@@ -39,11 +40,14 @@
3940
import org.sonar.plugins.python.api.tree.Name;
4041
import org.sonar.plugins.python.api.tree.NumericLiteral;
4142
import org.sonar.plugins.python.api.tree.ParenthesizedExpression;
43+
import org.sonar.plugins.python.api.tree.QualifiedExpression;
4244
import org.sonar.plugins.python.api.tree.StringElement;
4345
import org.sonar.plugins.python.api.tree.StringLiteral;
4446
import org.sonar.plugins.python.api.tree.Tree;
4547
import org.sonar.plugins.python.api.tree.Tree.Kind;
4648
import org.sonar.plugins.python.api.tree.Tuple;
49+
import org.sonar.plugins.python.api.tree.UnpackingExpression;
50+
import org.sonar.python.semantic.SymbolUtils;
4751
import org.sonar.python.tree.TreeUtils;
4852

4953
public class Expressions {
@@ -302,7 +306,44 @@ private static EscapeSequence extractOctal(String value, int i) {
302306
return IGNORE;
303307
}
304308
}
309+
}
310+
311+
public static Optional<Name> getAssignedName(Expression expression) {
312+
if (expression.is(Tree.Kind.NAME)) {
313+
return Optional.of((Name) expression);
314+
}
315+
if (expression.is(Tree.Kind.QUALIFIED_EXPR)) {
316+
return getAssignedName(((QualifiedExpression) expression).name());
317+
}
318+
319+
var assignment = (AssignmentStatement) TreeUtils.firstAncestorOfKind(expression, Tree.Kind.ASSIGNMENT_STMT);
320+
if (assignment == null) {
321+
return Optional.empty();
322+
}
305323

324+
var expressions = SymbolUtils.assignmentsLhs(assignment);
325+
if (expressions.size() != 1) {
326+
List<Expression> rhsExpressions = getExpressionsFromRhs(assignment.assignedValue());
327+
var rhsIndex = rhsExpressions.stream().flatMap(TreeUtils::flattenTuples).toList().indexOf(expression);
328+
if (rhsIndex != -1) {
329+
return getAssignedName(expressions.get(rhsIndex));
330+
} else {
331+
return Optional.empty();
332+
}
333+
}
334+
335+
return getAssignedName(expressions.get(0));
306336
}
307337

338+
private static List<Expression> getExpressionsFromRhs(Expression rhs) {
339+
List<Expression> expressions = new ArrayList<>();
340+
if (rhs.is(Tree.Kind.TUPLE)) {
341+
expressions.addAll(((Tuple) rhs).elements());
342+
} else if (rhs.is(Tree.Kind.LIST_LITERAL)) {
343+
expressions.addAll(((ListLiteral) rhs).elements().expressions());
344+
} else if (rhs.is(Tree.Kind.UNPACKING_EXPR)) {
345+
return getExpressionsFromRhs(((UnpackingExpression) rhs).expression());
346+
}
347+
return expressions;
348+
}
308349
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
<p>This rule raises an issue when a Scikit-Learn Pipeline is created without specifying the <code>memory</code> argument.</p>
2+
<h2>Why is this an issue?</h2>
3+
<p>When the <code>memory</code> argument is not specified, the pipeline will recompute the transformers every time the pipeline is fitted. This can be
4+
time-consuming if the transformers are expensive to compute or if the dataset is large.</p>
5+
<p>However, if the intent is to recompute the transformers everytime, the memory argument should be set explicitly to <code>None</code>. This way the
6+
intention is clear.</p>
7+
<h2>How to fix it</h2>
8+
<p>Specify the <code>memory</code> argument when creating a Scikit-Learn Pipeline.</p>
9+
<h3>Code examples</h3>
10+
<h4>Noncompliant code example</h4>
11+
<pre data-diff-id="1" data-diff-type="noncompliant">
12+
from sklearn.pipeline import Pipeline
13+
from sklearn.preprocessing import StandardScaler
14+
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
15+
16+
pipeline = Pipeline([
17+
('scaler', StandardScaler()),
18+
('classifier', LinearDiscriminantAnalysis())
19+
]) # Noncompliant: the memory parameter is not provided
20+
</pre>
21+
<h4>Compliant solution</h4>
22+
<pre data-diff-id="1" data-diff-type="compliant">
23+
from sklearn.pipeline import Pipeline
24+
from sklearn.preprocessing import StandardScaler
25+
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
26+
27+
pipeline = Pipeline([
28+
('scaler', StandardScaler()),
29+
('classifier', LinearDiscriminantAnalysis())
30+
], memory="cache_folder") # Compliant
31+
</pre>
32+
<h3>Pitfalls</h3>
33+
<p>If the pipeline is used with different datasets, the cache may not be helpful and can consume a lot of space. This is true when using
34+
<code>sklearn.model_selection.HalvingGridSearchCV</code> or <code>sklearn.model_selection.HalvingRandomSearchCV</code> because the size of the dataset
35+
changes every iteration when using the default configuration.</p>
36+
<h2>Resources</h2>
37+
<h3>Documentation</h3>
38+
<ul>
39+
<li> Scikit-Learn documentation - <a
40+
href="https://scikit-learn.org/stable/modules/compose.html#caching-transformers-avoid-repeated-computation">Pipeline</a> </li>
41+
</ul>
42+
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
{
2+
"title": "\"memory\" parameter should be specified for Scikit-Learn Pipeline",
3+
"type": "CODE_SMELL",
4+
"status": "ready",
5+
"remediation": {
6+
"func": "Constant\/Issue",
7+
"constantCost": "5min"
8+
},
9+
"tags": [],
10+
"defaultSeverity": "Major",
11+
"ruleSpecification": "RSPEC-6969",
12+
"sqKey": "S6969",
13+
"scope": "All",
14+
"quickfix": "targeted",
15+
"code": {
16+
"impacts": {
17+
"RELIABILITY": "LOW"
18+
},
19+
"attribute": "EFFICIENT"
20+
}
21+
}

python-checks/src/main/resources/org/sonar/l10n/py/rules/python/Sonar_way_profile.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@
239239
"S6925",
240240
"S6928",
241241
"S6929",
242+
"S6969",
242243
"S6971"
243244
]
244245
}

0 commit comments

Comments
 (0)