Skip to content

Commit f31a720

Browse files
authored
SONARPY-1769: Scikit-Learn random_state check (#1785)
1 parent e17fd70 commit f31a720

File tree

6 files changed

+191
-105
lines changed

6 files changed

+191
-105
lines changed

python-checks/src/main/java/org/sonar/python/checks/CheckList.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ public static Iterable<Class> getChecks() {
234234
IncompatibleOperandsCheck.class,
235235
InconsistentTypeHintCheck.class,
236236
IncorrectExceptionTypeCheck.class,
237+
IncorrectParameterDatetimeConstructorsCheck.class,
237238
IndexMethodCheck.class,
238239
InequalityUsageCheck.class,
239240
InfiniteRecursionCheck.class,
@@ -282,7 +283,6 @@ public static Iterable<Class> getChecks() {
282283
NumpyWeekMaskValidationCheck.class,
283284
NumpyIsNanCheck.class,
284285
NumpyListOverGeneratorCheck.class,
285-
NumpyRandomSeedCheck.class,
286286
NumpyRandomStateCheck.class,
287287
NumpyWhereOneConditionCheck.class,
288288
UnusedGroupNamesCheck.class,
@@ -304,13 +304,13 @@ public static Iterable<Class> getChecks() {
304304
ProcessSignallingCheck.class,
305305
PropertyAccessorParameterCountCheck.class,
306306
PytzUsageCheck.class,
307-
IncorrectParameterDatetimeConstructorsCheck.class,
308307
PseudoRandomCheck.class,
309308
PublicApiIsSecuritySensitiveCheck.class,
310309
PubliclyWritableDirectoriesCheck.class,
311310
PublicNetworkAccessToCloudResourcesCheck.class,
312311
PytzTimeZoneInDatetimeConstructorCheck.class,
313312
RaiseOutsideExceptCheck.class,
313+
RandomSeedCheck.class,
314314
RedosCheck.class,
315315
RedundantJumpCheck.class,
316316
PossessiveQuantifierContinuationCheck.class,

python-checks/src/main/java/org/sonar/python/checks/NumpyRandomSeedCheck.java

Lines changed: 0 additions & 88 deletions
This file was deleted.
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/*
2+
* SonarQube Python Plugin
3+
* Copyright (C) 2011-2024 SonarSource SA
4+
* mailto:info AT sonarsource DOT com
5+
*
6+
* This program is free software; you can redistribute it and/or
7+
* modify it under the terms of the GNU Lesser General Public
8+
* License as published by the Free Software Foundation; either
9+
* version 3 of the License, or (at your option) any later version.
10+
*
11+
* This program is distributed in the hope that it will be useful,
12+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14+
* Lesser General Public License for more details.
15+
*
16+
* You should have received a copy of the GNU Lesser General Public License
17+
* along with this program; if not, write to the Free Software Foundation,
18+
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
19+
*/
20+
package org.sonar.python.checks;
21+
22+
import java.util.Map;
23+
import java.util.Optional;
24+
import java.util.Set;
25+
import java.util.function.Predicate;
26+
import javax.annotation.Nullable;
27+
import org.sonar.check.Rule;
28+
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
29+
import org.sonar.plugins.python.api.SubscriptionContext;
30+
import org.sonar.plugins.python.api.symbols.FunctionSymbol;
31+
import org.sonar.plugins.python.api.symbols.Symbol;
32+
import org.sonar.plugins.python.api.symbols.Symbol.Kind;
33+
import org.sonar.plugins.python.api.tree.CallExpression;
34+
import org.sonar.plugins.python.api.tree.Expression;
35+
import org.sonar.plugins.python.api.tree.Name;
36+
import org.sonar.plugins.python.api.tree.RegularArgument;
37+
import org.sonar.plugins.python.api.tree.Tree;
38+
import org.sonar.python.cfg.fixpoint.ReachingDefinitionsAnalysis;
39+
import org.sonar.python.semantic.ClassSymbolImpl;
40+
import org.sonar.python.semantic.SymbolUtils;
41+
import org.sonar.python.tree.TreeUtils;
42+
43+
@Rule(key = "S6709")
44+
public class RandomSeedCheck extends PythonSubscriptionCheck {
45+
46+
private static final String NUMPY_SEED_ARG_NAME = "seed";
47+
48+
private static final Map<String, String> SEED_METHODS_TO_CHECK = Map.of(
49+
"numpy.seed", NUMPY_SEED_ARG_NAME,
50+
"numpy.random.seed", NUMPY_SEED_ARG_NAME,
51+
"numpy.random.default_rng", NUMPY_SEED_ARG_NAME,
52+
"numpy.random.SeedSequence", "entropy",
53+
"numpy.random.PCG64", NUMPY_SEED_ARG_NAME,
54+
"numpy.random.PCG64DXSM", NUMPY_SEED_ARG_NAME,
55+
"numpy.random.MT19937", NUMPY_SEED_ARG_NAME,
56+
"numpy.random.SFC64", NUMPY_SEED_ARG_NAME,
57+
"numpy.random.Philox", NUMPY_SEED_ARG_NAME);
58+
59+
private static final String SKLEARN_FQN = "sklearn";
60+
private static final String SKLEARN_ARG_NAME = "random_state";
61+
62+
private static final String MESSAGE = "Provide a seed for this random generator.";
63+
private static final String SKLEARN_MESSAGE = "Provide a seed for the random_state parameter.";
64+
65+
private ReachingDefinitionsAnalysis reachingDefinitionsAnalysis;
66+
67+
@Override
68+
public void initialize(Context context) {
69+
context.registerSyntaxNodeConsumer(Tree.Kind.FILE_INPUT,
70+
ctx -> this.reachingDefinitionsAnalysis = new ReachingDefinitionsAnalysis(ctx.pythonFile()));
71+
context.registerSyntaxNodeConsumer(Tree.Kind.CALL_EXPR, this::checkEmptySeedCall);
72+
}
73+
74+
private void checkEmptySeedCall(SubscriptionContext ctx) {
75+
CallExpression call = (CallExpression) ctx.syntaxNode();
76+
77+
Optional<Symbol> maybeCalleeSymbol = Optional.ofNullable(call.calleeSymbol());
78+
79+
maybeCalleeSymbol
80+
.map(Symbol::fullyQualifiedName)
81+
.map(SEED_METHODS_TO_CHECK::get)
82+
.filter(argName -> isArgumentAbsentOrNone(TreeUtils.nthArgumentOrKeyword(0, argName, call.arguments())))
83+
.map(arg -> MESSAGE)
84+
.or(() -> maybeCalleeSymbol
85+
.filter(symbol -> symbol.fullyQualifiedName() != null && symbol.fullyQualifiedName().startsWith(SKLEARN_FQN))
86+
.filter(RandomSeedCheck::hasRandomStateParameter)
87+
.filter(symbol -> isArgumentAbsentOrNone(TreeUtils.argumentByKeyword(SKLEARN_ARG_NAME, call.arguments())))
88+
.map(symbol -> SKLEARN_MESSAGE))
89+
.ifPresent(message -> ctx.addIssue(call.callee(), message));
90+
}
91+
92+
private static boolean hasRandomStateParameter(Symbol calleeSymbol) {
93+
return isClassInstantiationWithRandomStateParameter(calleeSymbol)
94+
.or(() -> isFunctionWithRandomStateParameter(calleeSymbol))
95+
.orElse(false);
96+
}
97+
98+
private static Optional<Boolean> isClassInstantiationWithRandomStateParameter(Symbol calleeSymbol) {
99+
return Optional.of(calleeSymbol)
100+
.filter(s -> s.is(Kind.CLASS))
101+
.map(ClassSymbolImpl.class::cast)
102+
.map(classSymbol -> classSymbol.declaredMembers()
103+
.stream()
104+
.filter(member -> "__init__".equals(member.name()))
105+
.toList())
106+
.filter(members -> members.size() == 1)
107+
.map(members -> members.get(0))
108+
.map(RandomSeedCheck::hasRandomStateParameter);
109+
}
110+
111+
private static Optional<Boolean> isFunctionWithRandomStateParameter(Symbol calleeSymbol) {
112+
return Optional.of(calleeSymbol)
113+
.filter(s1 -> s1.is(Kind.FUNCTION))
114+
.map(SymbolUtils::getFunctionSymbols)
115+
.filter(symbols -> symbols.size() == 1)
116+
.map(symbols -> symbols.get(0))
117+
.map(symbol -> symbol.parameters()
118+
.stream()
119+
.map(FunctionSymbol.Parameter::name)
120+
.anyMatch(SKLEARN_ARG_NAME::equals));
121+
}
122+
123+
private boolean isArgumentAbsentOrNone(@Nullable RegularArgument arg) {
124+
return arg == null || arg.expression().is(Tree.Kind.NONE) || isAssignedNone(arg.expression());
125+
}
126+
127+
private boolean isAssignedNone(Expression exp) {
128+
return Optional.of(exp)
129+
.flatMap(TreeUtils.toOptionalInstanceOfMapper(Name.class))
130+
.map(reachingDefinitionsAnalysis::valuesAtLocation)
131+
.filter(Predicate.not(Set::isEmpty))
132+
.filter(values -> values.stream().allMatch(value -> value.is(Tree.Kind.NONE))).isPresent();
133+
}
134+
}

python-checks/src/test/java/org/sonar/python/checks/NumpyRandomSeedCheckTest.java renamed to python-checks/src/test/java/org/sonar/python/checks/RandomSeedCheckTest.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,17 @@
2222
import org.junit.jupiter.api.Test;
2323
import org.sonar.python.checks.utils.PythonCheckVerifier;
2424

25-
class NumpyRandomSeedCheckTest {
25+
class RandomSeedCheckTest {
26+
RandomSeedCheck check = new RandomSeedCheck();
27+
2628
@Test
27-
void test() {
28-
PythonCheckVerifier.verify("src/test/resources/checks/numpyRandomSeedCheck.py", new NumpyRandomSeedCheck());
29+
void test_numpy() {
30+
PythonCheckVerifier.verify("src/test/resources/checks/randomSeedNumpy.py", check );
31+
}
32+
33+
@Test
34+
void test_sklearn() {
35+
PythonCheckVerifier.verify("src/test/resources/checks/randomSeedSKlearn.py", check);
2936
}
30-
3137
}
3238

python-checks/src/test/resources/checks/numpyRandomSeedCheck.py renamed to python-checks/src/test/resources/checks/randomSeedNumpy.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,35 +2,35 @@
22

33
def failure():
44
gen = np.random.default_rng() # Noncompliant {{Provide a seed for this random generator.}}
5-
# ^^^^^^^^^^^^^^^^^^^^^^^
5+
# ^^^^^^^^^^^^^^^^^^^^^
66

77
gen = np.random.SeedSequence() # Noncompliant
8-
# ^^^^^^^^^^^^^^^^^^^^^^^^
8+
# ^^^^^^^^^^^^^^^^^^^^^^
99

1010
gen = np.random.SeedSequence(entropy=None) # Noncompliant
11-
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
11+
# ^^^^^^^^^^^^^^^^^^^^^^
1212
gen = np.random.SeedSequence(spawn_key=[123]) # Noncompliant
13-
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
13+
# ^^^^^^^^^^^^^^^^^^^^^^
1414

1515
from numpy.random import SFC64, MT19937, PCG64, PCG64DXSM, Philox
1616

1717
gen = SFC64() # Noncompliant
18-
# ^^^^^^^
18+
# ^^^^^
1919
gen = MT19937(seed=None) # Noncompliant
20-
# ^^^^^^^^^^^^^^^^^^
21-
gen = PCG64() # Noncompliant
2220
# ^^^^^^^
21+
gen = PCG64() # Noncompliant
22+
# ^^^^^
2323

2424
gen = PCG64DXSM() # Noncompliant
25-
# ^^^^^^^^^^^
25+
# ^^^^^^^^^
2626
a = None
2727
gen = Philox(a) # Noncompliant
28-
# ^^^^^^^^^
28+
# ^^^^^^
2929

3030
gen = np.random.seed() # Noncompliant
31-
# ^^^^^^^^^^^^^^^^
31+
# ^^^^^^^^^^^^^^
3232
gen = np.seed(None) # Noncompliant
33-
# ^^^^^^^^^^^^^
33+
# ^^^^^^^
3434

3535

3636
def success():
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from sklearn.model_selection import train_test_split
2+
from sklearn.svm import SVC
3+
from sklearn.datasets import load_iris, make_blobs
4+
5+
def failure():
6+
X, y = load_iris(return_X_y=True)
7+
X_train, X_test, y_train, y_test = train_test_split(X, y) # Noncompliant {{Provide a seed for the random_state parameter.}}
8+
# ^^^^^^^^^^^^^^^^
9+
svc = SVC() # Noncompliant
10+
# ^^^
11+
12+
X, y = make_blobs(n_samples=1300, random_state=None) # Noncompliant
13+
# ^^^^^^^^^^
14+
15+
def success():
16+
from sklearn.ensemble import RandomForestClassifier
17+
rfc = RandomForestClassifier(random_state=0) # Compliant
18+
from sklearn.linear_model import SGDClassifier
19+
sgd = SGDClassifier(random_state=foo()) # Compliant
20+
21+
def sklearn_seed(rng):
22+
svc = SVC(random_state=rng) # Compliant
23+
24+
def foo(random_state=None):
25+
pass
26+
27+
foo() # Compliant
28+
29+
def ambiguous():
30+
from sklearn.svm import SVC as something
31+
from sklearn.datasets import make_blobs as something
32+
33+
something = something() # FN
34+

0 commit comments

Comments
 (0)