Skip to content

Commit 0068790

Browse files
authored
Rule S6982: module mode should be set after load_state_dict (#1976)
* SONARPY-1910 add metadata * SONARPY-1910 Rule S6982: module mode should be set after load_state_dict * SONARPY-1910 add license header and update Checklist * SONARPY-1910 fix SQ issues * SONARPY-1910 address pr comments
1 parent f92897a commit 0068790

File tree

7 files changed

+282
-0
lines changed

7 files changed

+282
-0
lines changed

python-checks/src/main/java/org/sonar/python/checks/CheckList.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ public static Iterable<Class> getChecks() {
368368
TooManyReturnsCheck.class,
369369
TorchAutogradVariableShouldNotBeUsedCheck.class,
370370
TorchLoadLeadsToUntrustedCodeExecutionCheck.class,
371+
TorchModuleModeShouldBeSetAfterLoadingCheck.class,
371372
TrailingCommentCheck.class,
372373
TrailingWhitespaceCheck.class,
373374
TypeAliasAnnotationCheck.class,
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/*
2+
* SonarQube Python Plugin
3+
* Copyright (C) 2011-2024 SonarSource SA
4+
* mailto:info AT sonarsource DOT com
5+
*
6+
* This program is free software; you can redistribute it and/or
7+
* modify it under the terms of the GNU Lesser General Public
8+
* License as published by the Free Software Foundation; either
9+
* version 3 of the License, or (at your option) any later version.
10+
*
11+
* This program is distributed in the hope that it will be useful,
12+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14+
* Lesser General Public License for more details.
15+
*
16+
* You should have received a copy of the GNU Lesser General Public License
17+
* along with this program; if not, write to the Free Software Foundation,
18+
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
19+
*/
20+
package org.sonar.python.checks;
21+
22+
import java.util.Collections;
23+
import java.util.List;
24+
import java.util.Optional;
25+
import java.util.Set;
26+
import org.sonar.check.Rule;
27+
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
28+
import org.sonar.plugins.python.api.symbols.Symbol;
29+
import org.sonar.plugins.python.api.symbols.Usage;
30+
import org.sonar.plugins.python.api.tree.Argument;
31+
import org.sonar.plugins.python.api.tree.CallExpression;
32+
import org.sonar.plugins.python.api.tree.Expression;
33+
import org.sonar.plugins.python.api.tree.Name;
34+
import org.sonar.plugins.python.api.tree.QualifiedExpression;
35+
import org.sonar.plugins.python.api.tree.RegularArgument;
36+
import org.sonar.plugins.python.api.tree.Tree;
37+
import org.sonar.python.cfg.fixpoint.ReachingDefinitionsAnalysis;
38+
import org.sonar.python.tree.TreeUtils;
39+
40+
@Rule(key = "S6982")
41+
public class TorchModuleModeShouldBeSetAfterLoadingCheck extends PythonSubscriptionCheck {
42+
private static final Set<String> STATE_SETTING_FUNCTION_FQNS = Set.of("eval", "train");
43+
private static final String TORCH_LOAD_FQN = "torch.load";
44+
private static final String LOAD_STATE_DICT_NAME = "load_state_dict";
45+
private static final String MESSAGE = "Set the module in training or evaluation mode.";
46+
private static final int IS_TORCH_LOAD_CALL_MAX_RECURSIVE_COUNTER = 10;
47+
48+
private ReachingDefinitionsAnalysis reachingDefinitionsAnalysis;
49+
50+
@Override
51+
public void initialize(Context context) {
52+
context.registerSyntaxNodeConsumer(Tree.Kind.FILE_INPUT, ctx -> reachingDefinitionsAnalysis =
53+
new ReachingDefinitionsAnalysis(ctx.pythonFile()));
54+
55+
context.registerSyntaxNodeConsumer(Tree.Kind.CALL_EXPR, ctx -> {
56+
CallExpression callExpr = (CallExpression) ctx.syntaxNode();
57+
List<Usage> receiverUsages = getForwardUsages(callExpr);
58+
if (isLoadStateDictCall(callExpr) && !hasEvalOrTrainUsage(receiverUsages) && !isModelPassedOn(receiverUsages)) {
59+
ctx.addIssue(callExpr.callee(), MESSAGE);
60+
}
61+
});
62+
}
63+
64+
private boolean isLoadStateDictCall(CallExpression callExpr) {
65+
// To properly check if the correct load_state_dict is called, typeshed type information would be required.
66+
// Since this is currently not possible, we check if the parameter to load_state_dict is torch.load(...),
67+
// with the assumption that if torch.load is passed to this load_state_dict, it is probably the correct method
68+
if(callExpr.callee() instanceof QualifiedExpression qualifiedExpr) {
69+
return LOAD_STATE_DICT_NAME.equals(qualifiedExpr.name().name()) && containsTorchLoadCall(callExpr.arguments());
70+
}
71+
return false;
72+
}
73+
74+
private boolean containsTorchLoadCall(List<Argument> args) {
75+
return args.stream()
76+
.flatMap(TreeUtils.toStreamInstanceOfMapper(RegularArgument.class))
77+
.anyMatch(arg -> isTorchLoadCall(arg.expression(), 0));
78+
}
79+
80+
private boolean isTorchLoadCall(Expression expr, int recursiveCounter) {
81+
if (recursiveCounter > IS_TORCH_LOAD_CALL_MAX_RECURSIVE_COUNTER) {
82+
return false;
83+
} else if (expr instanceof CallExpression callExpr) {
84+
Symbol calleeSymbol = callExpr.calleeSymbol();
85+
return calleeSymbol != null && TORCH_LOAD_FQN.equals(calleeSymbol.fullyQualifiedName());
86+
} else if (expr instanceof Name name) {
87+
return reachingDefinitionsAnalysis.valuesAtLocation(name).stream()
88+
.anyMatch(definitionExpr -> isTorchLoadCall(definitionExpr, recursiveCounter + 1));
89+
} else {
90+
return false;
91+
}
92+
}
93+
94+
private static List<Usage> getForwardUsages(CallExpression callExpr) {
95+
List<Usage> usages = getFunctionCallReceiverName(callExpr)
96+
.flatMap(name -> Optional.ofNullable(name.symbol()))
97+
.map(Symbol::usages)
98+
.orElse(Collections.emptyList());
99+
100+
return usages.stream()
101+
.filter(usage -> usage.tree().firstToken().line() > callExpr.firstToken().line())
102+
.toList();
103+
}
104+
105+
private static Optional<Name> getFunctionCallReceiverName(CallExpression callExpr) {
106+
return Optional.ofNullable(callExpr.callee())
107+
.flatMap(TreeUtils.toOptionalInstanceOfMapper(QualifiedExpression.class))
108+
.flatMap(qualifiedExpr -> Optional.ofNullable(qualifiedExpr.qualifier()))
109+
.flatMap(TreeUtils.toOptionalInstanceOfMapper(Name.class));
110+
}
111+
112+
private static boolean hasEvalOrTrainUsage(List<Usage> usages) {
113+
return usages.stream().anyMatch(TorchModuleModeShouldBeSetAfterLoadingCheck::isEvalOrTrain);
114+
}
115+
116+
private static boolean isEvalOrTrain(Usage usage) {
117+
Tree callTree = TreeUtils.firstAncestorOfKind(usage.tree(), Tree.Kind.CALL_EXPR);
118+
if (callTree != null) {
119+
CallExpression usageCall = (CallExpression) callTree;
120+
Symbol usageCallSymbol = usageCall.calleeSymbol();
121+
return usageCallSymbol != null && STATE_SETTING_FUNCTION_FQNS.contains(usageCallSymbol.name());
122+
}
123+
return false;
124+
}
125+
126+
private static boolean isModelPassedOn(List<Usage> usages) {
127+
return usages.stream().anyMatch(TorchModuleModeShouldBeSetAfterLoadingCheck::isPassingModel);
128+
}
129+
130+
private static boolean isPassingModel(Usage usage) {
131+
return TreeUtils.firstAncestorOfKind(usage.tree(), Tree.Kind.CALL_EXPR) != null;
132+
}
133+
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
<p>This rule raises an issue when a PyTorch model state is loaded and <code>torch.nn.Module.eval()</code> or <code>torch.nn.Module.train()</code> is
2+
not called.</p>
3+
<h2>Why is this an issue?</h2>
4+
<p>When using PyTorch it is common practice to load and save a model’s state from/to a <code>.pth</code> file. Doing so allows, for example, to
5+
instantiate an untrained model and load learned parameters coming from another pre-trained model. Once the learned parameters are loaded to the model
6+
it is important, before inferencing, to clearly state the intention by calling <code>torch.nn.Module.eval()</code> method to set the model in
7+
evaluation mode or calling <code>torch.nn.Module.train()</code> to indicate the training will resume. Failing to call
8+
<code>torch.nn.Module.eval()</code> would leave the model in training mode which may not be the intention.</p>
9+
<h2>How to fix it</h2>
10+
<p>Call the <code>torch.nn.Module.eval()</code> or <code>torch.nn.Module.train()</code> method on the model.</p>
11+
<h3>Code examples</h3>
12+
<h4>Noncompliant code example</h4>
13+
<pre data-diff-id="1" data-diff-type="noncompliant">
14+
import torch
15+
import torchvision.models as models
16+
17+
model = models.vgg16()
18+
model.load_state_dict(torch.load('model_weights.pth')) # Noncompliant: model.train() or model.eval() was not called.
19+
</pre>
20+
<h4>Compliant solution</h4>
21+
<pre data-diff-id="1" data-diff-type="compliant">
22+
import torch
23+
import torchvision.models as models
24+
25+
model = models.vgg16()
26+
model.load_state_dict(torch.load('model_weights.pth'))
27+
model.eval()
28+
</pre>
29+
<h2>Resources</h2>
30+
<h3>Documentation</h3>
31+
<ul>
32+
<li> PyTorch Documentation - <a href="https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.eval">eval - reference</a>
33+
</li>
34+
<li> PyTorch Documentation - <a href="https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.train">train - reference</a>
35+
</li>
36+
<li> PyTorch Documentation - <a href="https://pytorch.org/docs/stable/notes/autograd.html#evaluation-mode-nn-module-eval">Autograd - Evaluation
37+
Mode</a> </li>
38+
</ul>
39+
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
{
2+
"title": "\"model.eval()\" or \"model.train()\" should be called after loading the state of a PyTorch model",
3+
"type": "CODE_SMELL",
4+
"status": "ready",
5+
"remediation": {
6+
"func": "Constant\/Issue",
7+
"constantCost": "1min"
8+
},
9+
"tags": [
10+
"pytorch",
11+
"machine-learning"
12+
],
13+
"defaultSeverity": "Major",
14+
"ruleSpecification": "RSPEC-6982",
15+
"sqKey": "S6982",
16+
"scope": "All",
17+
"quickfix": "infeasible",
18+
"code": {
19+
"impacts": {
20+
"MAINTAINABILITY": "LOW",
21+
"RELIABILITY": "MEDIUM"
22+
},
23+
"attribute": "CLEAR"
24+
}
25+
}

python-checks/src/main/resources/org/sonar/l10n/py/rules/python/Sonar_way_profile.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@
247247
"S6973",
248248
"S6974",
249249
"S6979",
250+
"S6982",
250251
"S6983",
251252
"S6984",
252253
"S6985"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* SonarQube Python Plugin
3+
* Copyright (C) 2011-2024 SonarSource SA
4+
* mailto:info AT sonarsource DOT com
5+
*
6+
* This program is free software; you can redistribute it and/or
7+
* modify it under the terms of the GNU Lesser General Public
8+
* License as published by the Free Software Foundation; either
9+
* version 3 of the License, or (at your option) any later version.
10+
*
11+
* This program is distributed in the hope that it will be useful,
12+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14+
* Lesser General Public License for more details.
15+
*
16+
* You should have received a copy of the GNU Lesser General Public License
17+
* along with this program; if not, write to the Free Software Foundation,
18+
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
19+
*/
20+
package org.sonar.python.checks;
21+
22+
import org.junit.jupiter.api.Test;
23+
import org.sonar.python.checks.utils.PythonCheckVerifier;
24+
25+
class TorchModuleModeShouldBeSetAfterLoadingCheckTest {
26+
@Test
27+
void test() {
28+
PythonCheckVerifier.verify("src/test/resources/checks/torchModuleModeShouldBeSetAfterLoadingCheck.py", new TorchModuleModeShouldBeSetAfterLoadingCheck());
29+
}
30+
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import torch
2+
import torchvision.models as models
3+
4+
def noncompliant():
5+
model = models.vgg16()
6+
model.load_state_dict(torch.load('model_weights.pth')) # Noncompliant {{Set the module in training or evaluation mode.}}
7+
#^^^^^^^^^^^^^^^^^^^^^
8+
...
9+
10+
def noncompliant(model):
11+
model.load_state_dict(torch.load('model_weights.pth')) # Noncompliant
12+
13+
def noncompliant():
14+
model.load_state_dict(torch.load('model_weights.pth')) # Noncompliant
15+
16+
def noncompliant():
17+
get_model().load_state_dict(torch.load('model_weights.pth')) # Noncompliant
18+
19+
def noncompliant(model):
20+
weights = torch.load('model_weights.pth')
21+
weights2 = weights
22+
model.load_state_dict(weights2) # Noncompliant
23+
24+
def noncompliant():
25+
model = models.vgg16()
26+
model.train()
27+
model.load_state_dict(torch.load('model_weights.pth')) # Noncompliant
28+
other_model = model
29+
30+
def compliant(model):
31+
weights = weights
32+
model.load_state_dict(weights)
33+
34+
def compliant():
35+
model1 = models.vgg16()
36+
model1.load_state_dict(torch.load('model_weights.pth'))
37+
model1.eval()
38+
39+
def compliant():
40+
model2 = models.vgg16()
41+
model2.load_state_dict(torch.load('model_weights.pth'))
42+
other_model = model2
43+
model2.train()
44+
45+
def compliant():
46+
model3 = models.vgg16()
47+
model3.load_state_dict(torch.load('model_weights.pth')) # Ok if model is passed as argument to a function do not raise at all train or eval could be called in such functions
48+
foo(model3)
49+
50+
def compliant():
51+
# Ok since no torch.load() result is passed as an argument
52+
model.load_state_dict(1 + 1)
53+
model.load_state_dict((lambda x: x)())

0 commit comments

Comments
 (0)