27
27
import org .sonar .plugins .python .api .PythonSubscriptionCheck ;
28
28
import org .sonar .plugins .python .api .symbols .Symbol ;
29
29
import org .sonar .plugins .python .api .symbols .Usage ;
30
- import org .sonar .plugins .python .api .tree .Argument ;
31
30
import org .sonar .plugins .python .api .tree .CallExpression ;
32
- import org .sonar .plugins .python .api .tree .Expression ;
33
31
import org .sonar .plugins .python .api .tree .Name ;
34
32
import org .sonar .plugins .python .api .tree .QualifiedExpression ;
35
- import org .sonar .plugins .python .api .tree .RegularArgument ;
36
33
import org .sonar .plugins .python .api .tree .Tree ;
37
34
import org .sonar .python .cfg .fixpoint .ReachingDefinitionsAnalysis ;
38
35
import org .sonar .python .tree .TreeUtils ;
39
36
40
37
@ Rule (key = "S6982" )
41
38
public class TorchModuleModeShouldBeSetAfterLoadingCheck extends PythonSubscriptionCheck {
42
39
private static final Set <String > STATE_SETTING_FUNCTION_FQNS = Set .of ("eval" , "train" );
43
- private static final String TORCH_LOAD_FQN = "torch.load" ;
44
40
private static final String LOAD_STATE_DICT_NAME = "load_state_dict" ;
45
41
private static final String MESSAGE = "Set the module in training or evaluation mode." ;
46
- private static final int IS_TORCH_LOAD_CALL_MAX_RECURSIVE_COUNTER = 10 ;
47
42
48
43
private ReachingDefinitionsAnalysis reachingDefinitionsAnalysis ;
49
44
@@ -54,7 +49,7 @@ public void initialize(Context context) {
54
49
55
50
context .registerSyntaxNodeConsumer (Tree .Kind .CALL_EXPR , ctx -> {
56
51
CallExpression callExpr = (CallExpression ) ctx .syntaxNode ();
57
- List <Usage > receiverUsages = getForwardUsages (callExpr );
52
+ List <Usage > receiverUsages = getForwardUsagesOfReceiver (callExpr );
58
53
if (isLoadStateDictCall (callExpr ) && !hasEvalOrTrainUsage (receiverUsages ) && !isModelPassedOn (receiverUsages )) {
59
54
ctx .addIssue (callExpr .callee (), MESSAGE );
60
55
}
@@ -65,33 +60,14 @@ private boolean isLoadStateDictCall(CallExpression callExpr) {
65
60
// To properly check if the correct load_state_dict is called, typeshed type information would be required.
66
61
// Since this is currently not possible, we check if the parameter to load_state_dict is torch.load(...),
67
62
// with the assumption that if torch.load is passed to this load_state_dict, it is probably the correct method
68
- if (callExpr .callee () instanceof QualifiedExpression qualifiedExpr ) {
69
- return LOAD_STATE_DICT_NAME .equals (qualifiedExpr .name ().name ()) && containsTorchLoadCall (callExpr .arguments ());
63
+ if (callExpr .callee () instanceof QualifiedExpression qualifiedExpr ) {
64
+ return qualifiedExpr .qualifier ().type ().mustBeOrExtend ("torch.nn.modules.module.Module" )
65
+ && LOAD_STATE_DICT_NAME .equals (qualifiedExpr .name ().name ());
70
66
}
71
67
return false ;
72
68
}
73
69
74
- private boolean containsTorchLoadCall (List <Argument > args ) {
75
- return args .stream ()
76
- .flatMap (TreeUtils .toStreamInstanceOfMapper (RegularArgument .class ))
77
- .anyMatch (arg -> isTorchLoadCall (arg .expression (), 0 ));
78
- }
79
-
80
- private boolean isTorchLoadCall (Expression expr , int recursiveCounter ) {
81
- if (recursiveCounter > IS_TORCH_LOAD_CALL_MAX_RECURSIVE_COUNTER ) {
82
- return false ;
83
- } else if (expr instanceof CallExpression callExpr ) {
84
- Symbol calleeSymbol = callExpr .calleeSymbol ();
85
- return calleeSymbol != null && TORCH_LOAD_FQN .equals (calleeSymbol .fullyQualifiedName ());
86
- } else if (expr instanceof Name name ) {
87
- return reachingDefinitionsAnalysis .valuesAtLocation (name ).stream ()
88
- .anyMatch (definitionExpr -> isTorchLoadCall (definitionExpr , recursiveCounter + 1 ));
89
- } else {
90
- return false ;
91
- }
92
- }
93
-
94
- private static List <Usage > getForwardUsages (CallExpression callExpr ) {
70
+ private static List <Usage > getForwardUsagesOfReceiver (CallExpression callExpr ) {
95
71
List <Usage > usages = getFunctionCallReceiverName (callExpr )
96
72
.flatMap (name -> Optional .ofNullable (name .symbol ()))
97
73
.map (Symbol ::usages )
0 commit comments