Skip to content

Commit ac812bc

Browse files
authored
SONARPY-2147 fix rule S6973 (#1981)
* SONARPY-2147 fix message bug * SONARPY-2147 fix nested searchcv bug
1 parent 4862874 commit ac812bc

File tree

3 files changed

+37
-42
lines changed

3 files changed

+37
-42
lines changed

python-checks/src/main/java/org/sonar/python/checks/MissingHyperParameterCheck.java

Lines changed: 30 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
*/
2020
package org.sonar.python.checks;
2121

22+
import java.util.Collections;
2223
import java.util.List;
2324
import java.util.Map;
2425
import java.util.Objects;
@@ -74,19 +75,20 @@ private static void checkEstimator(SubscriptionContext ctx) {
7475
}
7576

7677
private static void checkPyTorchOptimizer(String name, CallExpression callExpression, SubscriptionContext ctx) {
77-
PyTorchCheck.getMissingParameters(name, callExpression)
78-
.map(MissingHyperParameterCheck::toParameterNames)
79-
.ifPresent(parameters -> ctx.addIssue(callExpression, formatMessage(parameters, PYTORCH_MESSAGE)));
80-
}
78+
List<String> missingParams = PyTorchCheck.getMissingParameters(name, callExpression).stream()
79+
.map(Param::name)
80+
.toList();
8181

82-
private static void checkSkLearnEstimator(String name, CallExpression callExpression, SubscriptionContext ctx) {
83-
SkLearnCheck.getMissingParameters(name, callExpression)
84-
.map(MissingHyperParameterCheck::toParameterNames)
85-
.ifPresent(parameters -> ctx.addIssue(callExpression, formatMessage(parameters, SKLEARN_MESSAGE)));
82+
if (!missingParams.isEmpty()) {
83+
ctx.addIssue(callExpression, formatMessage(missingParams, PYTORCH_MESSAGE));
84+
}
8685
}
8786

88-
private static List<String> toParameterNames(List<Param> parameters) {
89-
return parameters.stream().map(Param::name).toList();
87+
private static void checkSkLearnEstimator(String name, CallExpression callExpression, SubscriptionContext ctx) {
88+
List<String> missingParams = SkLearnCheck.getMissingParameters(name, callExpression).stream().map(Param::name).toList();
89+
if(!missingParams.isEmpty()) {
90+
ctx.addIssue(callExpression, formatMessage(missingParams, SKLEARN_MESSAGE));
91+
}
9092
}
9193

9294
private static String formatMessage(List<String> missingArgs, String formatString) {
@@ -101,12 +103,12 @@ private static String formatMessage(List<String> missingArgs, String formatStrin
101103

102104

103105
// common method used by both the PyTorchCheck class and SkLearnCheck class
104-
private static boolean isMissingAHyperparameter(CallExpression callExpression, List<Param> parametersToCheck) {
106+
private static List<Param> filterUsedHyperparameter(CallExpression callExpression, List<Param> parametersToCheck) {
105107
return parametersToCheck.stream()
106-
.map(param -> param.position()
108+
.filter(param -> param.position()
107109
.map(position -> TreeUtils.nthArgumentOrKeyword(position, param.name, callExpression.arguments()))
108-
.orElse(TreeUtils.argumentByKeyword(param.name, callExpression.arguments())))
109-
.anyMatch(Objects::isNull);
110+
.orElse(TreeUtils.argumentByKeyword(param.name, callExpression.arguments())) == null)
111+
.toList();
110112
}
111113

112114
private static class PyTorchCheck {
@@ -130,10 +132,11 @@ private static class PyTorchCheck {
130132
Map.entry("torch.optim.SGD", List.of(new Param(LR, 1), new Param("momentum", 2), new Param(WEIGHT_DECAY, 4)))
131133
);
132134

133-
public static Optional<List<Param>> getMissingParameters(String name, CallExpression callExpression) {
135+
public static List<Param> getMissingParameters(String name, CallExpression callExpression) {
134136
return Optional.ofNullable(PY_TORCH_ESTIMATORS_AND_PARAMETERS_TO_CHECK.get(name))
135137
.filter(parameters -> !Expressions.containsSpreadOperator(callExpression.arguments()))
136-
.filter(parameters -> isMissingAHyperparameter(callExpression, parameters));
138+
.map(parameters -> filterUsedHyperparameter(callExpression, parameters))
139+
.orElse(Collections.emptyList());
137140
}
138141
}
139142

@@ -175,23 +178,24 @@ private static class SkLearnCheck {
175178
"sklearn.model_selection._search_successive_halving.HalvingRandomSearchCV",
176179
"sklearn.model_selection._search_successive_halving.HalvingGridSearchCV");
177180

178-
private static final Set<String> PIPELINE_FQNS = Set.of(
179-
"sklearn.pipeline.make_pipeline",
180-
"sklearn.pipeline.Pipeline");
181-
182-
public static Optional<List<Param>> getMissingParameters(String name, CallExpression callExpression) {
181+
public static List<Param> getMissingParameters(String name, CallExpression callExpression) {
183182
return Optional.ofNullable(SK_LEARN_ESTIMATORS_AND_PARAMETERS_TO_CHECK.get(name))
184183
.filter(parameters -> !isDirectlyUsedInSearchCV(callExpression))
185184
.filter(parameters -> !isSetParamsCalled(callExpression))
186185
.filter(parameters -> !isPartOfPipelineAndSearchCV(callExpression))
187-
.filter(parameters -> isMissingAHyperparameter(callExpression, parameters));
186+
.map(parameters -> filterUsedHyperparameter(callExpression, parameters))
187+
.orElse(Collections.emptyList());
188188
}
189189

190190
private static boolean isDirectlyUsedInSearchCV(CallExpression callExpression) {
191-
return Optional.ofNullable(TreeUtils.firstAncestorOfKind(callExpression, REGULAR_ARGUMENT))
192-
.flatMap(TreeUtils.toOptionalInstanceOfMapper(RegularArgument.class))
193-
.map(SkLearnCheck::isArgumentPartOfSearchCV)
194-
.orElse(false);
191+
Tree current = callExpression;
192+
do{
193+
current = TreeUtils.firstAncestorOfKind(current, REGULAR_ARGUMENT);
194+
if(current instanceof RegularArgument arg && isArgumentPartOfSearchCV(arg)) {
195+
return true;
196+
}
197+
} while(current != null);
198+
return false;
195199
}
196200

197201
private static boolean isSetParamsCalled(CallExpression callExpression) {
@@ -216,8 +220,6 @@ private static boolean isUsedWithSetParams(List<Usage> usages) {
216220
private static boolean isPartOfPipelineAndSearchCV(CallExpression callExpression) {
217221
return Expressions.getAssignedName(callExpression)
218222
.map(SkLearnCheck::isEstimatorUsedInSearchCV)
219-
.or(() -> getPipelineAssignement(callExpression)
220-
.map(SkLearnCheck::isEstimatorUsedInSearchCV))
221223
.orElse(false);
222224
}
223225

@@ -241,15 +243,5 @@ private static boolean isArgumentPartOfSearchCV(RegularArgument arg) {
241243
.map(SEARCH_CV_FQNS::contains)
242244
.orElse(false);
243245
}
244-
245-
private static Optional<Name> getPipelineAssignement(CallExpression callExpression) {
246-
return Optional.ofNullable(TreeUtils.firstAncestorOfKind(callExpression, CALL_EXPR))
247-
.flatMap(TreeUtils.toOptionalInstanceOfMapper(CallExpression.class))
248-
.filter(callExp -> Optional.ofNullable(callExp.calleeSymbol())
249-
.map(Symbol::fullyQualifiedName)
250-
.map(PIPELINE_FQNS::contains)
251-
.orElse(false))
252-
.flatMap(Expressions::getAssignedName);
253-
}
254246
}
255247
}

python-checks/src/test/resources/checks/pytorch_optimizer_hyperparameters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def adadelta(param_dict, param_list, rho, eps):
1515
from torch.optim import Adadelta
1616
noncompliant = Adadelta() # Noncompliant {{Add the missing hyperparameters lr and weight_decay for this PyTorch optimizer.}}
1717
noncompliant = Adadelta(some_extra_variable=3) # Noncompliant {{Add the missing hyperparameters lr and weight_decay for this PyTorch optimizer.}}
18-
noncompliant = Adadelta(model.parameters(), lr = 0.001) # Noncompliant
18+
noncompliant = Adadelta(model.parameters(), lr = 0.001) # Noncompliant {{Add the missing hyperparameter weight_decay for this PyTorch optimizer.}}
1919
noncompliant = Adadelta(model.parameters(), weight_decay = 0.32) # Noncompliant
2020

2121
optimizer = Adadelta(model.parameters(), 0.001, rho, eps, 0.23)

python-checks/src/test/resources/checks/sklearn_estimator_hyperparameters.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
RandomForestClassifier, RandomForestRegressor,
66
)
77
from sklearn.linear_model import ElasticNet
8-
from sklearn.neighbors import NearestNeighbors, KNeighborsClassifier,KNeighborsRegressor
8+
from sklearn.neighbors import NearestNeighbors, KNeighborsClassifier,KNeighborsRegressor
99
from sklearn.svm import SVC, SVR, NuSVC, NuSVR
1010
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
1111
from sklearn.neural_network import MLPClassifier, MLPRegressor
@@ -36,7 +36,7 @@ def non_compliant():
3636
SVC(random_state=42) # Noncompliant
3737
SVC(C=1) # Noncompliant
3838
SVR() # Noncompliant {{Add the missing hyperparameters C, kernel and gamma for this Scikit-learn estimator.}}
39-
SVR(C=1.2, kernel="poly") # Noncompliant
39+
SVR(C=1.2, kernel="poly") # Noncompliant {{Add the missing hyperparameter gamma for this Scikit-learn estimator.}}
4040
NuSVC() # Noncompliant
4141
NuSVR(gamma="scale", kernel="poly") # Noncompliant
4242

@@ -95,7 +95,10 @@ def compliant():
9595

9696
pipe2 = Pipeline([('svc'), SVC()]) # FN
9797

98-
grid2 = GridSearchCV(pipe2, param_grid={'svc__C': [1, 10, 100]})
98+
grid2 = GridSearchCV(pipe2, param_grid={'svc__C': [1, 10, 100]})
99+
100+
grid3 = GridSearchCV(make_pipeline(SVC()), param_grid={'svc__C': [1, 10, 100]})
101+
grid4 = GridSearchCV(some_method(some_other_method(make_pipeline(SVC()))), param_grid={'svc__C': [1, 10, 100]})
99102

100103
GridSearchCV(s, param_grid={'svc__C': [1, 10, 100]}) # FN
101104

0 commit comments

Comments
 (0)