19
19
*/
20
20
package org .sonar .python .checks ;
21
21
22
+ import java .util .Collections ;
22
23
import java .util .List ;
23
24
import java .util .Map ;
24
25
import java .util .Objects ;
@@ -74,19 +75,20 @@ private static void checkEstimator(SubscriptionContext ctx) {
74
75
}
75
76
76
77
private static void checkPyTorchOptimizer (String name , CallExpression callExpression , SubscriptionContext ctx ) {
77
- PyTorchCheck .getMissingParameters (name , callExpression )
78
- .map (MissingHyperParameterCheck ::toParameterNames )
79
- .ifPresent (parameters -> ctx .addIssue (callExpression , formatMessage (parameters , PYTORCH_MESSAGE )));
80
- }
78
+ List <String > missingParams = PyTorchCheck .getMissingParameters (name , callExpression ).stream ()
79
+ .map (Param ::name )
80
+ .toList ();
81
81
82
- private static void checkSkLearnEstimator (String name , CallExpression callExpression , SubscriptionContext ctx ) {
83
- SkLearnCheck .getMissingParameters (name , callExpression )
84
- .map (MissingHyperParameterCheck ::toParameterNames )
85
- .ifPresent (parameters -> ctx .addIssue (callExpression , formatMessage (parameters , SKLEARN_MESSAGE )));
82
+ if (!missingParams .isEmpty ()) {
83
+ ctx .addIssue (callExpression , formatMessage (missingParams , PYTORCH_MESSAGE ));
84
+ }
86
85
}
87
86
88
- private static List <String > toParameterNames (List <Param > parameters ) {
89
- return parameters .stream ().map (Param ::name ).toList ();
87
+ private static void checkSkLearnEstimator (String name , CallExpression callExpression , SubscriptionContext ctx ) {
88
+ List <String > missingParams = SkLearnCheck .getMissingParameters (name , callExpression ).stream ().map (Param ::name ).toList ();
89
+ if (!missingParams .isEmpty ()) {
90
+ ctx .addIssue (callExpression , formatMessage (missingParams , SKLEARN_MESSAGE ));
91
+ }
90
92
}
91
93
92
94
private static String formatMessage (List <String > missingArgs , String formatString ) {
@@ -101,12 +103,12 @@ private static String formatMessage(List<String> missingArgs, String formatStrin
101
103
102
104
103
105
// common method used by both the PyTorchCheck class and SkLearnCheck class
104
- private static boolean isMissingAHyperparameter (CallExpression callExpression , List <Param > parametersToCheck ) {
106
+ private static List < Param > filterUsedHyperparameter (CallExpression callExpression , List <Param > parametersToCheck ) {
105
107
return parametersToCheck .stream ()
106
- .map (param -> param .position ()
108
+ .filter (param -> param .position ()
107
109
.map (position -> TreeUtils .nthArgumentOrKeyword (position , param .name , callExpression .arguments ()))
108
- .orElse (TreeUtils .argumentByKeyword (param .name , callExpression .arguments ())))
109
- .anyMatch ( Objects :: isNull );
110
+ .orElse (TreeUtils .argumentByKeyword (param .name , callExpression .arguments ())) == null )
111
+ .toList ( );
110
112
}
111
113
112
114
private static class PyTorchCheck {
@@ -130,10 +132,11 @@ private static class PyTorchCheck {
130
132
Map .entry ("torch.optim.SGD" , List .of (new Param (LR , 1 ), new Param ("momentum" , 2 ), new Param (WEIGHT_DECAY , 4 )))
131
133
);
132
134
133
- public static Optional < List <Param > > getMissingParameters (String name , CallExpression callExpression ) {
135
+ public static List <Param > getMissingParameters (String name , CallExpression callExpression ) {
134
136
return Optional .ofNullable (PY_TORCH_ESTIMATORS_AND_PARAMETERS_TO_CHECK .get (name ))
135
137
.filter (parameters -> !Expressions .containsSpreadOperator (callExpression .arguments ()))
136
- .filter (parameters -> isMissingAHyperparameter (callExpression , parameters ));
138
+ .map (parameters -> filterUsedHyperparameter (callExpression , parameters ))
139
+ .orElse (Collections .emptyList ());
137
140
}
138
141
}
139
142
@@ -175,23 +178,24 @@ private static class SkLearnCheck {
175
178
"sklearn.model_selection._search_successive_halving.HalvingRandomSearchCV" ,
176
179
"sklearn.model_selection._search_successive_halving.HalvingGridSearchCV" );
177
180
178
- private static final Set <String > PIPELINE_FQNS = Set .of (
179
- "sklearn.pipeline.make_pipeline" ,
180
- "sklearn.pipeline.Pipeline" );
181
-
182
- public static Optional <List <Param >> getMissingParameters (String name , CallExpression callExpression ) {
181
+ public static List <Param > getMissingParameters (String name , CallExpression callExpression ) {
183
182
return Optional .ofNullable (SK_LEARN_ESTIMATORS_AND_PARAMETERS_TO_CHECK .get (name ))
184
183
.filter (parameters -> !isDirectlyUsedInSearchCV (callExpression ))
185
184
.filter (parameters -> !isSetParamsCalled (callExpression ))
186
185
.filter (parameters -> !isPartOfPipelineAndSearchCV (callExpression ))
187
- .filter (parameters -> isMissingAHyperparameter (callExpression , parameters ));
186
+ .map (parameters -> filterUsedHyperparameter (callExpression , parameters ))
187
+ .orElse (Collections .emptyList ());
188
188
}
189
189
190
190
private static boolean isDirectlyUsedInSearchCV (CallExpression callExpression ) {
191
- return Optional .ofNullable (TreeUtils .firstAncestorOfKind (callExpression , REGULAR_ARGUMENT ))
192
- .flatMap (TreeUtils .toOptionalInstanceOfMapper (RegularArgument .class ))
193
- .map (SkLearnCheck ::isArgumentPartOfSearchCV )
194
- .orElse (false );
191
+ Tree current = callExpression ;
192
+ do {
193
+ current = TreeUtils .firstAncestorOfKind (current , REGULAR_ARGUMENT );
194
+ if (current instanceof RegularArgument arg && isArgumentPartOfSearchCV (arg )) {
195
+ return true ;
196
+ }
197
+ } while (current != null );
198
+ return false ;
195
199
}
196
200
197
201
private static boolean isSetParamsCalled (CallExpression callExpression ) {
@@ -216,8 +220,6 @@ private static boolean isUsedWithSetParams(List<Usage> usages) {
216
220
private static boolean isPartOfPipelineAndSearchCV (CallExpression callExpression ) {
217
221
return Expressions .getAssignedName (callExpression )
218
222
.map (SkLearnCheck ::isEstimatorUsedInSearchCV )
219
- .or (() -> getPipelineAssignement (callExpression )
220
- .map (SkLearnCheck ::isEstimatorUsedInSearchCV ))
221
223
.orElse (false );
222
224
}
223
225
@@ -241,15 +243,5 @@ private static boolean isArgumentPartOfSearchCV(RegularArgument arg) {
241
243
.map (SEARCH_CV_FQNS ::contains )
242
244
.orElse (false );
243
245
}
244
-
245
- private static Optional <Name > getPipelineAssignement (CallExpression callExpression ) {
246
- return Optional .ofNullable (TreeUtils .firstAncestorOfKind (callExpression , CALL_EXPR ))
247
- .flatMap (TreeUtils .toOptionalInstanceOfMapper (CallExpression .class ))
248
- .filter (callExp -> Optional .ofNullable (callExp .calleeSymbol ())
249
- .map (Symbol ::fullyQualifiedName )
250
- .map (PIPELINE_FQNS ::contains )
251
- .orElse (false ))
252
- .flatMap (Expressions ::getAssignedName );
253
- }
254
246
}
255
247
}
0 commit comments