Skip to content

Commit 33a25e4

Browse files
committed
Copilot suggestions
1 parent 5d9babe commit 33a25e4

File tree

1 file changed

+31
-8
lines changed

1 file changed

+31
-8
lines changed

src/main/java/org/greencodeinitiative/creedengo/python/checks/DisableGradientForModelEval.java

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ public void initialize(Context context) {
4949
context.registerSyntaxNodeConsumer(Tree.Kind.CALL_EXPR, this::checkModelCall);
5050

5151
context.registerSyntaxNodeConsumer(Tree.Kind.FUNCDEF, this::initializeContext);
52+
53+
evalModelsInContext.put(null, new HashSet<>());
54+
noGradScopesInContext.put(null, new HashSet<>());
5255
}
5356

5457
private void initializeContext(SubscriptionContext context) {
@@ -73,12 +76,14 @@ private void checkEvalCall(SubscriptionContext context) {
7376

7477

7578
if (expr.name().name().equals("eval")) {
76-
String modelName = expr.qualifier().firstToken().value();
77-
78-
79-
Tree enclosingContext = getEnclosingContext(expr);
80-
81-
evalModelsInContext.computeIfAbsent(enclosingContext, k -> new HashSet<>()).add(modelName);
79+
if (expr.qualifier() != null && expr.qualifier().firstToken() != null) {
80+
String modelName = expr.qualifier().firstToken().value();
81+
82+
83+
Tree enclosingContext = getEnclosingContext(expr);
84+
85+
evalModelsInContext.computeIfAbsent(enclosingContext, k -> new HashSet<>()).add(modelName);
86+
}
8287
}
8388
}
8489

@@ -103,13 +108,26 @@ private boolean isNoGradCall(Expression expr) {
103108
CallExpression callExpr = (CallExpression) expr;
104109
if (callExpr.callee().is(Tree.Kind.QUALIFIED_EXPR)) {
105110
QualifiedExpression qualExpr = (QualifiedExpression) callExpr.callee();
106-
return qualExpr.qualifier().firstToken().value().equals("torch") &&
111+
return qualExpr.qualifier() != null &&
112+
qualExpr.qualifier().firstToken() != null &&
113+
qualExpr.qualifier().firstToken().value().equals("torch") &&
107114
qualExpr.name().name().equals("no_grad");
108115
}
109116
}
110117
return false;
111118
}
112119

120+
/**
121+
* Checks if a model call is made in evaluation mode without the `torch.no_grad()` context.
122+
* <p>
123+
* This method identifies calls to models and verifies if they are in evaluation mode
124+
* (tracked by `evalModelsInContext`). If a model is in evaluation mode and the call is not
125+
* within a `torch.no_grad()` context, an issue is reported.
126+
* </p>
127+
*
128+
* @param context The subscription context containing the syntax node for the model call.
129+
* This is used to extract the call expression and its enclosing context.
130+
*/
113131
private void checkModelCall(SubscriptionContext context) {
114132
CallExpression callExpr = (CallExpression) context.syntaxNode();
115133

@@ -119,7 +137,12 @@ private void checkModelCall(SubscriptionContext context) {
119137
Expression callee = callExpr.callee();
120138
String modelName = null;
121139

122-
if (!callee.is(Tree.Kind.QUALIFIED_EXPR)) {
140+
if (callee.is(Tree.Kind.QUALIFIED_EXPR)) {
141+
QualifiedExpression qualExpr = (QualifiedExpression) callee;
142+
if (qualExpr.qualifier() != null && qualExpr.qualifier().firstToken() != null) {
143+
modelName = qualExpr.qualifier().firstToken().value();
144+
}
145+
} else {
123146
modelName = callee.firstToken().value();
124147
}
125148

0 commit comments

Comments
 (0)