Skip to content

Commit 7d62937

Browse files
authored
SONARPY-1900: Extend S6973 - Important hyperparameters should be specified for Scikit-Learn estimators (#1959)
* SONARPY-1900 add pytorch to S6973 * SONARPY-1900 reorder methods to improve readability * SONARPY-1900 add expected ruling issues * SONARPY-1900 add positional data to pytorch optimisers * SONARPY-1900 fix MissingHyperParameterCheck in accordance with PR * SONARPY-1900 fix tests in accordance with PR * SONARPY-1900 use Expressions.containsSpreadOperator(...)
1 parent 0068790 commit 7d62937

File tree

7 files changed

+419
-201
lines changed

7 files changed

+419
-201
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{
2+
"project:pecos/examples/giant-xrt/OGB_baselines/ogbn-arxiv/gnn.py": [
3+
165
4+
],
5+
"project:pecos/examples/giant-xrt/OGB_baselines/ogbn-arxiv/mlp.py": [
6+
124
7+
],
8+
"project:pecos/examples/giant-xrt/OGB_baselines/ogbn-papers100M/mlp_sgc.py": [
9+
130
10+
],
11+
"project:pecos/examples/giant-xrt/OGB_baselines/ogbn-papers100M/mlp_xrt.py": [
12+
139
13+
],
14+
"project:pecos/examples/giant-xrt/OGB_baselines/ogbn-products/graph_saint.py": [
15+
184
16+
],
17+
"project:pecos/examples/giant-xrt/OGB_baselines/ogbn-products/mlp.py": [
18+
124
19+
]
20+
}

python-checks/src/main/java/org/sonar/python/checks/CheckList.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ public static Iterable<Class> getChecks() {
355355
StringReplaceCheck.class,
356356
StrongCryptographicKeysCheck.class,
357357
SklearnCachedPipelineDontAccessTransformersCheck.class,
358-
SklearnEstimatorHyperparametersCheck.class,
358+
MissingHyperParameterCheck.class,
359359
SklearnPipelineSpecifyMemoryArgumentCheck.class,
360360
SklearnPipelineParameterAreCorrectCheck.class,
361361
SuperfluousCurlyBraceCheck.class,
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
/*
2+
* SonarQube Python Plugin
3+
* Copyright (C) 2011-2024 SonarSource SA
4+
* mailto:info AT sonarsource DOT com
5+
*
6+
* This program is free software; you can redistribute it and/or
7+
* modify it under the terms of the GNU Lesser General Public
8+
* License as published by the Free Software Foundation; either
9+
* version 3 of the License, or (at your option) any later version.
10+
*
11+
* This program is distributed in the hope that it will be useful,
12+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14+
* Lesser General Public License for more details.
15+
*
16+
* You should have received a copy of the GNU Lesser General Public License
17+
* along with this program; if not, write to the Free Software Foundation,
18+
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
19+
*/
20+
package org.sonar.python.checks;
21+
22+
import java.util.List;
23+
import java.util.Map;
24+
import java.util.Objects;
25+
import java.util.Optional;
26+
import java.util.Set;
27+
import java.util.stream.Collectors;
28+
import org.sonar.check.Rule;
29+
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
30+
import org.sonar.plugins.python.api.SubscriptionContext;
31+
import org.sonar.plugins.python.api.symbols.Symbol;
32+
import org.sonar.plugins.python.api.symbols.Usage;
33+
import org.sonar.plugins.python.api.tree.CallExpression;
34+
import org.sonar.plugins.python.api.tree.Name;
35+
import org.sonar.plugins.python.api.tree.QualifiedExpression;
36+
import org.sonar.plugins.python.api.tree.RegularArgument;
37+
import org.sonar.plugins.python.api.tree.Tree;
38+
import org.sonar.python.checks.utils.Expressions;
39+
import org.sonar.python.tree.TreeUtils;
40+
41+
import static org.sonar.plugins.python.api.tree.Tree.Kind.CALL_EXPR;
42+
import static org.sonar.plugins.python.api.tree.Tree.Kind.REGULAR_ARGUMENT;
43+
import static org.sonar.python.tree.TreeUtils.toOptionalInstanceOfMapper;
44+
45+
@Rule(key = "S6973")
46+
public class MissingHyperParameterCheck extends PythonSubscriptionCheck {
47+
private static final String SKLEARN_MESSAGE = "Add the missing hyperparameter%s %s for this Scikit-learn estimator.";
48+
private static final String PYTORCH_MESSAGE = "Add the missing hyperparameter%s %s for this PyTorch optimizer.";
49+
50+
private record Param(String name, Optional<Integer> position) {
51+
public Param(String name) {
52+
this(name, Optional.empty());
53+
}
54+
55+
public Param(String name, int position) {
56+
this(name, Optional.of(position));
57+
}
58+
}
59+
60+
@Override
61+
public void initialize(Context context) {
62+
context.registerSyntaxNodeConsumer(CALL_EXPR, MissingHyperParameterCheck::checkEstimator);
63+
}
64+
65+
private static void checkEstimator(SubscriptionContext ctx) {
66+
CallExpression callExpression = (CallExpression) ctx.syntaxNode();
67+
Symbol calleeSymbol = callExpression.calleeSymbol();
68+
69+
Optional.ofNullable(calleeSymbol)
70+
.map(Symbol::fullyQualifiedName).ifPresent(name -> {
71+
checkPyTorchOptimizer(name, callExpression, ctx);
72+
checkSkLearnEstimator(name, callExpression, ctx);
73+
});
74+
}
75+
76+
private static void checkPyTorchOptimizer(String name, CallExpression callExpression, SubscriptionContext ctx) {
77+
PyTorchCheck.getMissingParameters(name, callExpression)
78+
.map(MissingHyperParameterCheck::toParameterNames)
79+
.ifPresent(parameters -> ctx.addIssue(callExpression, formatMessage(parameters, PYTORCH_MESSAGE)));
80+
}
81+
82+
private static void checkSkLearnEstimator(String name, CallExpression callExpression, SubscriptionContext ctx) {
83+
SkLearnCheck.getMissingParameters(name, callExpression)
84+
.map(MissingHyperParameterCheck::toParameterNames)
85+
.ifPresent(parameters -> ctx.addIssue(callExpression, formatMessage(parameters, SKLEARN_MESSAGE)));
86+
}
87+
88+
private static List<String> toParameterNames(List<Param> parameters) {
89+
return parameters.stream().map(Param::name).toList();
90+
}
91+
92+
private static String formatMessage(List<String> missingArgs, String formatString) {
93+
String plural = missingArgs.size() == 1 ? "" : "s";
94+
String missingArgsString = missingArgs.get(missingArgs.size() - 1);
95+
if (missingArgs.size() > 1) {
96+
missingArgsString = missingArgs.subList(0, missingArgs.size() - 1).stream().collect(Collectors.joining(", " +
97+
"")) + " and " + missingArgsString;
98+
}
99+
return formatString.formatted(plural, missingArgsString);
100+
}
101+
102+
103+
// common method used by both the PyTorchCheck class and SkLearnCheck class
104+
private static boolean isMissingAHyperparameter(CallExpression callExpression, List<Param> parametersToCheck) {
105+
return parametersToCheck.stream()
106+
.map(param -> param.position()
107+
.map(position -> TreeUtils.nthArgumentOrKeyword(position, param.name, callExpression.arguments()))
108+
.orElse(TreeUtils.argumentByKeyword(param.name, callExpression.arguments())))
109+
.anyMatch(Objects::isNull);
110+
}
111+
112+
private static class PyTorchCheck {
113+
public static final String LR = "lr";
114+
public static final String WEIGHT_DECAY = "weight_decay";
115+
116+
private static final Map<String, List<Param>> PY_TORCH_ESTIMATORS_AND_PARAMETERS_TO_CHECK = Map.ofEntries(
117+
Map.entry("torch.utils.data.DataLoader", List.of(new Param("batch_size", 1))),
118+
Map.entry("torch.optim.Adadelta", List.of(new Param(LR, 1), new Param(WEIGHT_DECAY, 4))),
119+
Map.entry("torch.optim.Adagrad", List.of(new Param(LR, 1), new Param(WEIGHT_DECAY, 3))),
120+
Map.entry("torch.optim.Adam", List.of(new Param(LR, 1), new Param(WEIGHT_DECAY, 4))),
121+
Map.entry("torch.optim.AdamW", List.of(new Param(LR, 1), new Param(WEIGHT_DECAY, 4))),
122+
Map.entry("torch.optim.SparseAdam", List.of(new Param(LR, 1))),
123+
Map.entry("torch.optim.Adamax", List.of(new Param(LR, 1), new Param(WEIGHT_DECAY, 4))),
124+
Map.entry("torch.optim.ASGD", List.of(new Param(LR, 1), new Param(WEIGHT_DECAY, 5))),
125+
Map.entry("torch.optim.LBFGS", List.of(new Param(LR, 1))),
126+
Map.entry("torch.optim.NAdam", List.of(new Param(LR, 1), new Param(WEIGHT_DECAY, 4), new Param("momentum_decay", 5))),
127+
Map.entry("torch.optim.RAdam", List.of(new Param(LR, 1), new Param(WEIGHT_DECAY, 4))),
128+
Map.entry("torch.optim.RMSprop", List.of(new Param(LR, 1), new Param(WEIGHT_DECAY, 4), new Param("momentum", 5))),
129+
Map.entry("torch.optim.Rprop", List.of(new Param(LR, 1))),
130+
Map.entry("torch.optim.SGD", List.of(new Param(LR, 1), new Param("momentum", 2), new Param(WEIGHT_DECAY, 4)))
131+
);
132+
133+
public static Optional<List<Param>> getMissingParameters(String name, CallExpression callExpression) {
134+
return Optional.ofNullable(PY_TORCH_ESTIMATORS_AND_PARAMETERS_TO_CHECK.get(name))
135+
.filter(parameters -> !Expressions.containsSpreadOperator(callExpression.arguments()))
136+
.filter(parameters -> isMissingAHyperparameter(callExpression, parameters));
137+
}
138+
}
139+
140+
private static class SkLearnCheck {
141+
private static final String LEARNING_RATE = "learning_rate";
142+
private static final String N_NEIGHBORS = "n_neighbors";
143+
private static final String KERNEL = "kernel";
144+
private static final String GAMMA = "gamma";
145+
private static final String C = "C";
146+
147+
private static final Map<String, List<Param>> SK_LEARN_ESTIMATORS_AND_PARAMETERS_TO_CHECK = Map.ofEntries(
148+
Map.entry("sklearn.ensemble._weight_boosting.AdaBoostClassifier", List.of(new Param(LEARNING_RATE))),
149+
Map.entry("sklearn.ensemble._weight_boosting.AdaBoostRegressor", List.of(new Param(LEARNING_RATE))),
150+
Map.entry("sklearn.ensemble._gb.GradientBoostingClassifier", List.of(new Param(LEARNING_RATE))),
151+
Map.entry("sklearn.ensemble._gb.GradientBoostingRegressor", List.of(new Param(LEARNING_RATE))),
152+
Map.entry("sklearn.ensemble._hist_gradient_boosting.gradient_boosting.HistGradientBoostingClassifier",
153+
List.of(new Param(LEARNING_RATE))),
154+
Map.entry("sklearn.ensemble._hist_gradient_boosting.gradient_boosting.HistGradientBoostingRegressor",
155+
List.of(new Param(LEARNING_RATE))),
156+
Map.entry("sklearn.ensemble._forest.RandomForestClassifier", List.of(new Param("min_samples_leaf"), new Param("max_features"))),
157+
Map.entry("sklearn.ensemble._forest.RandomForestRegressor", List.of(new Param("min_samples_leaf"), new Param("max_features"))),
158+
Map.entry("sklearn.linear_model._coordinate_descent.ElasticNet", List.of(new Param("alpha", 0), new Param("l1_ratio"))),
159+
Map.entry("sklearn.neighbors._unsupervised.NearestNeighbors", List.of(new Param(N_NEIGHBORS, 0))),
160+
Map.entry("sklearn.neighbors._classification.KNeighborsClassifier", List.of(new Param(N_NEIGHBORS, 0))),
161+
Map.entry("sklearn.neighbors._regression.KNeighborsRegressor", List.of(new Param(N_NEIGHBORS, 0))),
162+
Map.entry("sklearn.svm._classes.NuSVC", List.of(new Param("nu"), new Param(KERNEL), new Param(GAMMA))),
163+
Map.entry("sklearn.svm._classes.NuSVR", List.of(new Param(C), new Param(KERNEL), new Param(GAMMA))),
164+
Map.entry("sklearn.svm._classes.SVC", List.of(new Param(C), new Param(KERNEL), new Param(GAMMA))),
165+
Map.entry("sklearn.svm._classes.SVR", List.of(new Param(C), new Param(KERNEL), new Param(GAMMA))),
166+
Map.entry("sklearn.tree._classes.DecisionTreeClassifier", List.of(new Param("ccp_alpha"))),
167+
Map.entry("sklearn.tree._classes.DecisionTreeRegressor", List.of(new Param("ccp_alpha"))),
168+
Map.entry("sklearn.neural_network._multilayer_perceptron.MLPClassifier", List.of(new Param("hidden_layer_sizes", 0))),
169+
Map.entry("sklearn.neural_network._multilayer_perceptron.MLPRegressor", List.of(new Param("hidden_layer_sizes", 0))),
170+
Map.entry("sklearn.preprocessing._polynomial.PolynomialFeatures", List.of(new Param("degree", 0), new Param("interaction_only"))));
171+
172+
private static final Set<String> SEARCH_CV_FQNS = Set.of(
173+
"sklearn.model_selection._search.GridSearchCV",
174+
"sklearn.model_selection._search.RandomizedSearchCV",
175+
"sklearn.model_selection._search_successive_halving.HalvingRandomSearchCV",
176+
"sklearn.model_selection._search_successive_halving.HalvingGridSearchCV");
177+
178+
private static final Set<String> PIPELINE_FQNS = Set.of(
179+
"sklearn.pipeline.make_pipeline",
180+
"sklearn.pipeline.Pipeline");
181+
182+
public static Optional<List<Param>> getMissingParameters(String name, CallExpression callExpression) {
183+
return Optional.ofNullable(SK_LEARN_ESTIMATORS_AND_PARAMETERS_TO_CHECK.get(name))
184+
.filter(parameters -> !isDirectlyUsedInSearchCV(callExpression))
185+
.filter(parameters -> !isSetParamsCalled(callExpression))
186+
.filter(parameters -> !isPartOfPipelineAndSearchCV(callExpression))
187+
.filter(parameters -> isMissingAHyperparameter(callExpression, parameters));
188+
}
189+
190+
private static boolean isDirectlyUsedInSearchCV(CallExpression callExpression) {
191+
return Optional.ofNullable(TreeUtils.firstAncestorOfKind(callExpression, REGULAR_ARGUMENT))
192+
.flatMap(TreeUtils.toOptionalInstanceOfMapper(RegularArgument.class))
193+
.map(SkLearnCheck::isArgumentPartOfSearchCV)
194+
.orElse(false);
195+
}
196+
197+
private static boolean isSetParamsCalled(CallExpression callExpression) {
198+
return Expressions.getAssignedName(callExpression)
199+
.map(Name::symbol)
200+
.map(Symbol::usages)
201+
.map(SkLearnCheck::isUsedWithSetParams)
202+
.orElse(false);
203+
}
204+
205+
private static boolean isUsedWithSetParams(List<Usage> usages) {
206+
return usages.stream()
207+
.map(Usage::tree)
208+
.map(Tree::parent)
209+
.filter(parent -> parent.is(Tree.Kind.QUALIFIED_EXPR))
210+
.map(TreeUtils.toInstanceOfMapper(QualifiedExpression.class))
211+
.filter(Objects::nonNull)
212+
.map(qExp -> qExp.name().name())
213+
.anyMatch("set_params"::equals);
214+
}
215+
216+
private static boolean isPartOfPipelineAndSearchCV(CallExpression callExpression) {
217+
return Expressions.getAssignedName(callExpression)
218+
.map(SkLearnCheck::isEstimatorUsedInSearchCV)
219+
.or(() -> getPipelineAssignement(callExpression)
220+
.map(SkLearnCheck::isEstimatorUsedInSearchCV))
221+
.orElse(false);
222+
}
223+
224+
private static boolean isEstimatorUsedInSearchCV(Name estimator) {
225+
return Optional.ofNullable(estimator.symbol())
226+
.map(Symbol::usages)
227+
.map(usages -> usages.stream()
228+
.map(Usage::tree)
229+
.map(Tree::parent)
230+
.filter(parent -> parent.is(REGULAR_ARGUMENT))
231+
.map(RegularArgument.class::cast)
232+
.anyMatch(SkLearnCheck::isArgumentPartOfSearchCV))
233+
.orElse(false);
234+
}
235+
236+
private static boolean isArgumentPartOfSearchCV(RegularArgument arg) {
237+
return Optional.ofNullable(TreeUtils.firstAncestorOfKind(arg, CALL_EXPR))
238+
.flatMap(toOptionalInstanceOfMapper(CallExpression.class))
239+
.map(CallExpression::calleeSymbol)
240+
.map(Symbol::fullyQualifiedName)
241+
.map(SEARCH_CV_FQNS::contains)
242+
.orElse(false);
243+
}
244+
245+
private static Optional<Name> getPipelineAssignement(CallExpression callExpression) {
246+
return Optional.ofNullable(TreeUtils.firstAncestorOfKind(callExpression, CALL_EXPR))
247+
.flatMap(TreeUtils.toOptionalInstanceOfMapper(CallExpression.class))
248+
.filter(callExp -> Optional.ofNullable(callExp.calleeSymbol())
249+
.map(Symbol::fullyQualifiedName)
250+
.map(PIPELINE_FQNS::contains)
251+
.orElse(false))
252+
.flatMap(Expressions::getAssignedName);
253+
}
254+
}
255+
}

0 commit comments

Comments
 (0)