Skip to content

Commit 8271f4e

Browse files
authored
SONARPY-1770 Rule S6971 : Transformers should not be accessed directly when a Scikit-Learn Pipeline uses caching (#1779)
1 parent 2392e84 commit 8271f4e

File tree

7 files changed

+466
-1
lines changed

7 files changed

+466
-1
lines changed

python-checks/src/main/java/org/sonar/python/checks/CheckList.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ public static Iterable<Class> getChecks() {
348348
StringLiteralDuplicationCheck.class,
349349
StringReplaceCheck.class,
350350
StrongCryptographicKeysCheck.class,
351+
SklearnCachedPipelineDontAccessTransformersCheck.class,
351352
SuperfluousCurlyBraceCheck.class,
352353
TempFileCreationCheck.class,
353354
ImplicitlySkippedTestCheck.class,
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
/*
2+
* SonarQube Python Plugin
3+
* Copyright (C) 2011-2024 SonarSource SA
4+
* mailto:info AT sonarsource DOT com
5+
*
6+
* This program is free software; you can redistribute it and/or
7+
* modify it under the terms of the GNU Lesser General Public
8+
* License as published by the Free Software Foundation; either
9+
* version 3 of the License, or (at your option) any later version.
10+
*
11+
* This program is distributed in the hope that it will be useful,
12+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14+
* Lesser General Public License for more details.
15+
*
16+
* You should have received a copy of the GNU Lesser General Public License
17+
* along with this program; if not, write to the Free Software Foundation,
18+
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
19+
*/
20+
package org.sonar.python.checks;
21+
22+
import java.util.ArrayList;
23+
import java.util.Collection;
24+
import java.util.HashMap;
25+
import java.util.List;
26+
import java.util.Map;
27+
import java.util.Optional;
28+
import java.util.stream.Stream;
29+
import javax.annotation.Nullable;
30+
import org.sonar.check.Rule;
31+
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
32+
import org.sonar.plugins.python.api.SubscriptionContext;
33+
import org.sonar.plugins.python.api.quickfix.PythonQuickFix;
34+
import org.sonar.plugins.python.api.symbols.Symbol;
35+
import org.sonar.plugins.python.api.tree.AssignmentStatement;
36+
import org.sonar.plugins.python.api.tree.CallExpression;
37+
import org.sonar.plugins.python.api.tree.Expression;
38+
import org.sonar.plugins.python.api.tree.ListLiteral;
39+
import org.sonar.plugins.python.api.tree.Name;
40+
import org.sonar.plugins.python.api.tree.QualifiedExpression;
41+
import org.sonar.plugins.python.api.tree.RegularArgument;
42+
import org.sonar.plugins.python.api.tree.StringLiteral;
43+
import org.sonar.plugins.python.api.tree.Tree;
44+
import org.sonar.plugins.python.api.tree.Tuple;
45+
import org.sonar.plugins.python.api.tree.UnpackingExpression;
46+
import org.sonar.python.quickfix.TextEditUtils;
47+
import org.sonar.python.semantic.SymbolUtils;
48+
import org.sonar.python.tree.TreeUtils;
49+
import org.sonar.python.tree.TupleImpl;
50+
import org.sonar.python.types.InferredTypes;
51+
52+
@Rule(key = "S6971")
53+
public class SklearnCachedPipelineDontAccessTransformersCheck extends PythonSubscriptionCheck {
54+
55+
public static final String MESSAGE = "Avoid accessing transformers in a cached pipeline.";
56+
public static final String MESSAGE_SECONDARY = "The transformer is accessed here";
57+
public static final String MESSAGE_SECONDARY_CREATION = "The Pipeline is created here";
58+
59+
@Override
60+
public void initialize(Context context) {
61+
context.registerSyntaxNodeConsumer(Tree.Kind.CALL_EXPR, SklearnCachedPipelineDontAccessTransformersCheck::checkCallExpr);
62+
}
63+
64+
private static void checkCallExpr(SubscriptionContext subscriptionContext) {
65+
CallExpression callExpression = (CallExpression) subscriptionContext.syntaxNode();
66+
Optional<PipelineCreation> pipelineCreationOptional = isPipelineCreation(callExpression);
67+
if (pipelineCreationOptional.isEmpty()) {
68+
return;
69+
}
70+
PipelineCreation pipelineCreation = pipelineCreationOptional.get();
71+
72+
var memoryArgument = TreeUtils.argumentByKeyword("memory", callExpression.arguments());
73+
if (memoryArgument == null || memoryArgument.expression().is(Tree.Kind.NONE) || memoryArgument.expression().type() == InferredTypes.anyType()) {
74+
return;
75+
}
76+
var stepsArgument = TreeUtils.nthArgumentOrKeyword(0, "steps", callExpression.arguments());
77+
78+
StepsFromPipeline stepsFromPipeline = getStepsFromPipeline(stepsArgument, pipelineCreation);
79+
80+
handleStepNames(subscriptionContext, stepsFromPipeline, pipelineCreation, callExpression);
81+
}
82+
83+
private static StepsFromPipeline getStepsFromPipeline(@Nullable RegularArgument stepsArgument, PipelineCreation pipelineCreation) {
84+
Map<Name, String> nameToStepName = new HashMap<>();
85+
Optional<Expression> stepArgumentExpression = Optional.ofNullable(stepsArgument)
86+
.map(RegularArgument::expression);
87+
88+
var stepNames = stepArgumentExpression.map(
89+
e -> pipelineCreation == PipelineCreation.PIPELINE ? extractFromPipeline(e, nameToStepName) : extractFromMakePipeline(e))
90+
.orElse(Stream.empty());
91+
return new StepsFromPipeline(nameToStepName, stepNames);
92+
}
93+
94+
private record StepsFromPipeline(Map<Name, String> nameToStepName, Stream<Name> stepNames) {
95+
}
96+
97+
private static void handleStepNames(SubscriptionContext subscriptionContext, StepsFromPipeline stepsFromPipeline, PipelineCreation pipelineCreation,
98+
CallExpression callExpression) {
99+
stepsFromPipeline.stepNames()
100+
.map(name -> Map.entry(name, symbolIsUsedInQualifiedExpression(name))).forEach(entry -> {
101+
Name name = entry.getKey();
102+
Map<Tree, QualifiedExpression> uses = entry.getValue();
103+
104+
if (!uses.isEmpty()) {
105+
createIssue(subscriptionContext, stepsFromPipeline, pipelineCreation, callExpression, name, uses);
106+
}
107+
});
108+
}
109+
110+
private static void createIssue(SubscriptionContext subscriptionContext, StepsFromPipeline stepsFromPipeline, PipelineCreation pipelineCreation, CallExpression callExpression,
111+
Name name, Map<Tree, QualifiedExpression> uses) {
112+
var issue = subscriptionContext.addIssue(name, MESSAGE);
113+
uses.forEach((useTree, qualExpr) -> issue.secondary(useTree, MESSAGE_SECONDARY));
114+
if (pipelineCreation == PipelineCreation.PIPELINE) {
115+
issue.secondary(callExpression.callee(), MESSAGE_SECONDARY_CREATION);
116+
uses
117+
.forEach(
118+
(useTree, qualExpr) -> getAssignedName(callExpression)
119+
.flatMap(pipelineBindingVariable -> getQuickFix(pipelineBindingVariable, name, qualExpr, stepsFromPipeline.nameToStepName()))
120+
.ifPresent(issue::addQuickFix));
121+
}
122+
}
123+
124+
private static Stream<Name> extractFromMakePipeline(Expression stepArgumentExpression) {
125+
return Optional.of(stepArgumentExpression)
126+
.filter(e -> e.is(Tree.Kind.NAME))
127+
.map(Name.class::cast)
128+
.stream();
129+
}
130+
131+
private static Stream<Name> extractFromPipeline(Expression stepArgumentExpression, Map<Name, String> nameToStepName) {
132+
return Optional.of(stepArgumentExpression)
133+
.filter(e -> e.is(Tree.Kind.LIST_LITERAL))
134+
.map(e -> ((ListLiteral) e).elements().expressions())
135+
.stream()
136+
.flatMap(Collection::stream)
137+
.filter(e -> e.is(Tree.Kind.TUPLE))
138+
.map(t -> ((TupleImpl) t).elements())
139+
.filter(e -> e.size() == 2)
140+
.filter(e -> e.get(1).is(Tree.Kind.NAME))
141+
.map(elements -> {
142+
if (elements.get(0).is(Tree.Kind.STRING_LITERAL)) {
143+
nameToStepName.put((Name) elements.get(1), ((StringLiteral) elements.get(0)).trimmedQuotesValue());
144+
}
145+
return elements;
146+
})
147+
.map(e -> e.get(1))
148+
.map(Name.class::cast);
149+
}
150+
151+
private static Optional<PythonQuickFix> getQuickFix(Name pipelineBindingVariable, Tree name, QualifiedExpression qualifiedExpression, Map<Name, String> nameToStepName) {
152+
return Optional.ofNullable(nameToStepName.get(name))
153+
.map(stepName -> PythonQuickFix.newQuickFix("Access the property through the ")
154+
.addTextEdit(TextEditUtils.replace(qualifiedExpression.qualifier(), String.format("%s.named_steps[\"%s\"]", pipelineBindingVariable.name(), stepName)))
155+
.build());
156+
}
157+
158+
private static Map<Tree, QualifiedExpression> symbolIsUsedInQualifiedExpression(Name name) {
159+
Symbol symbol = name.symbol();
160+
if (symbol == null) {
161+
return new HashMap<>();
162+
}
163+
Map<Tree, QualifiedExpression> qualifiedExpressions = new HashMap<>();
164+
symbol.usages().stream()
165+
.filter(u -> u.tree().parent().is(Tree.Kind.QUALIFIED_EXPR))
166+
.forEach(u -> qualifiedExpressions.put(((QualifiedExpression) u.tree().parent()).qualifier(), ((QualifiedExpression) u.tree().parent())));
167+
168+
return qualifiedExpressions;
169+
}
170+
171+
private enum PipelineCreation {
172+
PIPELINE,
173+
MAKE_PIPELINE
174+
}
175+
176+
private static Optional<PipelineCreation> isPipelineCreation(CallExpression callExpression) {
177+
return Optional.ofNullable(callExpression.calleeSymbol()).map(Symbol::fullyQualifiedName)
178+
.map(fqn -> {
179+
if ("sklearn.pipeline.Pipeline".equals(fqn)) {
180+
return PipelineCreation.PIPELINE;
181+
}
182+
if ("sklearn.pipeline.make_pipeline".equals(fqn)) {
183+
return PipelineCreation.MAKE_PIPELINE;
184+
}
185+
return null;
186+
});
187+
}
188+
189+
private static Optional<Name> getAssignedName(Expression expression) {
190+
if (expression.is(Tree.Kind.NAME)) {
191+
return Optional.of((Name) expression);
192+
}
193+
if (expression.is(Tree.Kind.QUALIFIED_EXPR)) {
194+
return getAssignedName(((QualifiedExpression) expression).name());
195+
}
196+
197+
var assignment = (AssignmentStatement) TreeUtils.firstAncestorOfKind(expression, Tree.Kind.ASSIGNMENT_STMT);
198+
if (assignment == null) {
199+
return Optional.empty();
200+
}
201+
202+
var expressions = SymbolUtils.assignmentsLhs(assignment);
203+
if (expressions.size() != 1) {
204+
List<Expression> rhsExpressions = getExpressionsFromRhs(assignment.assignedValue());
205+
var rhsIndex = rhsExpressions.indexOf(expression);
206+
if (rhsIndex != -1) {
207+
return getAssignedName(expressions.get(rhsIndex));
208+
} else {
209+
return Optional.empty();
210+
}
211+
}
212+
213+
return getAssignedName(expressions.get(0));
214+
}
215+
216+
private static List<Expression> getExpressionsFromRhs(Expression rhs) {
217+
List<Expression> expressions = new ArrayList<>();
218+
if (rhs.is(Tree.Kind.TUPLE)) {
219+
expressions.addAll(((Tuple) rhs).elements());
220+
} else if (rhs.is(Tree.Kind.LIST_LITERAL)) {
221+
expressions.addAll(((ListLiteral) rhs).elements().expressions());
222+
} else if (rhs.is(Tree.Kind.UNPACKING_EXPR)) {
223+
return getExpressionsFromRhs(((UnpackingExpression) rhs).expression());
224+
}
225+
return expressions;
226+
}
227+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
<p>This rule raises an issue when trying to access a Scikit-Learn transformer used in a pipeline with caching directly.</p>
2+
<h2>Why is this an issue?</h2>
3+
<p>When using a pipeline with a cache and passing the transformer objects as an instance from a variable, it is possible to access them directly.</p>
4+
<p>This is an issue, since when the Pipeline is fitted, all the transformers are cloned. The objects outside the Pipeline are therefore not updated,
5+
and will yield unexpected results.</p>
6+
<h2>How to fix it</h2>
7+
<p>Replace the direct access to the transformer with an access to the <code>named_steps</code> attribute of the pipeline.</p>
8+
<h3>Code examples</h3>
9+
<h4>Noncompliant code example</h4>
10+
<pre data-diff-id="1" data-diff-type="noncompliant">
11+
from sklearn.datasets import load_diabetes
12+
from sklearn.preprocessing import RobustScaler
13+
from sklearn.neighbors import KNeighborsRegressor
14+
from sklearn.pipeline import Pipeline
15+
16+
diabetes = load_diabetes()
17+
scaler = RobustScaler()
18+
knn = KNeighborsRegressor(n_neighbors=5)
19+
20+
pipeline = Pipeline([
21+
('scaler', scaler),
22+
('knn', knn)
23+
], memory="cache")
24+
25+
pipeline.fit(diabetes.data, diabetes.target)
26+
print(scaler.center_) # Noncompliant : raises an AttributeError
27+
</pre>
28+
<h4>Compliant solution</h4>
29+
<pre data-diff-id="1" data-diff-type="compliant">
30+
from sklearn.datasets import load_diabetes
31+
from sklearn.preprocessing import RobustScaler
32+
from sklearn.neighbors import KNeighborsRegressor
33+
from sklearn.pipeline import Pipeline
34+
35+
diabetes = load_diabetes()
36+
scaler = RobustScaler()
37+
knn = KNeighborsRegressor(n_neighbors=5)
38+
39+
pipeline = Pipeline([
40+
('scaler', scaler),
41+
('knn', knn)
42+
], memory="cache")
43+
44+
pipeline.fit(diabetes.data, diabetes.target)
45+
print(pipeline.named_steps['scaler'].center_) # Compliant
46+
</pre>
47+
<h2>Resources</h2>
48+
<h3>Documentation</h3>
49+
<ul>
50+
<li> Scikit-Learn - Pipelines and composite estimators : <a
51+
href="https://scikit-learn.org/stable/modules/compose.html#warning:-side-effect-of-caching-transformers">Side effect of caching transformers</a>
52+
</li>
53+
</ul>
54+
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
{
2+
"title": "Transformers should not be accessed directly when a Scikit-Learn Pipeline uses caching",
3+
"type": "BUG",
4+
"status": "ready",
5+
"remediation": {
6+
"func": "Constant\/Issue",
7+
"constantCost": "5min"
8+
},
9+
"tags": [],
10+
"defaultSeverity": "Major",
11+
"ruleSpecification": "RSPEC-6971",
12+
"sqKey": "S6971",
13+
"scope": "All",
14+
"quickfix": "targeted",
15+
"code": {
16+
"impacts": {
17+
"MAINTAINABILITY": "MEDIUM",
18+
"RELIABILITY": "HIGH"
19+
},
20+
"attribute": "LOGICAL"
21+
}
22+
}

python-checks/src/main/resources/org/sonar/l10n/py/rules/python/Sonar_way_profile.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@
238238
"S6919",
239239
"S6925",
240240
"S6928",
241-
"S6929"
241+
"S6929",
242+
"S6971"
242243
]
243244
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* SonarQube Python Plugin
3+
* Copyright (C) 2011-2024 SonarSource SA
4+
* mailto:info AT sonarsource DOT com
5+
*
6+
* This program is free software; you can redistribute it and/or
7+
* modify it under the terms of the GNU Lesser General Public
8+
* License as published by the Free Software Foundation; either
9+
* version 3 of the License, or (at your option) any later version.
10+
*
11+
* This program is distributed in the hope that it will be useful,
12+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14+
* Lesser General Public License for more details.
15+
*
16+
* You should have received a copy of the GNU Lesser General Public License
17+
* along with this program; if not, write to the Free Software Foundation,
18+
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
19+
*/
20+
package org.sonar.python.checks;
21+
22+
import org.junit.jupiter.api.Test;
23+
import org.sonar.python.checks.quickfix.PythonQuickFixVerifier;
24+
import org.sonar.python.checks.utils.PythonCheckVerifier;
25+
26+
27+
class SklearnCachedPipelineDontAccessTransformersCheckTest {
28+
29+
@Test
30+
void test(){
31+
PythonCheckVerifier.verify("src/test/resources/checks/sklearn_cached_pipeline_dont_access_transformers.py", new SklearnCachedPipelineDontAccessTransformersCheck());
32+
}
33+
34+
@Test
35+
void test_quickfix1(){
36+
PythonQuickFixVerifier.verify(
37+
new SklearnCachedPipelineDontAccessTransformersCheck(),
38+
"""
39+
from sklearn.pipeline import Pipeline
40+
scaler = RobustScaler()
41+
knn = KNeighborsRegressor(n_neighbors=5)
42+
43+
pipeline = Pipeline([
44+
('scaler', scaler),
45+
('knn', knn),
46+
], memory="cache")
47+
print(scaler.center_)
48+
""",
49+
"""
50+
from sklearn.pipeline import Pipeline
51+
scaler = RobustScaler()
52+
knn = KNeighborsRegressor(n_neighbors=5)
53+
54+
pipeline = Pipeline([
55+
('scaler', scaler),
56+
('knn', knn),
57+
], memory="cache")
58+
print(pipeline.named_steps["scaler"].center_)
59+
"""
60+
);
61+
}
62+
}

0 commit comments

Comments
 (0)