Skip to content

Commit c7097e5

Browse files
authored
SONARPY-1898: Subclasses of "torch.nn.Module" should call the initializer (#1974)
* SONARPY-1898 update metadata * SONARPY-1898 implement basic detection logic * SONARPY-1898 add test cases and clean up * SONARPY-1898 add license header and add to checklist * SONARPY-1898 fix sonarqube issues * SONARPY-1898 add test coverage * SONARPY-1898 add quickfix * SONARPY-1898 address reviewer comments * SONARPY-1898 add coverage
1 parent 5947a5d commit c7097e5

File tree

7 files changed

+322
-0
lines changed

7 files changed

+322
-0
lines changed

python-checks/src/main/java/org/sonar/python/checks/CheckList.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ public static Iterable<Class> getChecks() {
369369
TorchAutogradVariableShouldNotBeUsedCheck.class,
370370
TorchLoadLeadsToUntrustedCodeExecutionCheck.class,
371371
TorchModuleModeShouldBeSetAfterLoadingCheck.class,
372+
TorchModuleShouldCallInitCheck.class,
372373
TrailingCommentCheck.class,
373374
TrailingWhitespaceCheck.class,
374375
TypeAliasAnnotationCheck.class,
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* SonarQube Python Plugin
3+
* Copyright (C) 2011-2024 SonarSource SA
4+
* mailto:info AT sonarsource DOT com
5+
*
6+
* This program is free software; you can redistribute it and/or
7+
* modify it under the terms of the GNU Lesser General Public
8+
* License as published by the Free Software Foundation; either
9+
* version 3 of the License, or (at your option) any later version.
10+
*
11+
* This program is distributed in the hope that it will be useful,
12+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14+
* Lesser General Public License for more details.
15+
*
16+
* You should have received a copy of the GNU Lesser General Public License
17+
* along with this program; if not, write to the Free Software Foundation,
18+
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
19+
*/
20+
package org.sonar.python.checks;
21+
22+
import java.util.Optional;
23+
import javax.annotation.Nullable;
24+
import org.sonar.check.Rule;
25+
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
26+
import org.sonar.plugins.python.api.quickfix.PythonQuickFix;
27+
import org.sonar.plugins.python.api.quickfix.PythonTextEdit;
28+
import org.sonar.plugins.python.api.symbols.FunctionSymbol;
29+
import org.sonar.plugins.python.api.symbols.Symbol;
30+
import org.sonar.plugins.python.api.tree.ArgList;
31+
import org.sonar.plugins.python.api.tree.CallExpression;
32+
import org.sonar.plugins.python.api.tree.ClassDef;
33+
import org.sonar.plugins.python.api.tree.Expression;
34+
import org.sonar.plugins.python.api.tree.FunctionDef;
35+
import org.sonar.plugins.python.api.tree.QualifiedExpression;
36+
import org.sonar.plugins.python.api.tree.RegularArgument;
37+
import org.sonar.plugins.python.api.tree.Tree;
38+
import org.sonar.python.checks.utils.CheckUtils;
39+
import org.sonar.python.quickfix.TextEditUtils;
40+
import org.sonar.python.tree.TreeUtils;
41+
42+
@Rule(key = "S6978")
43+
public class TorchModuleShouldCallInitCheck extends PythonSubscriptionCheck {
44+
private static final String TORCH_NN_MODULE = "torch.nn.Module";
45+
private static final String MESSAGE = "Add a call to super().__init__().";
46+
private static final String SECONDARY_MESSAGE = "Inheritance happens here";
47+
public static final String QUICK_FIX_MESSAGE = "insert call to super constructor";
48+
49+
@Override
50+
public void initialize(Context context) {
51+
context.registerSyntaxNodeConsumer(Tree.Kind.FUNCDEF, ctx -> {
52+
FunctionDef funcDef = (FunctionDef) ctx.syntaxNode();
53+
ClassDef classDef = CheckUtils.getParentClassDef(funcDef);
54+
if (isInheritingFromTorchModule(classDef) && isConstructor(funcDef) && isMissingSuperCall(funcDef)) {
55+
PreciseIssue issue = ctx.addIssue(funcDef.name(), MESSAGE);
56+
issue.secondary(classDef.name(), SECONDARY_MESSAGE);
57+
createQuickFix(funcDef).ifPresent(issue::addQuickFix);
58+
}
59+
});
60+
}
61+
62+
private static boolean isConstructor(FunctionDef funcDef) {
63+
FunctionSymbol symbol = TreeUtils.getFunctionSymbolFromDef(funcDef);
64+
return symbol != null && "__init__".equals(symbol.name()) && funcDef.isMethodDefinition();
65+
}
66+
67+
private static boolean isInheritingFromTorchModule(@Nullable ClassDef classDef) {
68+
if (classDef == null) return false;
69+
ArgList args = classDef.args();
70+
return args != null && args.arguments().stream()
71+
.flatMap(TreeUtils.toStreamInstanceOfMapper(RegularArgument.class))
72+
.map(arg -> getQualifiedName(arg.expression()))
73+
.anyMatch(expr -> expr.filter(TORCH_NN_MODULE::equals).isPresent());
74+
}
75+
76+
private static Optional<String> getQualifiedName(Expression node) {
77+
return TreeUtils.getSymbolFromTree(node).flatMap(symbol -> Optional.ofNullable(symbol.fullyQualifiedName()));
78+
}
79+
80+
private static boolean isMissingSuperCall(FunctionDef funcDef) {
81+
ClassDef parentClassDef = CheckUtils.getParentClassDef(funcDef);
82+
return parentClassDef != null && !TreeUtils.hasDescendant(parentClassDef, t -> t.is(Tree.Kind.CALL_EXPR) && isSuperConstructorCall((CallExpression) t));
83+
}
84+
85+
private static boolean isSuperConstructorCall(CallExpression callExpr) {
86+
return callExpr.callee() instanceof QualifiedExpression qualifiedCallee && isSuperCall(qualifiedCallee.qualifier()) && "__init__".equals(qualifiedCallee.name().name());
87+
}
88+
89+
private static boolean isSuperCall(Expression qualifier) {
90+
if (qualifier instanceof CallExpression callExpression) {
91+
Symbol superSymbol = callExpression.calleeSymbol();
92+
return superSymbol != null && "super".equals(superSymbol.name());
93+
}
94+
return false;
95+
}
96+
97+
private static Optional<PythonQuickFix> createQuickFix(FunctionDef functionDef) {
98+
// it is hard to find the correct indentation when the function def and the body is on the same line (e.g. def test(): pass).
99+
// Thus we don't produce a quickfix in those cases
100+
if(functionDef.colon().line() == functionDef.body().firstToken().line()) {
101+
return Optional.empty();
102+
}
103+
104+
PythonTextEdit pythonTextEdit = TextEditUtils.insertLineAfter(functionDef.colon(), functionDef.body(), "super().__init__()");
105+
return Optional.of(PythonQuickFix.newQuickFix(QUICK_FIX_MESSAGE, pythonTextEdit));
106+
}
107+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
<p>This rule raises an issue when a class is a Pytorch module and does not call the <code>super().__init__()</code> method in its constructor.</p>
2+
<h2>Why is this an issue?</h2>
3+
<p>To provide the AutoGrad functionality, the Pytorch library needs to set up the necessary data structures in the base class. If the
4+
<code>super().__init__()</code> method is not called, the module will not be able to keep track of its parameters and other attributes.</p>
5+
<p>For example, when trying to instantiate a module like <code>nn.Linear</code> without calling the <code>super().__init__()</code> method, the
6+
instantiation will fail when it tries to register it as a submodule of the parent module.</p>
7+
<pre>
8+
import torch.nn as nn
9+
10+
class MyCustomModule(nn.Module):
11+
def __init__(self, input_size, output_size):
12+
self.fc = nn.Linear(input_size, output_size)
13+
14+
model = MyCustomModule(10, 5) # AttributeError: cannot assign module before Module.__init__() call
15+
</pre>
16+
<h2>How to fix it</h2>
17+
<p>Add a call to <code>super().__init__()</code> at the beginning of the constructor of the class.</p>
18+
<h3>Code examples</h3>
19+
<h4>Noncompliant code example</h4>
20+
<pre data-diff-id="1" data-diff-type="noncompliant">
21+
import torch.nn as nn
22+
23+
class MyCustomModule(nn.Module):
24+
def __init__(self, input_size, output_size):
25+
self.fc = nn.Linear(input_size, output_size) # Noncompliant: creating an nn.Linear without calling super().__init__()
26+
</pre>
27+
<h4>Compliant solution</h4>
28+
<pre data-diff-id="1" data-diff-type="compliant">
29+
import torch.nn as nn
30+
31+
class MyCustomModule(nn.Module):
32+
def __init__(self, input_size, output_size):
33+
super().__init__()
34+
self.fc = nn.Linear(input_size, output_size)
35+
</pre>
36+
<h2>Resources</h2>
37+
<h3>Documentation</h3>
38+
<ul>
39+
<li> Pytorch documentation - <a href="https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module">torch.nn.Module</a> </li>
40+
</ul>
41+
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
{
2+
"title": "Subclasses of \"torch.nn.Module\" should call the initializer",
3+
"type": "BUG",
4+
"status": "ready",
5+
"remediation": {
6+
"func": "Constant\/Issue",
7+
"constantCost": "1min"
8+
},
9+
"tags": [],
10+
"defaultSeverity": "Major",
11+
"ruleSpecification": "RSPEC-6978",
12+
"sqKey": "S6978",
13+
"scope": "All",
14+
"quickfix": "targeted",
15+
"code": {
16+
"impacts": {
17+
"RELIABILITY": "HIGH"
18+
},
19+
"attribute": "LOGICAL"
20+
}
21+
}

python-checks/src/main/resources/org/sonar/l10n/py/rules/python/Sonar_way_profile.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@
246246
"S6972",
247247
"S6973",
248248
"S6974",
249+
"S6978",
249250
"S6979",
250251
"S6982",
251252
"S6983",
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* SonarQube Python Plugin
3+
* Copyright (C) 2011-2024 SonarSource SA
4+
* mailto:info AT sonarsource DOT com
5+
*
6+
* This program is free software; you can redistribute it and/or
7+
* modify it under the terms of the GNU Lesser General Public
8+
* License as published by the Free Software Foundation; either
9+
* version 3 of the License, or (at your option) any later version.
10+
*
11+
* This program is distributed in the hope that it will be useful,
12+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14+
* Lesser General Public License for more details.
15+
*
16+
* You should have received a copy of the GNU Lesser General Public License
17+
* along with this program; if not, write to the Free Software Foundation,
18+
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
19+
*/
20+
package org.sonar.python.checks;
21+
22+
import org.junit.jupiter.api.Test;
23+
import org.sonar.python.checks.quickfix.PythonQuickFixVerifier;
24+
import org.sonar.python.checks.utils.PythonCheckVerifier;
25+
26+
class TorchModuleShouldCallInitCheckTest {
27+
@Test
28+
void test() {
29+
PythonCheckVerifier.verify("src/test/resources/checks/torchModuleShouldCallInit.py", new TorchModuleShouldCallInitCheck());
30+
}
31+
32+
@Test
33+
void testQuickFix() {
34+
PythonQuickFixVerifier.verify(new TorchModuleShouldCallInitCheck(),
35+
"""
36+
import torch
37+
class Test(torch.nn.Module):
38+
def __init__(self):
39+
some_method()
40+
""",
41+
"""
42+
import torch
43+
class Test(torch.nn.Module):
44+
def __init__(self):
45+
super().__init__()
46+
some_method()
47+
""");
48+
49+
PythonQuickFixVerifier.verify(new TorchModuleShouldCallInitCheck(),
50+
"""
51+
import torch
52+
class Test(torch.nn.Module):
53+
def __init__(self):
54+
...
55+
""",
56+
"""
57+
import torch
58+
class Test(torch.nn.Module):
59+
def __init__(self):
60+
super().__init__()
61+
...
62+
""");
63+
64+
PythonQuickFixVerifier.verifyNoQuickFixes(new TorchModuleShouldCallInitCheck(),
65+
"""
66+
import torch
67+
class Test(torch.nn.Module):
68+
def __init__(self): some_method()
69+
""");
70+
71+
PythonQuickFixVerifier.verifyNoQuickFixes(new TorchModuleShouldCallInitCheck(),
72+
"""
73+
import torch
74+
class Test(torch.nn.Module):
75+
def __init__(self): pass
76+
""");
77+
}
78+
}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import torch.nn as nn
2+
3+
class NonCompliantModule(nn.Module):
4+
#^^^^^^^^^^^^^^^^^^> {{Inheritance happens here}}
5+
def __init__(self): #Noncompliant {{Add a call to super().__init__().}}
6+
#^^^^^^^^
7+
...
8+
9+
10+
class NonCompliantModule(OtherModule, nn.Module):
11+
def __init__(self): #Noncompliant
12+
...
13+
14+
class CompliantModule(NonExistantClass):
15+
def __init__(self, encoder, decoder):
16+
...
17+
18+
class CompliantModule(nn.Module):
19+
def __init__(self):
20+
super().__init__()
21+
class CompliantModule(nn.Module):
22+
pass
23+
24+
class CompliantModule(nn.Module):
25+
def __init__(self, cond):
26+
if cond:
27+
super().__init__()
28+
29+
class CompliantModule(nn.Module):
30+
def __init__(self, super):
31+
super().__init__()
32+
33+
class CompliantModule(nn.Module):
34+
def __init__(self):
35+
(lambda x: x)()
36+
(lambda x: x)().test()
37+
self.call_super()
38+
39+
def call_super(self):
40+
super().__init__()
41+
42+
43+
class CompliantModule(nn.Module):
44+
def __init__(self, fake_super):
45+
fake_super().__init__()
46+
super().not_init()
47+
48+
def call_super(self):
49+
super().__init__()
50+
51+
class CompliantModule2(nn.Module):
52+
def __init__(self): #FN
53+
do_something()
54+
55+
class Nested(Other):
56+
def __init__(self):
57+
super().__init__()
58+
59+
class UnrelatedCompliantModule:
60+
def __init__(self):
61+
...
62+
63+
def __init__(test):
64+
pass
65+
66+
def some_other_func():
67+
pass
68+
69+
class CompliantModule(nn.Module):
70+
if cond:
71+
class some_func: pass
72+
else:
73+
def some_func(): pass

0 commit comments

Comments
 (0)