2727import org .sonar .plugins .python .api .PythonSubscriptionCheck ;
2828import org .sonar .plugins .python .api .symbols .Symbol ;
2929import org .sonar .plugins .python .api .symbols .Usage ;
30- import org .sonar .plugins .python .api .tree .Argument ;
3130import org .sonar .plugins .python .api .tree .CallExpression ;
32- import org .sonar .plugins .python .api .tree .Expression ;
3331import org .sonar .plugins .python .api .tree .Name ;
3432import org .sonar .plugins .python .api .tree .QualifiedExpression ;
35- import org .sonar .plugins .python .api .tree .RegularArgument ;
3633import org .sonar .plugins .python .api .tree .Tree ;
3734import org .sonar .python .cfg .fixpoint .ReachingDefinitionsAnalysis ;
3835import org .sonar .python .tree .TreeUtils ;
3936
4037@ Rule (key = "S6982" )
4138public class TorchModuleModeShouldBeSetAfterLoadingCheck extends PythonSubscriptionCheck {
4239 private static final Set <String > STATE_SETTING_FUNCTION_FQNS = Set .of ("eval" , "train" );
43- private static final String TORCH_LOAD_FQN = "torch.load" ;
4440 private static final String LOAD_STATE_DICT_NAME = "load_state_dict" ;
4541 private static final String MESSAGE = "Set the module in training or evaluation mode." ;
46- private static final int IS_TORCH_LOAD_CALL_MAX_RECURSIVE_COUNTER = 10 ;
4742
4843 private ReachingDefinitionsAnalysis reachingDefinitionsAnalysis ;
4944
@@ -54,7 +49,7 @@ public void initialize(Context context) {
5449
5550 context .registerSyntaxNodeConsumer (Tree .Kind .CALL_EXPR , ctx -> {
5651 CallExpression callExpr = (CallExpression ) ctx .syntaxNode ();
57- List <Usage > receiverUsages = getForwardUsages (callExpr );
52+ List <Usage > receiverUsages = getForwardUsagesOfReceiver (callExpr );
5853 if (isLoadStateDictCall (callExpr ) && !hasEvalOrTrainUsage (receiverUsages ) && !isModelPassedOn (receiverUsages )) {
5954 ctx .addIssue (callExpr .callee (), MESSAGE );
6055 }
@@ -65,33 +60,14 @@ private boolean isLoadStateDictCall(CallExpression callExpr) {
6560 // To properly check if the correct load_state_dict is called, typeshed type information would be required.
6661 // Since this is currently not possible, we check if the parameter to load_state_dict is torch.load(...),
6762 // with the assumption that if torch.load is passed to this load_state_dict, it is probably the correct method
68- if (callExpr .callee () instanceof QualifiedExpression qualifiedExpr ) {
69- return LOAD_STATE_DICT_NAME .equals (qualifiedExpr .name ().name ()) && containsTorchLoadCall (callExpr .arguments ());
63+ if (callExpr .callee () instanceof QualifiedExpression qualifiedExpr ) {
64+ return qualifiedExpr .qualifier ().type ().mustBeOrExtend ("torch.nn.modules.module.Module" )
65+ && LOAD_STATE_DICT_NAME .equals (qualifiedExpr .name ().name ());
7066 }
7167 return false ;
7268 }
7369
74- private boolean containsTorchLoadCall (List <Argument > args ) {
75- return args .stream ()
76- .flatMap (TreeUtils .toStreamInstanceOfMapper (RegularArgument .class ))
77- .anyMatch (arg -> isTorchLoadCall (arg .expression (), 0 ));
78- }
79-
80- private boolean isTorchLoadCall (Expression expr , int recursiveCounter ) {
81- if (recursiveCounter > IS_TORCH_LOAD_CALL_MAX_RECURSIVE_COUNTER ) {
82- return false ;
83- } else if (expr instanceof CallExpression callExpr ) {
84- Symbol calleeSymbol = callExpr .calleeSymbol ();
85- return calleeSymbol != null && TORCH_LOAD_FQN .equals (calleeSymbol .fullyQualifiedName ());
86- } else if (expr instanceof Name name ) {
87- return reachingDefinitionsAnalysis .valuesAtLocation (name ).stream ()
88- .anyMatch (definitionExpr -> isTorchLoadCall (definitionExpr , recursiveCounter + 1 ));
89- } else {
90- return false ;
91- }
92- }
93-
94- private static List <Usage > getForwardUsages (CallExpression callExpr ) {
70+ private static List <Usage > getForwardUsagesOfReceiver (CallExpression callExpr ) {
9571 List <Usage > usages = getFunctionCallReceiverName (callExpr )
9672 .flatMap (name -> Optional .ofNullable (name .symbol ()))
9773 .map (Symbol ::usages )
0 commit comments