Skip to content

Commit b28933c

Browse files
authored
SONARPY-1775: Important hyperparameters should be specified for Scikit-Learn estimators. (#1788)
1 parent f31a720 commit b28933c

File tree

9 files changed

+555
-17
lines changed

9 files changed

+555
-17
lines changed

python-checks/src/main/java/org/sonar/python/checks/CheckList.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,7 @@ public static Iterable<Class> getChecks() {
352352
SklearnCachedPipelineDontAccessTransformersCheck.class,
353353
SuperfluousCurlyBraceCheck.class,
354354
SklearnPipelineSpecifyMemoryArgumentCheck.class,
355+
SklearnEstimatorHyperparametersCheck.class,
355356
TempFileCreationCheck.class,
356357
ImplicitlySkippedTestCheck.class,
357358
ToDoCommentCheck.class,
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
/*
2+
* SonarQube Python Plugin
3+
* Copyright (C) 2011-2024 SonarSource SA
4+
* mailto:info AT sonarsource DOT com
5+
*
6+
* This program is free software; you can redistribute it and/or
7+
* modify it under the terms of the GNU Lesser General Public
8+
* License as published by the Free Software Foundation; either
9+
* version 3 of the License, or (at your option) any later version.
10+
*
11+
* This program is distributed in the hope that it will be useful,
12+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14+
* Lesser General Public License for more details.
15+
*
16+
* You should have received a copy of the GNU Lesser General Public License
17+
* along with this program; if not, write to the Free Software Foundation,
18+
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
19+
*/
20+
package org.sonar.python.checks;
21+
22+
import java.util.List;
23+
import java.util.Map;
24+
import java.util.Objects;
25+
import java.util.Optional;
26+
import java.util.Set;
27+
import org.sonar.check.Rule;
28+
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
29+
import org.sonar.plugins.python.api.SubscriptionContext;
30+
import org.sonar.plugins.python.api.symbols.ClassSymbol;
31+
import org.sonar.plugins.python.api.symbols.Symbol;
32+
import org.sonar.plugins.python.api.symbols.Symbol.Kind;
33+
import org.sonar.plugins.python.api.symbols.Usage;
34+
import org.sonar.plugins.python.api.tree.CallExpression;
35+
import org.sonar.plugins.python.api.tree.Name;
36+
import org.sonar.plugins.python.api.tree.QualifiedExpression;
37+
import org.sonar.plugins.python.api.tree.RegularArgument;
38+
import org.sonar.plugins.python.api.tree.Tree;
39+
import org.sonar.python.checks.utils.Expressions;
40+
import org.sonar.python.tree.TreeUtils;
41+
42+
import static org.sonar.plugins.python.api.tree.Tree.Kind.CALL_EXPR;
43+
import static org.sonar.plugins.python.api.tree.Tree.Kind.REGULAR_ARGUMENT;
44+
import static org.sonar.python.tree.TreeUtils.toOptionalInstanceOfMapper;
45+
46+
@Rule(key = "S6973")
47+
public class SklearnEstimatorHyperparametersCheck extends PythonSubscriptionCheck {
48+
49+
private static final String MESSAGE = "Specify important hyperparameters when instantiating a Scikit-learn estimator.";
50+
51+
private record Param(String name, Optional<Integer> position) {
52+
public Param(String name) {
53+
this(name, Optional.empty());
54+
}
55+
56+
public Param(String name, int position) {
57+
this(name, Optional.of(position));
58+
}
59+
}
60+
61+
private static final String LEARNING_RATE = "learning_rate";
62+
private static final String N_NEIGHBORS = "n_neighbors";
63+
private static final String KERNEL = "kernel";
64+
private static final String GAMMA = "gamma";
65+
private static final String C = "C";
66+
67+
private static final Map<String, List<Param>> ESTIMATORS_AND_PARAMETERS_TO_CHECK = Map.ofEntries(
68+
Map.entry("sklearn.ensemble._weight_boosting.AdaBoostClassifier", List.of(new Param(LEARNING_RATE))),
69+
Map.entry("sklearn.ensemble._weight_boosting.AdaBoostRegressor", List.of(new Param(LEARNING_RATE))),
70+
Map.entry("sklearn.ensemble._gb.GradientBoostingClassifier", List.of(new Param(LEARNING_RATE))),
71+
Map.entry("sklearn.ensemble._gb.GradientBoostingRegressor", List.of(new Param(LEARNING_RATE))),
72+
Map.entry("sklearn.ensemble._hist_gradient_boosting.gradient_boosting.HistGradientBoostingClassifier", List.of(new Param(LEARNING_RATE))),
73+
Map.entry("sklearn.ensemble._hist_gradient_boosting.gradient_boosting.HistGradientBoostingRegressor", List.of(new Param(LEARNING_RATE))),
74+
Map.entry("sklearn.ensemble._forest.RandomForestClassifier", List.of(new Param("min_samples_leaf"), new Param("max_features"))),
75+
Map.entry("sklearn.ensemble._forest.RandomForestRegressor", List.of(new Param("min_samples_leaf"), new Param("max_features"))),
76+
Map.entry("sklearn.linear_model._coordinate_descent.ElasticNet", List.of(new Param("alpha", 0), new Param("l1_ratio"))),
77+
Map.entry("sklearn.neighbors._unsupervised.NearestNeighbors", List.of(new Param(N_NEIGHBORS, 0))),
78+
Map.entry("sklearn.neighbors._classification.KNeighborsClassifier", List.of(new Param(N_NEIGHBORS, 0))),
79+
Map.entry("sklearn.neighbors._regression.KNeighborsRegressor", List.of(new Param(N_NEIGHBORS, 0))),
80+
Map.entry("sklearn.svm._classes.NuSVC", List.of(new Param("nu"), new Param(KERNEL), new Param(GAMMA))),
81+
Map.entry("sklearn.svm._classes.NuSVR", List.of(new Param(C), new Param(KERNEL), new Param(GAMMA))),
82+
Map.entry("sklearn.svm._classes.SVC", List.of(new Param(C), new Param(KERNEL), new Param(GAMMA))),
83+
Map.entry("sklearn.svm._classes.SVR", List.of(new Param(C), new Param(KERNEL), new Param(GAMMA))),
84+
Map.entry("sklearn.tree._classes.DecisionTreeClassifier", List.of(new Param("ccp_alpha"))),
85+
Map.entry("sklearn.tree._classes.DecisionTreeRegressor", List.of(new Param("ccp_alpha"))),
86+
Map.entry("sklearn.neural_network._multilayer_perceptron.MLPClassifier", List.of(new Param("hidden_layer_sizes", 0))),
87+
Map.entry("sklearn.neural_network._multilayer_perceptron.MLPRegressor", List.of(new Param("hidden_layer_sizes", 0))),
88+
Map.entry("sklearn.preprocessing._polynomial.PolynomialFeatures", List.of(new Param("degree", 0), new Param("interaction_only"))));
89+
90+
private static final Set<String> SEARCH_CV_FQNS = Set.of(
91+
"sklearn.model_selection._search.GridSearchCV",
92+
"sklearn.model_selection._search.RandomizedSearchCV",
93+
"sklearn.model_selection._search_successive_halving.HalvingRandomSearchCV",
94+
"sklearn.model_selection._search_successive_halving.HalvingGridSearchCV");
95+
96+
private static final Set<String> PIPELINE_FQNS = Set.of(
97+
"sklearn.pipeline.make_pipeline",
98+
"sklearn.pipeline.Pipeline");
99+
100+
@Override
101+
public void initialize(Context context) {
102+
context.registerSyntaxNodeConsumer(CALL_EXPR, SklearnEstimatorHyperparametersCheck::checkEstimator);
103+
}
104+
105+
private static void checkEstimator(SubscriptionContext ctx) {
106+
CallExpression callExpression = (CallExpression) ctx.syntaxNode();
107+
108+
Symbol calleeSymbol = callExpression.calleeSymbol();
109+
110+
Optional.ofNullable(calleeSymbol)
111+
.filter(callee -> callee.is(Kind.CLASS))
112+
.map(ClassSymbol.class::cast)
113+
.map(ClassSymbol::fullyQualifiedName)
114+
.map(ESTIMATORS_AND_PARAMETERS_TO_CHECK::get)
115+
.filter(parameters -> !isDirectlyUsedInSearchCV(callExpression))
116+
.filter(parameters -> !isSetParamsCalled(callExpression))
117+
.filter(parameters -> !isPartOfPipelineAndSearchCV(callExpression))
118+
.filter(parameters -> isMissingAHyperparameter(callExpression, parameters))
119+
.ifPresent(parameters -> ctx.addIssue(callExpression, MESSAGE));
120+
}
121+
122+
private static boolean isMissingAHyperparameter(CallExpression callExpression, List<Param> parametersToCheck) {
123+
return parametersToCheck.stream()
124+
.map(param -> param.position()
125+
.map(position -> TreeUtils.nthArgumentOrKeyword(position, param.name, callExpression.arguments()))
126+
.orElse(TreeUtils.argumentByKeyword(param.name, callExpression.arguments())))
127+
.anyMatch(Objects::isNull);
128+
}
129+
130+
private static boolean isDirectlyUsedInSearchCV(CallExpression callExpression) {
131+
return Optional.ofNullable(TreeUtils.firstAncestorOfKind(callExpression, REGULAR_ARGUMENT))
132+
.flatMap(TreeUtils.toOptionalInstanceOfMapper(RegularArgument.class))
133+
.map(SklearnEstimatorHyperparametersCheck::isArgumentPartOfSearchCV)
134+
.orElse(false);
135+
}
136+
137+
private static boolean isPartOfPipelineAndSearchCV(CallExpression callExpression) {
138+
return Expressions.getAssignedName(callExpression)
139+
.map(SklearnEstimatorHyperparametersCheck::isEstimatorUsedInSearchCV)
140+
.or(() -> getPipelineAssignement(callExpression)
141+
.map(SklearnEstimatorHyperparametersCheck::isEstimatorUsedInSearchCV))
142+
.orElse(false);
143+
}
144+
145+
private static Optional<Name> getPipelineAssignement(CallExpression callExpression) {
146+
return Optional.ofNullable(TreeUtils.firstAncestorOfKind(callExpression, CALL_EXPR))
147+
.flatMap(TreeUtils.toOptionalInstanceOfMapper(CallExpression.class))
148+
.filter(callExp -> Optional.ofNullable(callExp.calleeSymbol())
149+
.map(Symbol::fullyQualifiedName)
150+
.map(PIPELINE_FQNS::contains)
151+
.orElse(false))
152+
.flatMap(Expressions::getAssignedName);
153+
}
154+
155+
private static boolean isEstimatorUsedInSearchCV(Name estimator) {
156+
return Optional.ofNullable(estimator.symbol())
157+
.map(Symbol::usages)
158+
.map(usages -> usages.stream()
159+
.map(Usage::tree)
160+
.map(Tree::parent)
161+
.filter(parent -> parent.is(REGULAR_ARGUMENT))
162+
.map(RegularArgument.class::cast)
163+
.anyMatch(SklearnEstimatorHyperparametersCheck::isArgumentPartOfSearchCV))
164+
.orElse(false);
165+
}
166+
167+
private static boolean isArgumentPartOfSearchCV(RegularArgument arg) {
168+
return Optional.ofNullable(TreeUtils.firstAncestorOfKind(arg, CALL_EXPR))
169+
.flatMap(toOptionalInstanceOfMapper(CallExpression.class))
170+
.map(CallExpression::calleeSymbol)
171+
.map(Symbol::fullyQualifiedName)
172+
.map(SEARCH_CV_FQNS::contains)
173+
.orElse(false);
174+
}
175+
176+
private static boolean isSetParamsCalled(CallExpression callExpression) {
177+
return Expressions.getAssignedName(callExpression)
178+
.map(Name::symbol)
179+
.map(Symbol::usages)
180+
.map(SklearnEstimatorHyperparametersCheck::isUsedWithSetParams)
181+
.orElse(false);
182+
}
183+
184+
private static boolean isUsedWithSetParams(List<Usage> usages) {
185+
return usages.stream()
186+
.map(Usage::tree)
187+
.map(Tree::parent)
188+
.filter(parent -> parent.is(Tree.Kind.QUALIFIED_EXPR))
189+
.map(TreeUtils.toInstanceOfMapper(QualifiedExpression.class))
190+
.filter(Objects::nonNull)
191+
.map(qExp -> qExp.name().name())
192+
.anyMatch("set_params"::equals);
193+
}
194+
195+
}

python-checks/src/main/java/org/sonar/python/checks/utils/Expressions.java

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@
4141
import org.sonar.plugins.python.api.tree.NumericLiteral;
4242
import org.sonar.plugins.python.api.tree.ParenthesizedExpression;
4343
import org.sonar.plugins.python.api.tree.QualifiedExpression;
44+
import org.sonar.plugins.python.api.tree.SliceExpression;
4445
import org.sonar.plugins.python.api.tree.StringElement;
4546
import org.sonar.plugins.python.api.tree.StringLiteral;
47+
import org.sonar.plugins.python.api.tree.SubscriptionExpression;
4648
import org.sonar.plugins.python.api.tree.Tree;
4749
import org.sonar.plugins.python.api.tree.Tree.Kind;
4850
import org.sonar.plugins.python.api.tree.Tuple;
@@ -309,30 +311,44 @@ private static EscapeSequence extractOctal(String value, int i) {
309311
}
310312

311313
public static Optional<Name> getAssignedName(Expression expression) {
314+
return getAssignedName(expression, 0);
315+
}
316+
317+
private static Optional<Name> getAssignedName(Expression expression, int recursionCount) {
318+
if(recursionCount > 4){
319+
return Optional.empty();
320+
}
312321
if (expression.is(Tree.Kind.NAME)) {
313322
return Optional.of((Name) expression);
314323
}
315324
if (expression.is(Tree.Kind.QUALIFIED_EXPR)) {
316-
return getAssignedName(((QualifiedExpression) expression).name());
325+
return Optional.of(((QualifiedExpression) expression).name());
326+
}
327+
if(expression.is(Tree.Kind.SUBSCRIPTION)){
328+
expression = ((SubscriptionExpression) expression).object();
329+
}
330+
if(expression.is(Tree.Kind.SLICE_EXPR)){
331+
expression = ((SliceExpression) expression).object();
317332
}
318333

319-
var assignment = (AssignmentStatement) TreeUtils.firstAncestorOfKind(expression, Tree.Kind.ASSIGNMENT_STMT);
320-
if (assignment == null) {
334+
var maybeAssignment = TreeUtils.firstAncestorOfKind(expression, Tree.Kind.ASSIGNMENT_STMT);
335+
if (maybeAssignment == null) {
321336
return Optional.empty();
322337
}
323-
338+
339+
var assignment = (AssignmentStatement) maybeAssignment;
324340
var expressions = SymbolUtils.assignmentsLhs(assignment);
325-
if (expressions.size() != 1) {
341+
342+
if (expressions.size() != 1 ) {
326343
List<Expression> rhsExpressions = getExpressionsFromRhs(assignment.assignedValue());
327344
var rhsIndex = rhsExpressions.stream().flatMap(TreeUtils::flattenTuples).toList().indexOf(expression);
328-
if (rhsIndex != -1) {
329-
return getAssignedName(expressions.get(rhsIndex));
345+
if (rhsIndex != -1 && rhsIndex < expressions.size()) {
346+
return getAssignedName(expressions.get(rhsIndex), recursionCount + 1);
330347
} else {
331348
return Optional.empty();
332349
}
333350
}
334-
335-
return getAssignedName(expressions.get(0));
351+
return getAssignedName(expressions.get(0), recursionCount + 1);
336352
}
337353

338354
private static List<Expression> getExpressionsFromRhs(Expression rhs) {

0 commit comments

Comments
 (0)