Skip to content

Commit 8134b9c

Browse files
authored
Merge pull request #73 from cleophass/GCI100-python
GCI100 AI/PyTorch DisableGradientForModelEval #Python #DLG #Build
2 parents f44d097 + 9ae386e commit 8134b9c

File tree

8 files changed

+337
-0
lines changed

8 files changed

+337
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
### Added
1111

12+
- [#73](https://github.com/green-code-initiative/creedengo-python/pull/73) Add rule GCI100 Disable Gradient For model eval, a rule specific to PyTorch and AI/ML
1213
- [#77](https://github.com/green-code-initiative/creedengo-python/pull/77) Add rule GCI104 AvoidCreatingTensorUsingNumpyOrNativePython, a rule specific to AI/ML code
1314
- [#70](https://github.com/green-code-initiative/creedengo-python/pull/70) Add rule GCI108 Prefer Append Left (a rule to prefer the use of `append` over `insert` for list, using deques)
1415
- [#78](https://github.com/green-code-initiative/creedengo-python/pull/78) Add rule GCI105 on String Concatenation. This rule may also apply to other rules

src/it/java/org/greencodeinitiative/creedengo/python/integration/tests/GCIRulesIT.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,22 @@ void testGCI99(){
320320

321321
checkIssuesForFile(filePath, ruleId, ruleMsg, startLines, endLines, SEVERITY, TYPE, EFFORT_50MIN);
322322
}
323+
324+
@Test
325+
void testGCI100() {
326+
327+
String filePath = "src/disableGradientForModelEval.py";
328+
String ruleId = "creedengo-python:GCI100";
329+
String ruleMsg = "PyTorch : Disable gradient computation when evaluating a model to save memory and computation time";
330+
int[] startLines = new int[]{
331+
19, 29, 38
332+
};
333+
int[] endLines = new int[]{
334+
19, 29, 38
335+
};
336+
337+
checkIssuesForFile(filePath, ruleId, ruleMsg, startLines, endLines, SEVERITY, TYPE, EFFORT_10MIN);
338+
}
323339

324340
@Test
325341
void testGCI101(){
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from torchvision import models
5+
6+
class SimpleModel(nn.Module):
7+
def __init__(self):
8+
super(SimpleModel, self).__init__()
9+
self.linear = nn.Linear(10, 2)
10+
11+
def forward(self, x):
12+
return self.linear(x)
13+
14+
model = models.resnet18(pretrained=True)
15+
model.eval()
16+
17+
input_tensor = torch.randn(1, 3, 224, 224, requires_grad=True)
18+
19+
output = model(input_tensor) # Noncompliant {{PyTorch : Disable gradient computation when evaluating a model to save memory and computation time}}
20+
21+
score = output[0].max()
22+
23+
24+
def non_compliant_without_no_grad():
25+
model = SimpleModel()
26+
model.eval()
27+
28+
inputs = torch.randn(1, 10)
29+
outputs = model(inputs) # Noncompliant {{PyTorch : Disable gradient computation when evaluating a model to save memory and computation time}}
30+
31+
return outputs
32+
33+
def non_compliant_with_different_model_name():
34+
my_neural_net = SimpleModel()
35+
my_neural_net.eval()
36+
37+
inputs = torch.randn(1, 10)
38+
outputs = my_neural_net(inputs) # Noncompliant {{PyTorch : Disable gradient computation when evaluating a model to save memory and computation time}}
39+
40+
return outputs
41+
42+
def compliant_with_no_grad():
43+
model = SimpleModel()
44+
model.eval()
45+
46+
inputs = torch.randn(1, 10)
47+
with torch.no_grad():
48+
outputs = model(inputs)
49+
50+
return outputs
51+
52+
def compliant_without_eval():
53+
model = SimpleModel()
54+
55+
inputs = torch.randn(1, 10)
56+
outputs = model(inputs)
57+
58+
return outputs
59+

src/main/java/org/greencodeinitiative/creedengo/python/PythonRuleRepository.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ public class PythonRuleRepository implements RulesDefinition, PythonCustomRuleRe
4949
AvoidIterativeMatrixOperations.class,
5050
AvoidNonPinnedMemoryForDataloaders.class,
5151
AvoidConvBiasBeforeBatchNorm.class,
52+
DisableGradientForModelEval.class,
5253
StringConcatenation.class,
5354
PreferAppendLeft.class,
5455
AvoidCreatingTensorUsingNumpyOrNativePython.class
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
/*
2+
* creedengo - Python language - Provides rules to reduce the environmental footprint of your Python programs
3+
* Copyright © 2024 Green Code Initiative (https://green-code-initiative.org)
4+
*
5+
* This program is free software: you can redistribute it and/or modify
6+
* it under the terms of the GNU General Public License as published by
7+
* the Free Software Foundation, either version 3 of the License, or
8+
* (at your option) any later version.
9+
*
10+
* This program is distributed in the hope that it will be useful,
11+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
12+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13+
* GNU General Public License for more details.
14+
*
15+
* You should have received a copy of the GNU General Public License
16+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
17+
*/
18+
package org.greencodeinitiative.creedengo.python.checks;
19+
20+
21+
import org.sonar.check.Rule;
22+
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
23+
import org.sonar.plugins.python.api.SubscriptionContext;
24+
import org.sonar.plugins.python.api.tree.Expression;
25+
import org.sonar.plugins.python.api.tree.CallExpression;
26+
import org.sonar.plugins.python.api.tree.FunctionDef;
27+
import org.sonar.plugins.python.api.tree.QualifiedExpression;
28+
import org.sonar.plugins.python.api.tree.Tree;
29+
import org.sonar.plugins.python.api.tree.WithItem;
30+
import org.sonar.plugins.python.api.tree.WithStatement;
31+
32+
import java.util.HashMap;
33+
import java.util.HashSet;
34+
import java.util.Map;
35+
import java.util.Set;
36+
37+
@Rule(key = "GCI100")
38+
public class DisableGradientForModelEval extends PythonSubscriptionCheck {
39+
40+
private static final String DESCRIPTION = "PyTorch : Disable gradient computation when evaluating a model to save memory and computation time";
41+
42+
private final Map<Tree, Set<String>> evalModelsInContext = new HashMap<>();
43+
private final Map<Tree, Set<Tree>> noGradScopesInContext = new HashMap<>();
44+
45+
@Override
46+
public void initialize(Context context) {
47+
context.registerSyntaxNodeConsumer(Tree.Kind.QUALIFIED_EXPR, this::checkEvalCall);
48+
context.registerSyntaxNodeConsumer(Tree.Kind.WITH_STMT, this::checkWithNoGrad);
49+
context.registerSyntaxNodeConsumer(Tree.Kind.CALL_EXPR, this::checkModelCall);
50+
51+
context.registerSyntaxNodeConsumer(Tree.Kind.FUNCDEF, this::initializeContext);
52+
53+
evalModelsInContext.put(null, new HashSet<>());
54+
noGradScopesInContext.put(null, new HashSet<>());
55+
}
56+
57+
private void initializeContext(SubscriptionContext context) {
58+
FunctionDef funcDef = (FunctionDef) context.syntaxNode();
59+
evalModelsInContext.put(funcDef, new HashSet<>());
60+
noGradScopesInContext.put(funcDef, new HashSet<>());
61+
}
62+
63+
private Tree getEnclosingContext(Tree node) {
64+
Tree current = node;
65+
while (current != null) {
66+
if (current.is(Tree.Kind.FUNCDEF)) {
67+
return current;
68+
}
69+
current = current.parent();
70+
}
71+
return null;
72+
}
73+
74+
private void checkEvalCall(SubscriptionContext context) {
75+
QualifiedExpression expr = (QualifiedExpression) context.syntaxNode();
76+
77+
78+
if (expr.name().name().equals("eval")) {
79+
if (expr.qualifier() != null && expr.qualifier().firstToken() != null) {
80+
String modelName = expr.qualifier().firstToken().value();
81+
82+
83+
Tree enclosingContext = getEnclosingContext(expr);
84+
85+
evalModelsInContext.computeIfAbsent(enclosingContext, k -> new HashSet<>()).add(modelName);
86+
}
87+
}
88+
}
89+
90+
private void checkWithNoGrad(SubscriptionContext context) {
91+
WithStatement withStmt = (WithStatement) context.syntaxNode();
92+
93+
94+
for (WithItem item : withStmt.withItems()) {
95+
if (isNoGradCall(item.test())) {
96+
97+
Tree enclosingContext = getEnclosingContext(withStmt);
98+
99+
100+
noGradScopesInContext.computeIfAbsent(enclosingContext, k -> new HashSet<>()).add(withStmt);
101+
return;
102+
}
103+
}
104+
}
105+
106+
private boolean isNoGradCall(Expression expr) {
107+
if (expr.is(Tree.Kind.CALL_EXPR)) {
108+
CallExpression callExpr = (CallExpression) expr;
109+
if (callExpr.callee().is(Tree.Kind.QUALIFIED_EXPR)) {
110+
QualifiedExpression qualExpr = (QualifiedExpression) callExpr.callee();
111+
return qualExpr.qualifier() != null &&
112+
qualExpr.qualifier().firstToken() != null &&
113+
qualExpr.qualifier().firstToken().value().equals("torch") &&
114+
qualExpr.name().name().equals("no_grad");
115+
}
116+
}
117+
return false;
118+
}
119+
120+
/**
121+
* Checks if a model call is made in evaluation mode without the `torch.no_grad()` context.
122+
* <p>
123+
* This method identifies calls to models and verifies if they are in evaluation mode
124+
* (tracked by `evalModelsInContext`). If a model is in evaluation mode and the call is not
125+
* within a `torch.no_grad()` context, an issue is reported.
126+
* </p>
127+
*
128+
* @param context The subscription context containing the syntax node for the model call.
129+
* This is used to extract the call expression and its enclosing context.
130+
*/
131+
private void checkModelCall(SubscriptionContext context) {
132+
CallExpression callExpr = (CallExpression) context.syntaxNode();
133+
134+
Tree enclosingContext = getEnclosingContext(callExpr);
135+
136+
137+
Expression callee = callExpr.callee();
138+
String modelName = null;
139+
140+
if (callee.is(Tree.Kind.QUALIFIED_EXPR)) {
141+
QualifiedExpression qualExpr = (QualifiedExpression) callee;
142+
if (qualExpr.qualifier() != null && qualExpr.qualifier().firstToken() != null) {
143+
modelName = qualExpr.qualifier().firstToken().value();
144+
}
145+
} else {
146+
modelName = callee.firstToken().value();
147+
}
148+
149+
Set<String> modelsInEvalMode = evalModelsInContext.getOrDefault(enclosingContext, new HashSet<>());
150+
151+
if (modelName != null && modelsInEvalMode.contains(modelName)) {
152+
153+
if (!isInNoGradContext(callExpr, enclosingContext)) {
154+
context.addIssue(callExpr, DESCRIPTION);
155+
}
156+
}
157+
}
158+
159+
private boolean isInNoGradContext(Tree tree, Tree enclosingContext) {
160+
Set<Tree> noGradScopes = noGradScopesInContext.getOrDefault(enclosingContext, new HashSet<>());
161+
162+
Tree current = tree;
163+
while (current != null && current != enclosingContext) {
164+
if (noGradScopes.contains(current)) {
165+
return true;
166+
}
167+
current = current.parent();
168+
}
169+
return false;
170+
}
171+
}

src/main/resources/org/greencodeinitiative/creedengo/python/creedengo_way_profile.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"GCI96",
1414
"GCI97",
1515
"GCI99",
16+
"GCI100",
1617
"GCI101",
1718
"GCI102",
1819
"GCI103",
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* creedengo - Python language - Provides rules to reduce the environmental footprint of your Python programs
3+
* Copyright © 2024 Green Code Initiative (https://green-code-initiative.org)
4+
*
5+
* This program is free software: you can redistribute it and/or modify
6+
* it under the terms of the GNU General Public License as published by
7+
* the Free Software Foundation, either version 3 of the License, or
8+
* (at your option) any later version.
9+
*
10+
* This program is distributed in the hope that it will be useful,
11+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
12+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13+
* GNU General Public License for more details.
14+
*
15+
* You should have received a copy of the GNU General Public License
16+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
17+
*/
18+
package org.greencodeinitiative.creedengo.python.checks;
19+
20+
import org.junit.Test;
21+
import org.sonar.python.checks.utils.PythonCheckVerifier;
22+
23+
public class DisableGradientForModelEvalTest {
24+
25+
@Test
26+
public void test() {
27+
PythonCheckVerifier.verify("src/test/resources/checks/disableGradientForModelEval.py", new DisableGradientForModelEval());
28+
}
29+
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from torchvision import models
5+
6+
class SimpleModel(nn.Module):
7+
def __init__(self):
8+
super(SimpleModel, self).__init__()
9+
self.linear = nn.Linear(10, 2)
10+
11+
def forward(self, x):
12+
return self.linear(x)
13+
14+
model = models.resnet18(pretrained=True)
15+
model.eval()
16+
17+
input_tensor = torch.randn(1, 3, 224, 224, requires_grad=True)
18+
19+
output = model(input_tensor) # Noncompliant {{PyTorch : Disable gradient computation when evaluating a model to save memory and computation time}}
20+
21+
score = output[0].max()
22+
23+
24+
def non_compliant_without_no_grad():
25+
model = SimpleModel()
26+
model.eval()
27+
28+
inputs = torch.randn(1, 10)
29+
outputs = model(inputs) # Noncompliant {{PyTorch : Disable gradient computation when evaluating a model to save memory and computation time}}
30+
31+
return outputs
32+
33+
def non_compliant_with_different_model_name():
34+
my_neural_net = SimpleModel()
35+
my_neural_net.eval()
36+
37+
inputs = torch.randn(1, 10)
38+
outputs = my_neural_net(inputs) # Noncompliant {{PyTorch : Disable gradient computation when evaluating a model to save memory and computation time}}
39+
40+
return outputs
41+
42+
def compliant_with_no_grad():
43+
model = SimpleModel()
44+
model.eval()
45+
46+
inputs = torch.randn(1, 10)
47+
with torch.no_grad():
48+
outputs = model(inputs)
49+
50+
return outputs
51+
52+
def compliant_without_eval():
53+
model = SimpleModel()
54+
55+
inputs = torch.randn(1, 10)
56+
outputs = model(inputs)
57+
58+
return outputs
59+

0 commit comments

Comments
 (0)