Skip to content

Commit 312f89c

Browse files
authored
SONARPY-2150: Fix FP on S6982 when method is used on optimizers (#1984)
* SONARPY-2150 check type of module * SONARPY-2150 improve coverage
1 parent fae4aa6 commit 312f89c

File tree

2 files changed

+21
-53
lines changed

2 files changed

+21
-53
lines changed

python-checks/src/main/java/org/sonar/python/checks/TorchModuleModeShouldBeSetAfterLoadingCheck.java

Lines changed: 5 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,18 @@
2727
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
2828
import org.sonar.plugins.python.api.symbols.Symbol;
2929
import org.sonar.plugins.python.api.symbols.Usage;
30-
import org.sonar.plugins.python.api.tree.Argument;
3130
import org.sonar.plugins.python.api.tree.CallExpression;
32-
import org.sonar.plugins.python.api.tree.Expression;
3331
import org.sonar.plugins.python.api.tree.Name;
3432
import org.sonar.plugins.python.api.tree.QualifiedExpression;
35-
import org.sonar.plugins.python.api.tree.RegularArgument;
3633
import org.sonar.plugins.python.api.tree.Tree;
3734
import org.sonar.python.cfg.fixpoint.ReachingDefinitionsAnalysis;
3835
import org.sonar.python.tree.TreeUtils;
3936

4037
@Rule(key = "S6982")
4138
public class TorchModuleModeShouldBeSetAfterLoadingCheck extends PythonSubscriptionCheck {
4239
private static final Set<String> STATE_SETTING_FUNCTION_FQNS = Set.of("eval", "train");
43-
private static final String TORCH_LOAD_FQN = "torch.load";
4440
private static final String LOAD_STATE_DICT_NAME = "load_state_dict";
4541
private static final String MESSAGE = "Set the module in training or evaluation mode.";
46-
private static final int IS_TORCH_LOAD_CALL_MAX_RECURSIVE_COUNTER = 10;
4742

4843
private ReachingDefinitionsAnalysis reachingDefinitionsAnalysis;
4944

@@ -54,7 +49,7 @@ public void initialize(Context context) {
5449

5550
context.registerSyntaxNodeConsumer(Tree.Kind.CALL_EXPR, ctx -> {
5651
CallExpression callExpr = (CallExpression) ctx.syntaxNode();
57-
List<Usage> receiverUsages = getForwardUsages(callExpr);
52+
List<Usage> receiverUsages = getForwardUsagesOfReceiver(callExpr);
5853
if (isLoadStateDictCall(callExpr) && !hasEvalOrTrainUsage(receiverUsages) && !isModelPassedOn(receiverUsages)) {
5954
ctx.addIssue(callExpr.callee(), MESSAGE);
6055
}
@@ -65,33 +60,14 @@ private boolean isLoadStateDictCall(CallExpression callExpr) {
6560
// To properly check if the correct load_state_dict is called, typeshed type information would be required.
6661
// Since this is currently not possible, we check if the parameter to load_state_dict is torch.load(...),
6762
// with the assumption that if torch.load is passed to this load_state_dict, it is probably the correct method
68-
if(callExpr.callee() instanceof QualifiedExpression qualifiedExpr) {
69-
return LOAD_STATE_DICT_NAME.equals(qualifiedExpr.name().name()) && containsTorchLoadCall(callExpr.arguments());
63+
if (callExpr.callee() instanceof QualifiedExpression qualifiedExpr) {
64+
return qualifiedExpr.qualifier().type().mustBeOrExtend("torch.nn.modules.module.Module")
65+
&& LOAD_STATE_DICT_NAME.equals(qualifiedExpr.name().name());
7066
}
7167
return false;
7268
}
7369

74-
private boolean containsTorchLoadCall(List<Argument> args) {
75-
return args.stream()
76-
.flatMap(TreeUtils.toStreamInstanceOfMapper(RegularArgument.class))
77-
.anyMatch(arg -> isTorchLoadCall(arg.expression(), 0));
78-
}
79-
80-
private boolean isTorchLoadCall(Expression expr, int recursiveCounter) {
81-
if (recursiveCounter > IS_TORCH_LOAD_CALL_MAX_RECURSIVE_COUNTER) {
82-
return false;
83-
} else if (expr instanceof CallExpression callExpr) {
84-
Symbol calleeSymbol = callExpr.calleeSymbol();
85-
return calleeSymbol != null && TORCH_LOAD_FQN.equals(calleeSymbol.fullyQualifiedName());
86-
} else if (expr instanceof Name name) {
87-
return reachingDefinitionsAnalysis.valuesAtLocation(name).stream()
88-
.anyMatch(definitionExpr -> isTorchLoadCall(definitionExpr, recursiveCounter + 1));
89-
} else {
90-
return false;
91-
}
92-
}
93-
94-
private static List<Usage> getForwardUsages(CallExpression callExpr) {
70+
private static List<Usage> getForwardUsagesOfReceiver(CallExpression callExpr) {
9571
List<Usage> usages = getFunctionCallReceiverName(callExpr)
9672
.flatMap(name -> Optional.ofNullable(name.symbol()))
9773
.map(Symbol::usages)
Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import torch
1+
import torch.nn
22
import torchvision.models as models
33

44
def noncompliant():
@@ -7,47 +7,39 @@ def noncompliant():
77
#^^^^^^^^^^^^^^^^^^^^^
88
...
99

10-
def noncompliant(model):
11-
model.load_state_dict(torch.load('model_weights.pth')) # Noncompliant
10+
class CustomModule(torch.nn.Module):
11+
pass
1212

1313
def noncompliant():
14+
model = CustomModule()
1415
model.load_state_dict(torch.load('model_weights.pth')) # Noncompliant
16+
...
1517

1618
def noncompliant():
17-
get_model().load_state_dict(torch.load('model_weights.pth')) # Noncompliant
18-
19-
def noncompliant(model):
20-
weights = torch.load('model_weights.pth')
21-
weights2 = weights
22-
model.load_state_dict(weights2) # Noncompliant
19+
model = models.vgg16()
20+
model.train()
21+
model.load_state_dict(torch.load('model_weights.pth')) # Noncompliant
2322

2423
def noncompliant():
2524
model = models.vgg16()
25+
model.load_state_dict(torch.load('model_weights.pth'))
2626
model.train()
27+
2728
model.load_state_dict(torch.load('model_weights.pth')) # Noncompliant
29+
2830
other_model = model
2931

3032
def compliant(model):
31-
weights = weights
32-
model.load_state_dict(weights)
33+
model.load_state_dict(torch.load('model_weights.pth'))
3334

34-
def compliant():
35-
model1 = models.vgg16()
36-
model1.load_state_dict(torch.load('model_weights.pth'))
37-
model1.eval()
3835

39-
def compliant():
40-
model2 = models.vgg16()
41-
model2.load_state_dict(torch.load('model_weights.pth'))
42-
other_model = model2
43-
model2.train()
4436

4537
def compliant():
4638
model3 = models.vgg16()
4739
model3.load_state_dict(torch.load('model_weights.pth')) # Ok if model is passed as argument to a function do not raise at all train or eval could be called in such functions
4840
foo(model3)
4941

50-
def compliant():
51-
# Ok since no torch.load() result is passed as an argument
52-
model.load_state_dict(1 + 1)
53-
model.load_state_dict((lambda x: x)())
42+
43+
44+
45+

0 commit comments

Comments
 (0)