@@ -49,6 +49,9 @@ public void initialize(Context context) {
4949 context .registerSyntaxNodeConsumer (Tree .Kind .CALL_EXPR , this ::checkModelCall );
5050
5151 context .registerSyntaxNodeConsumer (Tree .Kind .FUNCDEF , this ::initializeContext );
52+
53+ evalModelsInContext .put (null , new HashSet <>());
54+ noGradScopesInContext .put (null , new HashSet <>());
5255 }
5356
5457 private void initializeContext (SubscriptionContext context ) {
@@ -73,12 +76,14 @@ private void checkEvalCall(SubscriptionContext context) {
7376
7477
7578 if (expr .name ().name ().equals ("eval" )) {
76- String modelName = expr .qualifier ().firstToken ().value ();
77-
78-
79- Tree enclosingContext = getEnclosingContext (expr );
80-
81- evalModelsInContext .computeIfAbsent (enclosingContext , k -> new HashSet <>()).add (modelName );
79+ if (expr .qualifier () != null && expr .qualifier ().firstToken () != null ) {
80+ String modelName = expr .qualifier ().firstToken ().value ();
81+
82+
83+ Tree enclosingContext = getEnclosingContext (expr );
84+
85+ evalModelsInContext .computeIfAbsent (enclosingContext , k -> new HashSet <>()).add (modelName );
86+ }
8287 }
8388 }
8489
@@ -103,13 +108,26 @@ private boolean isNoGradCall(Expression expr) {
103108 CallExpression callExpr = (CallExpression ) expr ;
104109 if (callExpr .callee ().is (Tree .Kind .QUALIFIED_EXPR )) {
105110 QualifiedExpression qualExpr = (QualifiedExpression ) callExpr .callee ();
106- return qualExpr .qualifier ().firstToken ().value ().equals ("torch" ) &&
111+ return qualExpr .qualifier () != null &&
112+ qualExpr .qualifier ().firstToken () != null &&
113+ qualExpr .qualifier ().firstToken ().value ().equals ("torch" ) &&
107114 qualExpr .name ().name ().equals ("no_grad" );
108115 }
109116 }
110117 return false ;
111118 }
112119
120+ /**
121+ * Checks if a model call is made in evaluation mode without the `torch.no_grad()` context.
122+ * <p>
123+ * This method identifies calls to models and verifies if they are in evaluation mode
124+ * (tracked by `evalModelsInContext`). If a model is in evaluation mode and the call is not
125+ * within a `torch.no_grad()` context, an issue is reported.
126+ * </p>
127+ *
128+ * @param context The subscription context containing the syntax node for the model call.
129+ * This is used to extract the call expression and its enclosing context.
130+ */
113131 private void checkModelCall (SubscriptionContext context ) {
114132 CallExpression callExpr = (CallExpression ) context .syntaxNode ();
115133
@@ -119,7 +137,12 @@ private void checkModelCall(SubscriptionContext context) {
119137 Expression callee = callExpr .callee ();
120138 String modelName = null ;
121139
122- if (!callee .is (Tree .Kind .QUALIFIED_EXPR )) {
140+ if (callee .is (Tree .Kind .QUALIFIED_EXPR )) {
141+ QualifiedExpression qualExpr = (QualifiedExpression ) callee ;
142+ if (qualExpr .qualifier () != null && qualExpr .qualifier ().firstToken () != null ) {
143+ modelName = qualExpr .qualifier ().firstToken ().value ();
144+ }
145+ } else {
123146 modelName = callee .firstToken ().value ();
124147 }
125148
0 commit comments