Skip to content

Commit 93ffcfc

Browse files
authored
SONARPY-2151 S6982: Fix fn when import torch is used instead of the full import name torch.nn
1 parent fb1a65c commit 93ffcfc

File tree

2 files changed

+43
-31
lines changed

2 files changed

+43
-31
lines changed

python-checks/src/main/java/org/sonar/python/checks/TorchModuleModeShouldBeSetAfterLoadingCheck.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.sonar.plugins.python.api.tree.Name;
3232
import org.sonar.plugins.python.api.tree.QualifiedExpression;
3333
import org.sonar.plugins.python.api.tree.Tree;
34+
import org.sonar.plugins.python.api.types.InferredType;
3435
import org.sonar.python.tree.TreeUtils;
3536

3637
@Rule(key = "S6982")
@@ -56,8 +57,10 @@ private static boolean isLoadStateDictCall(CallExpression callExpr) {
5657
// Since this is currently not possible, we check if the parameter to load_state_dict is torch.load(...),
5758
// with the assumption that if torch.load is passed to this load_state_dict, it is probably the correct method
5859
if (callExpr.callee() instanceof QualifiedExpression qualifiedExpr) {
59-
return qualifiedExpr.qualifier().type().mustBeOrExtend("torch.nn.modules.module.Module")
60-
&& LOAD_STATE_DICT_NAME.equals(qualifiedExpr.name().name());
60+
InferredType qualifierType = qualifiedExpr.qualifier().type();
61+
boolean isModule = qualifierType.mustBeOrExtend("torch.nn.modules.module.Module")
62+
|| qualifierType.mustBeOrExtend("torch.nn.Module");
63+
return isModule && LOAD_STATE_DICT_NAME.equals(qualifiedExpr.name().name());
6164
}
6265
return false;
6366
}
Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,54 @@
1-
import torch.nn
2-
import torchvision.models as models
31

4-
def noncompliant():
5-
model = models.vgg16()
6-
model.load_state_dict(torch.load('model_weights.pth')) # Noncompliant {{Set the module in training or evaluation mode.}}
7-
#^^^^^^^^^^^^^^^^^^^^^
8-
...
2+
import torchvision.models as models
93

10-
class CustomModule(torch.nn.Module):
11-
pass
4+
def torch_nn_imported():
5+
import torch.nn
6+
def noncompliant():
7+
model = models.vgg16()
8+
model.load_state_dict(torch.load('model_weights.pth')) # Noncompliant {{Set the module in training or evaluation mode.}}
9+
#^^^^^^^^^^^^^^^^^^^^^
10+
...
1211

13-
def noncompliant():
14-
model = CustomModule()
15-
model.load_state_dict(torch.load('model_weights.pth')) # Noncompliant
16-
...
12+
class CustomModule(torch.nn.Module):
13+
pass
1714

18-
def noncompliant():
19-
model = models.vgg16()
20-
model.train()
21-
model.load_state_dict(torch.load('model_weights.pth')) # Noncompliant
15+
def noncompliant():
16+
model = CustomModule()
17+
model.load_state_dict(torch.load('model_weights.pth')) # Noncompliant
18+
...
2219

23-
def noncompliant():
24-
model = models.vgg16()
25-
model.load_state_dict(torch.load('model_weights.pth'))
26-
model.train()
20+
def noncompliant():
21+
model = models.vgg16()
22+
model.train()
23+
model.load_state_dict(torch.load('model_weights.pth')) # Noncompliant
2724

28-
model.load_state_dict(torch.load('model_weights.pth')) # Noncompliant
25+
def noncompliant():
26+
model = models.vgg16()
27+
model.load_state_dict(torch.load('model_weights.pth'))
28+
model.train()
2929

30-
other_model = model
30+
model.load_state_dict(torch.load('model_weights.pth')) # Noncompliant
3131

32-
def compliant(model):
33-
model.load_state_dict(torch.load('model_weights.pth'))
32+
other_model = model
3433

34+
def compliant(model):
35+
model.load_state_dict(torch.load('model_weights.pth'))
3536

3637

37-
def compliant():
38-
model3 = models.vgg16()
39-
model3.load_state_dict(torch.load('model_weights.pth')) # Ok if model is passed as argument to a function do not raise at all train or eval could be called in such functions
40-
foo(model3)
4138

39+
def compliant():
40+
model3 = models.vgg16()
41+
model3.load_state_dict(torch.load('model_weights.pth')) # Ok if model is passed as argument to a function do not raise at all train or eval could be called in such functions
42+
foo(model3)
4243

4344

45+
def torch_imported():
46+
import torch
4447

48+
class CustomModule(torch.nn.Module):
49+
pass
4550

51+
def noncompliant():
52+
model = CustomModule()
53+
model.load_state_dict(torch.load('model_weights.pth')) # Noncompliant
54+
...

0 commit comments

Comments
 (0)