Skip to content

Commit ca9525f

Browse files
authored
SONARPY-1771 Rule S6972: Nested estimator parameters adjustment in a Pipeline should refer to valid parameters (#1789)
1 parent b28933c commit ca9525f

File tree

7 files changed

+536
-2
lines changed

7 files changed

+536
-2
lines changed

python-checks/src/main/java/org/sonar/python/checks/CheckList.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,9 +350,10 @@ public static Iterable<Class> getChecks() {
350350
StringReplaceCheck.class,
351351
StrongCryptographicKeysCheck.class,
352352
SklearnCachedPipelineDontAccessTransformersCheck.class,
353-
SuperfluousCurlyBraceCheck.class,
354-
SklearnPipelineSpecifyMemoryArgumentCheck.class,
355353
SklearnEstimatorHyperparametersCheck.class,
354+
SklearnPipelineSpecifyMemoryArgumentCheck.class,
355+
SklearnPipelineParameterAreCorrectCheck.class,
356+
SuperfluousCurlyBraceCheck.class,
356357
TempFileCreationCheck.class,
357358
ImplicitlySkippedTestCheck.class,
358359
ToDoCommentCheck.class,
Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
/*
2+
* SonarQube Python Plugin
3+
* Copyright (C) 2011-2024 SonarSource SA
4+
* mailto:info AT sonarsource DOT com
5+
*
6+
* This program is free software; you can redistribute it and/or
7+
* modify it under the terms of the GNU Lesser General Public
8+
* License as published by the Free Software Foundation; either
9+
* version 3 of the License, or (at your option) any later version.
10+
*
11+
* This program is distributed in the hope that it will be useful,
12+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14+
* Lesser General Public License for more details.
15+
*
16+
* You should have received a copy of the GNU Lesser General Public License
17+
* along with this program; if not, write to the Free Software Foundation,
18+
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
19+
*/
20+
package org.sonar.python.checks;
21+
22+
import java.util.Arrays;
23+
import java.util.Collection;
24+
import java.util.HashMap;
25+
import java.util.HashSet;
26+
import java.util.List;
27+
import java.util.Map;
28+
import java.util.Objects;
29+
import java.util.Optional;
30+
import java.util.Set;
31+
import java.util.function.Function;
32+
import java.util.stream.Collector;
33+
import java.util.stream.Collectors;
34+
import java.util.stream.Stream;
35+
import javax.annotation.Nullable;
36+
import org.sonar.check.Rule;
37+
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
38+
import org.sonar.plugins.python.api.SubscriptionContext;
39+
import org.sonar.plugins.python.api.symbols.ClassSymbol;
40+
import org.sonar.plugins.python.api.symbols.FunctionSymbol;
41+
import org.sonar.plugins.python.api.symbols.Symbol;
42+
import org.sonar.plugins.python.api.tree.CallExpression;
43+
import org.sonar.plugins.python.api.tree.DictionaryLiteral;
44+
import org.sonar.plugins.python.api.tree.Expression;
45+
import org.sonar.plugins.python.api.tree.ExpressionList;
46+
import org.sonar.plugins.python.api.tree.Name;
47+
import org.sonar.plugins.python.api.tree.QualifiedExpression;
48+
import org.sonar.plugins.python.api.tree.RegularArgument;
49+
import org.sonar.plugins.python.api.tree.StringLiteral;
50+
import org.sonar.plugins.python.api.tree.Tree;
51+
import org.sonar.python.checks.utils.Expressions;
52+
import org.sonar.python.tree.DictionaryLiteralImpl;
53+
import org.sonar.python.tree.KeyValuePairImpl;
54+
import org.sonar.python.tree.ListLiteralImpl;
55+
import org.sonar.python.tree.TreeUtils;
56+
import org.sonar.python.tree.TupleImpl;
57+
58+
@Rule(key = "S6972")
59+
public class SklearnPipelineParameterAreCorrectCheck extends PythonSubscriptionCheck {
60+
61+
public static final String MESSAGE = "Provide valid parameters to the estimator.";
62+
private static final Set<String> SKLEARN_SEARCH_FQNS = Set.of(
63+
"sklearn.model_selection._search.GridSearchCV",
64+
"sklearn.model_selection._search_successive_halving.HalvingGridSearchCV",
65+
"sklearn.model_selection._search.RandomizedSearchCV",
66+
"sklearn.model_selection._search_successive_halving.HalvingRandomSearchCV");
67+
68+
private record PipelineNameAndParsedParameters(Name pipelineName, Map<String, Set<ParameterNameAndLocation>> parsedParameters) {
69+
}
70+
private record ExpressionAndPrefix(List<Expression> tuple, String prefix, int depth) {
71+
}
72+
private record ParameterNameAndLocation(String string, Tree tree) {
73+
}
74+
private record StepAndClassifier(String stepName, ClassSymbol classifierName) {
75+
}
76+
private record StepAndParameter(String step, String parameter, Tree location) {
77+
}
78+
79+
@Override
80+
public void initialize(Context context) {
81+
context.registerSyntaxNodeConsumer(Tree.Kind.CALL_EXPR, SklearnPipelineParameterAreCorrectCheck::checkCallExpression);
82+
}
83+
84+
private static void checkCallExpression(SubscriptionContext subscriptionContext) {
85+
CallExpression callExpression = (CallExpression) subscriptionContext.syntaxNode();
86+
87+
var parsedFunctionOptional = Optional.ofNullable(callExpression.calleeSymbol())
88+
.map(Symbol::fullyQualifiedName)
89+
.filter(SKLEARN_SEARCH_FQNS::contains)
90+
.map(callExpr -> getStepAndParametersFromDict(callExpression))
91+
.flatMap(parsedParameters -> getPipelineNameAndParsedParametersFromSearchFunctions(parsedParameters, callExpression))
92+
.or(() -> Optional.ofNullable(callExpression.calleeSymbol())
93+
.map(Symbol::fullyQualifiedName)
94+
.filter("sklearn.pipeline.Pipeline.set_params"::equals)
95+
.map(callExpr -> getStepAndParametersFromArguments(callExpression))
96+
.flatMap(parsedParameters -> getPipelineNameAndParsedParametersFromPipelineSetParamsFunction(parsedParameters, callExpression)));
97+
98+
parsedFunctionOptional.ifPresent(
99+
pipelineNameAndParsedParameters -> {
100+
var parsedFunction = pipelineNameAndParsedParameters.parsedParameters;
101+
var pipelineName = pipelineNameAndParsedParameters.pipelineName;
102+
Expressions.singleAssignedNonNameValue(pipelineName)
103+
.flatMap(TreeUtils.toOptionalInstanceOfMapper(CallExpression.class))
104+
.ifPresent(pipelineCallExpr -> findProblems(parsedFunction, parsePipeline(pipelineCallExpr), subscriptionContext));
105+
});
106+
}
107+
108+
private static Optional<PipelineNameAndParsedParameters> getPipelineNameAndParsedParametersFromPipelineSetParamsFunction(
109+
Map<String, Set<ParameterNameAndLocation>> parsedParameters,
110+
CallExpression callExpression) {
111+
return newPipelineNameAndParsedParameters(
112+
Optional.of(callExpression).map(CallExpression::callee)
113+
.flatMap(TreeUtils.toOptionalInstanceOfMapper(QualifiedExpression.class))
114+
.map(QualifiedExpression::qualifier)
115+
.flatMap(TreeUtils.toOptionalInstanceOfMapper(Name.class))
116+
.orElse(null),
117+
parsedParameters);
118+
}
119+
120+
private static Optional<PipelineNameAndParsedParameters> getPipelineNameAndParsedParametersFromSearchFunctions(Map<String, Set<ParameterNameAndLocation>> parsedParameters,
121+
CallExpression callExpression) {
122+
return newPipelineNameAndParsedParameters(
123+
Optional.ofNullable(TreeUtils.nthArgumentOrKeyword(0, "estimator", callExpression.arguments()))
124+
.flatMap(TreeUtils.toOptionalInstanceOfMapper(RegularArgument.class))
125+
.map(RegularArgument::expression)
126+
.flatMap(TreeUtils.toOptionalInstanceOfMapper(Name.class))
127+
.orElse(null),
128+
parsedParameters);
129+
}
130+
131+
private static Optional<PipelineNameAndParsedParameters> newPipelineNameAndParsedParameters(@Nullable Name pipelineName,
132+
Map<String, Set<ParameterNameAndLocation>> parsedParameters) {
133+
return Optional.ofNullable(pipelineName)
134+
.map(pipelineName1 -> new PipelineNameAndParsedParameters(pipelineName1, parsedParameters));
135+
}
136+
137+
private static void findProblems(Map<String, Set<ParameterNameAndLocation>> setParameters, Map<String, ClassSymbol> pipelineDefinition, SubscriptionContext subscriptionContext) {
138+
for (var entry : setParameters.entrySet()) {
139+
var step = entry.getKey();
140+
var stringAndTree = entry.getValue();
141+
var parameters = stringAndTree.stream().map(ParameterNameAndLocation::string).collect(Collectors.toSet());
142+
143+
var classifier = pipelineDefinition.get(step);
144+
if (classifier == null) {
145+
continue;
146+
}
147+
var possibleParameters = getInitFunctionSymbol(classifier).map(FunctionSymbol::parameters).orElse(List.of());
148+
149+
parameters.forEach(parameter -> {
150+
if (isNotAValidParameter(parameter, possibleParameters)) {
151+
createIssue(subscriptionContext, parameter, stringAndTree);
152+
}
153+
});
154+
}
155+
}
156+
157+
private static void createIssue(SubscriptionContext subscriptionContext, String parameter, Set<ParameterNameAndLocation> parameterNameAndLocation) {
158+
parameterNameAndLocation
159+
.stream()
160+
.filter(parameterNameAndLocation1 -> parameterNameAndLocation1.string()
161+
.equals(parameter))
162+
.findFirst()
163+
.ifPresent(location -> subscriptionContext.addIssue(location.tree, MESSAGE));
164+
}
165+
166+
private static boolean isNotAValidParameter(String parameter, List<FunctionSymbol.Parameter> possibleParameters) {
167+
return possibleParameters.stream().noneMatch(symbol -> Objects.equals(symbol.name(), parameter));
168+
}
169+
170+
private static Optional<FunctionSymbol> getInitFunctionSymbol(ClassSymbol classSymbol) {
171+
return classSymbol.declaredMembers().stream().filter(
172+
memberSymbol -> "__init__".equals(memberSymbol.name())).findFirst().map(FunctionSymbol.class::cast);
173+
}
174+
175+
private static Stream<Expression> getExpressionsFromArgument(@Nullable RegularArgument argument) {
176+
return Optional.ofNullable(argument)
177+
.map(RegularArgument::expression)
178+
.flatMap(TreeUtils.toOptionalInstanceOfMapper(ListLiteralImpl.class))
179+
.map(ListLiteralImpl::elements)
180+
.map(ExpressionList::expressions)
181+
.stream()
182+
.flatMap(Collection::stream);
183+
}
184+
185+
private static Map<String, ClassSymbol> parsePipeline(CallExpression callExpression) {
186+
var stepsArgument = TreeUtils.nthArgumentOrKeyword(0, "steps", callExpression.arguments());
187+
var out = new HashMap<String, ClassSymbol>();
188+
189+
getExpressionsFromArgument(stepsArgument)
190+
.map(
191+
TreeUtils.toInstanceOfMapper(TupleImpl.class))
192+
.filter(Objects::nonNull)
193+
.map(TupleImpl::elements)
194+
.filter(SklearnPipelineParameterAreCorrectCheck::isTwoElementTuple)
195+
// If we find a pipeline inside the pipeline, we need to parse it recursively
196+
.map(SklearnPipelineParameterAreCorrectCheck::createEmptyExpressionAndPrefix)
197+
.flatMap(expandRecursivePipelines())
198+
.forEach(
199+
expressionAndPrefix -> getResult(expressionAndPrefix.tuple())
200+
.ifPresent(stepAndClassifier1 -> out.put(expressionAndPrefix.prefix() + stepAndClassifier1.stepName(), stepAndClassifier1.classifierName())));
201+
return out;
202+
}
203+
204+
private static boolean isTwoElementTuple(List<Expression> elements) {
205+
return elements.size() == 2;
206+
}
207+
208+
private static ExpressionAndPrefix createEmptyExpressionAndPrefix(List<Expression> tuple) {
209+
return new ExpressionAndPrefix(tuple, "", 0);
210+
}
211+
212+
private static Function<ExpressionAndPrefix, Stream<ExpressionAndPrefix>> expandRecursivePipelines() {
213+
return expressionAndPrefix -> {
214+
var tuple = expressionAndPrefix.tuple();
215+
var step = tuple.get(0);
216+
var classifier = tuple.get(1);
217+
218+
if (!step.is(Tree.Kind.STRING_LITERAL) || !classifier.is(Tree.Kind.NAME)) {
219+
return Stream.of(expressionAndPrefix);
220+
}
221+
if (expressionAndPrefix.depth > 10) {
222+
return Stream.of(expressionAndPrefix);
223+
}
224+
225+
return classifierIsANestedPipeline((Name) classifier)
226+
.map(callExpression -> TreeUtils.nthArgumentOrKeyword(0, "steps", callExpression.arguments()))
227+
.map(SklearnPipelineParameterAreCorrectCheck::getExpressionsFromArgument)
228+
.orElse(Stream.empty())
229+
.map(TreeUtils.toInstanceOfMapper(TupleImpl.class))
230+
.filter(Objects::nonNull)
231+
.map(TupleImpl::elements)
232+
.filter(SklearnPipelineParameterAreCorrectCheck::isTwoElementTuple)
233+
.map(elements -> new ExpressionAndPrefix(elements,
234+
expressionAndPrefix.prefix() + ((StringLiteral) step).trimmedQuotesValue() + "__", expressionAndPrefix.depth + 1))
235+
.flatMap(expandRecursivePipelines());
236+
};
237+
}
238+
239+
private static Optional<CallExpression> classifierIsANestedPipeline(Name classifier) {
240+
return Expressions.singleAssignedNonNameValue(classifier)
241+
.flatMap(TreeUtils.toOptionalInstanceOfMapper(CallExpression.class))
242+
.filter(callExpression -> Optional.of(callExpression)
243+
.map(CallExpression::calleeSymbol).map(Symbol::fullyQualifiedName)
244+
.filter("sklearn.pipeline.Pipeline"::equals)
245+
.isPresent());
246+
}
247+
248+
private static Optional<StepAndClassifier> getResult(List<Expression> tuple) {
249+
var step = tuple.get(0);
250+
var classifier = tuple.get(1);
251+
var stepName = Optional.ofNullable(step)
252+
.flatMap(TreeUtils.toOptionalInstanceOfMapper(StringLiteral.class))
253+
.map(StringLiteral::trimmedQuotesValue);
254+
var classifierName = Optional.ofNullable(classifier)
255+
.flatMap(TreeUtils.toOptionalInstanceOfMapper(CallExpression.class))
256+
.map(CallExpression::calleeSymbol)
257+
.filter(symbol -> symbol.is(Symbol.Kind.CLASS) && !"sklearn.pipeline.Pipeline".equals(symbol.fullyQualifiedName()))
258+
.map(ClassSymbol.class::cast);
259+
260+
return stepName.flatMap(stepName1 -> classifierName.map(classifierName1 -> new StepAndClassifier(stepName1, classifierName1)));
261+
}
262+
263+
private static Map<String, Set<ParameterNameAndLocation>> getStepAndParametersFromArguments(CallExpression callExpression) {
264+
return callExpression.arguments()
265+
.stream()
266+
.filter(argument -> argument.is(Tree.Kind.REGULAR_ARGUMENT))
267+
.map(RegularArgument.class::cast)
268+
.map(RegularArgument::keywordArgument)
269+
.filter(Objects::nonNull)
270+
.map(SklearnPipelineParameterAreCorrectCheck::getStepAndParameterFromName)
271+
.<StepAndParameter>mapMulti(Optional::ifPresent)
272+
.collect(mergeStringAndTreeToMapCollector());
273+
}
274+
275+
private static Map<String, Set<ParameterNameAndLocation>> getStepAndParametersFromDict(CallExpression callExpression) {
276+
return Optional.ofNullable(TreeUtils.nthArgumentOrKeyword(1, "param_grid", callExpression.arguments()))
277+
.flatMap(TreeUtils.toOptionalInstanceOfMapper(RegularArgument.class))
278+
.map(RegularArgument::expression)
279+
.stream()
280+
.flatMap(SklearnPipelineParameterAreCorrectCheck::extractKeyValuePairFromDictLiteral)
281+
.map(KeyValuePairImpl::key)
282+
.map(TreeUtils.toInstanceOfMapper(StringLiteral.class))
283+
.filter(Objects::nonNull)
284+
.map(stringLiteral -> getStepAndParameterFromString(stringLiteral.trimmedQuotesValue(), stringLiteral))
285+
.<StepAndParameter>mapMulti(Optional::ifPresent)
286+
.collect(
287+
mergeStringAndTreeToMapCollector());
288+
}
289+
290+
private static Stream<KeyValuePairImpl> extractKeyValuePairFromDictLiteral(Expression expression) {
291+
return Optional.of(expression)
292+
.flatMap(TreeUtils.toOptionalInstanceOfMapper(Name.class))
293+
.flatMap(Expressions::singleAssignedNonNameValue)
294+
.flatMap(TreeUtils.toOptionalInstanceOfMapper(DictionaryLiteralImpl.class))
295+
.map(DictionaryLiteral::elements)
296+
.stream()
297+
.flatMap(Collection::stream)
298+
.map(TreeUtils.toInstanceOfMapper(KeyValuePairImpl.class))
299+
.filter(Objects::nonNull);
300+
}
301+
302+
private static Collector<StepAndParameter, ?, Map<String, Set<ParameterNameAndLocation>>> mergeStringAndTreeToMapCollector() {
303+
return Collectors.toMap(StepAndParameter::step, stepAndParameter -> Set.of(new ParameterNameAndLocation(stepAndParameter.parameter, stepAndParameter.location)),
304+
(set1, set2) -> {
305+
var set = new HashSet<>(set1);
306+
set.addAll(set2);
307+
return set;
308+
});
309+
}
310+
311+
private static Optional<StepAndParameter> getStepAndParameterFromName(Name name) {
312+
return splitStepString(name.name()).map(split -> {
313+
var splitsNotLast = Arrays.stream(split).limit(split.length - 1L).collect(Collectors.joining("__"));
314+
return new StepAndParameter(splitsNotLast, split[split.length - 1], name);
315+
});
316+
}
317+
318+
private static Optional<StepAndParameter> getStepAndParameterFromString(String string, Tree location) {
319+
return splitStepString(string).map(split -> new StepAndParameter(split[0], split[1], location));
320+
}
321+
322+
private static Optional<String[]> splitStepString(String string) {
323+
var split = string.split("__");
324+
if (split.length < 2 || string.endsWith("__")) {
325+
return Optional.empty();
326+
}
327+
return Optional.of(split);
328+
}
329+
330+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
<p>This rule raises an issue when an invalid nested estimator parameter is set on a Pipeline.</p>
2+
<h2>Why is this an issue?</h2>
3+
<p>In the sklearn library, when using the <code>Pipeline</code> class, it is possible to modify the parameters of the nested estimators. This
4+
modification can be done by using the <code>Pipeline</code> method <code>set_params</code> and specifying the name of the estimator and the parameter
5+
to update separated by a double underscore <code>__</code>.</p>
6+
<pre>
7+
from sklearn.pipeline import Pipeline
8+
from sklearn.svm import SVC
9+
10+
pipe = Pipeline(steps=[("clf", SVC())])
11+
pipe.set_params(clf__C=10)
12+
</pre>
13+
<p>In the example above, the regularization parameter <code>C</code> is set to the value <code>10</code> for the classifier called <code>clf</code>.
14+
Setting such parameters can be done as well with the help of the <code>param_grid</code> parameter for example when using
15+
<code>GridSearchCV</code>.</p>
16+
<p>Providing invalid parameters that do not exist on the estimator can lead to unexpected behavior or runtime errors.</p>
17+
<p>This rule checks that the parameters provided to the <code>set_params</code> method of a Pipeline instance or through the <code>param_grid</code>
18+
parameters of a <code>GridSearchCV</code> are valid for the nested estimators.</p>
19+
<h2>How to fix it</h2>
20+
<p>To fix this issue provide valid parameters to the nested estimators.</p>
21+
<h3>Code examples</h3>
22+
<h4>Noncompliant code example</h4>
23+
<pre data-diff-id="1" data-diff-type="noncompliant">
24+
from sklearn.pipeline import Pipeline
25+
from sklearn.decomposition import PCA
26+
27+
pipe = Pipeline(steps=[('reduce_dim', PCA())])
28+
pipe.set_params(reduce_dim__C=2) # Noncompliant: the parameter C does not exists for the PCA estimator
29+
</pre>
30+
<h4>Compliant solution</h4>
31+
<pre data-diff-id="1" data-diff-type="compliant">
32+
from sklearn.pipeline import Pipeline
33+
from sklearn.decomposition import PCA
34+
35+
pipe = Pipeline(steps=[('reduce_dim', PCA())])
36+
pipe.set_params(reduce_dim__n_components=2) # Compliant
37+
</pre>
38+
<h2>Resources</h2>
39+
<h3>Documentation</h3>
40+
<ul>
41+
<li> Scikit-Learn documentation - <a href="https://scikit-learn.org/stable/modules/compose.html#access-to-nested-parameters">Access to nested
42+
parameters</a> </li>
43+
<li> Scikit-Learn documentation - <a
44+
href="https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html#sklearn-model-selection-gridsearchcv">GridSearchCV
45+
reference</a> </li>
46+
</ul>
47+

0 commit comments

Comments
 (0)