Skip to content

Commit e17fd70

Browse files
authored
SONARPY-1780 Rule S6794: Subclasses of Scikit-Learn's "BaseEstimator" should not set attributes ending with "_" in the "__init__" method (#1768)
1 parent bf7efe7 commit e17fd70

File tree

8 files changed

+400
-1
lines changed

8 files changed

+400
-1
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"project:pecos/examples/qp2q/models/pecosq2q.py": [
3+
148
4+
]
5+
}

python-checks/src/main/java/org/sonar/python/checks/CheckList.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ public static Iterable<Class> getChecks() {
338338
SingleCharacterAlternationCheck.class,
339339
SingleCharCharacterClassCheck.class,
340340
SkippedTestNoReasonCheck.class,
341+
SkLearnEstimatorDontInitializeEstimatedValuesCheck.class,
341342
SpecialMethodParamListCheck.class,
342343
SpecialMethodReturnTypeCheck.class,
343344
SQLQueriesCheck.class,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
/*
2+
* SonarQube Python Plugin
3+
* Copyright (C) 2011-2024 SonarSource SA
4+
* mailto:info AT sonarsource DOT com
5+
*
6+
* This program is free software; you can redistribute it and/or
7+
* modify it under the terms of the GNU Lesser General Public
8+
* License as published by the Free Software Foundation; either
9+
* version 3 of the License, or (at your option) any later version.
10+
*
11+
* This program is distributed in the hope that it will be useful,
12+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14+
* Lesser General Public License for more details.
15+
*
16+
* You should have received a copy of the GNU Lesser General Public License
17+
* along with this program; if not, write to the Free Software Foundation,
18+
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
19+
*/
20+
package org.sonar.python.checks;
21+
22+
import java.util.HashMap;
23+
import java.util.Map;
24+
import java.util.Optional;
25+
import java.util.Set;
26+
import java.util.stream.Stream;
27+
import org.sonar.check.Rule;
28+
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
29+
import org.sonar.plugins.python.api.SubscriptionContext;
30+
import org.sonar.plugins.python.api.quickfix.PythonQuickFix;
31+
import org.sonar.plugins.python.api.symbols.ClassSymbol;
32+
import org.sonar.plugins.python.api.tree.AssignmentStatement;
33+
import org.sonar.plugins.python.api.tree.BaseTreeVisitor;
34+
import org.sonar.plugins.python.api.tree.ClassDef;
35+
import org.sonar.plugins.python.api.tree.QualifiedExpression;
36+
import org.sonar.plugins.python.api.tree.Tree;
37+
import org.sonar.plugins.python.api.tree.Tuple;
38+
import org.sonar.python.quickfix.TextEditUtils;
39+
import org.sonar.python.tree.FunctionDefImpl;
40+
import org.sonar.python.tree.TreeUtils;
41+
42+
@Rule(key = "S6974")
43+
public class SkLearnEstimatorDontInitializeEstimatedValuesCheck extends PythonSubscriptionCheck {
44+
45+
private static final String BASE_ESTIMATOR_FULLY_QUALIFIED_NAME = "sklearn.base.BaseEstimator";
46+
private static final Set<String> MIXINS_FULLY_QUALIFIED_NAME = Set.of(
47+
"sklearn.base.BiclusterMixin",
48+
"sklearn.base.ClassifierMixin",
49+
"sklearn.base.ClusterMixin",
50+
"sklearn.base.DensityMixin",
51+
"sklearn.base.MetaEstimatorMixin",
52+
"sklearn.base.OneToOneFeatureMixin",
53+
"sklearn.base.OutlierMixin",
54+
"sklearn.base.RegressorMixin",
55+
"sklearn.base.TransformerMixin");
56+
57+
private static final String MESSAGE = "Move this estimated attribute in the `fit` method.";
58+
private static final String MESSAGE_SECONDARY = "The attribute is used in this estimator";
59+
public static final String QUICK_FIX_MESSAGE = "Remove the statement";
60+
public static final String QUICK_FIX_RENAME_MESSAGE = "Remove all trailing underscores from the variable name";
61+
62+
@Override
63+
public void initialize(Context context) {
64+
context.registerSyntaxNodeConsumer(Tree.Kind.FUNCDEF, SkLearnEstimatorDontInitializeEstimatedValuesCheck::checkFunction);
65+
}
66+
67+
private static boolean inheritsMixin(ClassSymbol classSymbol) {
68+
return MIXINS_FULLY_QUALIFIED_NAME.stream().anyMatch(classSymbol::isOrExtends);
69+
}
70+
71+
private static void checkFunction(SubscriptionContext subscriptionContext) {
72+
FunctionDefImpl functionDef = (FunctionDefImpl) subscriptionContext.syntaxNode();
73+
if (!"__init__".equals(functionDef.name().name())) {
74+
return;
75+
}
76+
77+
var classDef = (ClassDef) TreeUtils.firstAncestorOfKind(functionDef, Tree.Kind.CLASSDEF);
78+
if (classDef == null) {
79+
return;
80+
}
81+
var classSymbol = TreeUtils.getClassSymbolFromDef(classDef);
82+
if (classSymbol == null) {
83+
return;
84+
}
85+
boolean inheritsBaseEstimator = Optional.of(classSymbol)
86+
.map(classSymbol1 -> classSymbol1.isOrExtends(BASE_ESTIMATOR_FULLY_QUALIFIED_NAME))
87+
.orElse(false);
88+
89+
if (!inheritsBaseEstimator && !inheritsMixin(classSymbol)) {
90+
return;
91+
}
92+
93+
var visitor = new VariableDeclarationEndingWithUnderscoreVisitor();
94+
functionDef.body().accept(visitor);
95+
var offendingVariables = visitor.qualifiedExpressions;
96+
var secondaryLocation = classDef.name();
97+
offendingVariables
98+
.forEach((qualifiedExpression, assignmentStatement) -> {
99+
var issue = subscriptionContext.addIssue(qualifiedExpression.name(), MESSAGE).secondary(secondaryLocation, MESSAGE_SECONDARY);
100+
101+
createQuickFix(assignmentStatement).ifPresent(issue::addQuickFix);
102+
issue.addQuickFix(createQuickFixRename(qualifiedExpression));
103+
});
104+
}
105+
106+
private static PythonQuickFix createQuickFixRename(QualifiedExpression qualifiedExpression) {
107+
var quickFix = PythonQuickFix.newQuickFix(QUICK_FIX_RENAME_MESSAGE);
108+
var newName = qualifiedExpression.name().name().replaceAll("_+$", "");
109+
return quickFix.addTextEdit(TextEditUtils.renameAllUsages(qualifiedExpression.name(), newName)).build();
110+
}
111+
112+
private static Optional<PythonQuickFix> createQuickFix(AssignmentStatement assignmentStatement) {
113+
114+
var builder = PythonQuickFix.newQuickFix(QUICK_FIX_MESSAGE);
115+
116+
if (assignmentStatement.lhsExpressions().size() != 1 || assignmentStatement.lhsExpressions().stream().anyMatch(expressions -> expressions.expressions().size() != 1)) {
117+
return Optional.empty();
118+
}
119+
builder.addTextEdit(TextEditUtils.removeStatement(assignmentStatement));
120+
if (assignmentStatement.assignedValue().is(Tree.Kind.NONE)) {
121+
return Optional.of(builder.build());
122+
}
123+
return Optional.empty();
124+
}
125+
126+
private static class VariableDeclarationEndingWithUnderscoreVisitor extends BaseTreeVisitor {
127+
128+
private final Map<QualifiedExpression, AssignmentStatement> qualifiedExpressions = new HashMap<>();
129+
130+
private static boolean isOffendingQualifiedExpression(QualifiedExpression qualifiedExpression) {
131+
return !qualifiedExpression.name().name().startsWith("__") && qualifiedExpression.name().name().endsWith("_") && qualifiedExpression.qualifier().is(Tree.Kind.NAME)
132+
&& "self".equals(((org.sonar.plugins.python.api.tree.Name) qualifiedExpression.qualifier()).name());
133+
}
134+
135+
@Override
136+
public void visitAssignmentStatement(AssignmentStatement pyAssignmentStatementTree) {
137+
var offendingQualifiedExpressions = pyAssignmentStatementTree.lhsExpressions()
138+
.stream()
139+
.flatMap(expressionList -> expressionList.expressions().stream())
140+
.filter(expression -> expression.is(Tree.Kind.QUALIFIED_EXPR))
141+
.map(QualifiedExpression.class::cast);
142+
143+
var offendingTuples = pyAssignmentStatementTree.lhsExpressions()
144+
.stream()
145+
.flatMap(expressionList -> expressionList.expressions().stream())
146+
.filter(expression -> expression.is(Tree.Kind.TUPLE))
147+
.map(Tuple.class::cast)
148+
.flatMap(tuple -> tuple.elements().stream())
149+
.filter(expression -> expression.is(Tree.Kind.QUALIFIED_EXPR))
150+
.map(QualifiedExpression.class::cast);
151+
152+
Stream.concat(
153+
offendingQualifiedExpressions, offendingTuples)
154+
.filter(VariableDeclarationEndingWithUnderscoreVisitor::isOffendingQualifiedExpression)
155+
.forEach(qualifiedExpression -> qualifiedExpressions.put(qualifiedExpression, pyAssignmentStatementTree));
156+
super.visitAssignmentStatement(pyAssignmentStatementTree);
157+
}
158+
}
159+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
<p>This rule raises an issue when an attributes ending with <code>_</code> is set in the <code>__init__</code> method of a class inheriting from
2+
Scikit-Learn <code>BaseEstimator</code></p>
3+
<h2>Why is this an issue?</h2>
4+
<p>On a Scikit-Learn estimator, attributes that have a trailing underscore represents attributes that are estimated. These attributes have to be set
5+
in the fit method. Their presence are used to verify if an estimator has been fitted.</p>
6+
<pre>
7+
from sklearn.neighbors import KNeighborsClassifier
8+
9+
X = [[0], [1], [2], [3]]
10+
y = [0, 0, 1, 1]
11+
knn = KNeighborsClassifier(n_neighbors=1)
12+
knn.fit(X, y)
13+
knn.n_samples_fit_
14+
</pre>
15+
<p>In the example above the attributes of the <code>KNeighborsClassifier</code>, <code>n_samples_fit_</code> is set only after the estimator’s
16+
<code>fit</code> method is called. Calling <code>n_samples_fit_</code> before the estimator is fitted would raise an <code>AttributeError</code>
17+
exception.</p>
18+
<p>When implementing a custom estimator by subclassing Scikit-Learn’s <code>BaseEstimator</code>, it is important to follow the above convention and
19+
not set attributes with a trailing underscore inside the <code>__init__</code> method.</p>
20+
<h2>How to fix it</h2>
21+
<p>To fix this issue, move the attributes with a trailing underscore from the <code>__init__</code> method to the <code>fit</code> method.</p>
22+
<h3>Code examples</h3>
23+
<h4>Noncompliant code example</h4>
24+
<pre data-diff-id="1" data-diff-type="noncompliant">
25+
from sklearn.base import BaseEstimator
26+
27+
class MyEstimator(BaseEstimator):
28+
def __init__(self):
29+
self.estimated_attribute_ = None # Noncompliant: an estimated attribute is set in the __init__ method.
30+
</pre>
31+
<h4>Compliant solution</h4>
32+
<pre data-diff-id="1" data-diff-type="compliant">
33+
from sklearn.base import BaseEstimator
34+
35+
class MyEstimator(BaseEstimator):
36+
def fit(self, X, y):
37+
self.estimated_attribute_ = some_estimation(X) # Compliant
38+
</pre>
39+
<h2>Resources</h2>
40+
<h3>Documentation</h3>
41+
<ul>
42+
<li> Scikit-Learn documentation - <a href="https://scikit-learn.org/stable/developers/develop.html#parameters-and-init">Parameters and init</a>
43+
</li>
44+
<li> {rule:python:S6970} - The Scikit-learn <code>fit</code> method should be called before methods yielding results </li>
45+
</ul>
46+
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
{
2+
"title": "Subclasses of Scikit-Learn\u0027s \"BaseEstimator\" should not set attributes ending with \"_\" in the \"__init__\" method",
3+
"type": "CODE_SMELL",
4+
"status": "ready",
5+
"remediation": {
6+
"func": "Constant\/Issue",
7+
"constantCost": "5min"
8+
},
9+
"tags": [],
10+
"defaultSeverity": "Major",
11+
"ruleSpecification": "RSPEC-6974",
12+
"sqKey": "S6974",
13+
"scope": "All",
14+
"quickfix": "partial",
15+
"code": {
16+
"impacts": {
17+
"MAINTAINABILITY": "MEDIUM",
18+
"RELIABILITY": "HIGH"
19+
},
20+
"attribute": "CONVENTIONAL"
21+
}
22+
}

python-checks/src/main/resources/org/sonar/l10n/py/rules/python/Sonar_way_profile.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@
240240
"S6928",
241241
"S6929",
242242
"S6969",
243-
"S6971"
243+
"S6971",
244+
"S6974"
244245
]
245246
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
* SonarQube Python Plugin
3+
* Copyright (C) 2011-2024 SonarSource SA
4+
* mailto:info AT sonarsource DOT com
5+
*
6+
* This program is free software; you can redistribute it and/or
7+
* modify it under the terms of the GNU Lesser General Public
8+
* License as published by the Free Software Foundation; either
9+
* version 3 of the License, or (at your option) any later version.
10+
*
11+
* This program is distributed in the hope that it will be useful,
12+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14+
* Lesser General Public License for more details.
15+
*
16+
* You should have received a copy of the GNU Lesser General Public License
17+
* along with this program; if not, write to the Free Software Foundation,
18+
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
19+
*/
20+
package org.sonar.python.checks;
21+
22+
import org.junit.jupiter.api.Test;
23+
import org.sonar.python.checks.quickfix.PythonQuickFixVerifier;
24+
import org.sonar.python.checks.utils.PythonCheckVerifier;
25+
26+
class SkLearnEstimatorDontInitializeEstimatedValuesCheckTest {
27+
@Test
28+
void test() {
29+
PythonCheckVerifier.verify("src/test/resources/checks/sklearn_estimator_underscore_initialization.py", new SkLearnEstimatorDontInitializeEstimatedValuesCheck());
30+
}
31+
32+
@Test
33+
void testQuickfix1() {
34+
PythonQuickFixVerifier.verify(
35+
new SkLearnEstimatorDontInitializeEstimatedValuesCheck(),
36+
"""
37+
from sklearn.base import BaseEstimator
38+
class InheritingEstimator(BaseEstimator):
39+
def __init__(self) -> None:
40+
self.a_ = None
41+
...""",
42+
"""
43+
from sklearn.base import BaseEstimator
44+
class InheritingEstimator(BaseEstimator):
45+
def __init__(self) -> None:
46+
...""",
47+
"""
48+
from sklearn.base import BaseEstimator
49+
class InheritingEstimator(BaseEstimator):
50+
def __init__(self) -> None:
51+
self.a = None
52+
..."""
53+
);
54+
}
55+
@Test
56+
void testQuickfix2() {
57+
PythonQuickFixVerifier.verify(
58+
new SkLearnEstimatorDontInitializeEstimatedValuesCheck(),
59+
"""
60+
from sklearn.base import BaseEstimator
61+
class InheritingEstimator(BaseEstimator):
62+
def __init__(self) -> None:
63+
self._something_a_______ = None""",
64+
"""
65+
from sklearn.base import BaseEstimator
66+
class InheritingEstimator(BaseEstimator):
67+
def __init__(self) -> None:
68+
pass""",
69+
"""
70+
from sklearn.base import BaseEstimator
71+
class InheritingEstimator(BaseEstimator):
72+
def __init__(self) -> None:
73+
self._something_a = None"""
74+
);
75+
}
76+
@Test
77+
void testQuickfixEmptyFunc() {
78+
PythonQuickFixVerifier.verify(
79+
new SkLearnEstimatorDontInitializeEstimatedValuesCheck(),
80+
"""
81+
from sklearn.base import ClassifierMixin
82+
class InheritingEstimator(ClassifierMixin):
83+
def __init__(self) -> None:
84+
self.a_ = None""",
85+
"""
86+
from sklearn.base import ClassifierMixin
87+
class InheritingEstimator(ClassifierMixin):
88+
def __init__(self) -> None:
89+
pass""",
90+
"""
91+
from sklearn.base import ClassifierMixin
92+
class InheritingEstimator(ClassifierMixin):
93+
def __init__(self) -> None:
94+
self.a = None"""
95+
);
96+
}
97+
}

0 commit comments

Comments
 (0)