Skip to content

Commit 6e23928

Browse files
authored
Merge pull request #77 from cleophass/GCI104-python
GCI104 AI AvoidCreatingTensorUsingNumpyOrNativePython #Python #DLG #Build
2 parents 77f1ff0 + 3d1c3d2 commit 6e23928

File tree

9 files changed

+198
-1
lines changed

9 files changed

+198
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
### Added
1111

12+
- [#77](https://github.com/green-code-initiative/creedengo-python/pull/77) Add rule GCI104 AvoidCreatingTensorUsingNumpyOrNativePython, a rule specific to AI/ML code
1213
- [#70](https://github.com/green-code-initiative/creedengo-python/pull/70) Add rule GCI108 Prefer Append Left (a rule to prefer the use of `append` over `insert` for list, using deques)
1314
- [#78](https://github.com/green-code-initiative/creedengo-python/pull/78) Add rule GCI105 on String Concatenation. This rule may also apply to other rules
1415
- [#74](https://github.com/green-code-initiative/creedengo-python/pull/74) Add rule GCI101 Avoid Conv Bias Before Batch Normalization, a rule specific to Deeplearning

src/it/java/org/greencodeinitiative/creedengo/python/integration/tests/GCIRulesIT.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,22 @@ void testGCI103(){
366366

367367
checkIssuesForFile(filePath, ruleId, ruleMsg, startLines, endLines, SEVERITY, TYPE, EFFORT_1MIN);
368368
}
369+
370+
@Test
371+
void testGCI104() {
372+
373+
String filePath = "src/avoidCreatingTensorUsingNumpyOrNativePython.py";
374+
String ruleId = "creedengo-python:GCI104";
375+
String ruleMsg = "Directly create tensors as torch.Tensor instead of using numpy functions.";
376+
int[] startLines = new int[]{
377+
5, 15, 19, 24
378+
};
379+
int[] endLines = new int[]{
380+
5, 15, 19, 24
381+
};
382+
383+
checkIssuesForFile(filePath, ruleId, ruleMsg, startLines, endLines, SEVERITY, TYPE, EFFORT_10MIN);
384+
}
369385

370386
@Test
371387
void testGCI105() {
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import numpy as np
2+
import torch
3+
4+
def non_compliant_random_rand():
5+
tensor = torch.tensor(np.random.rand(1000, 1000)) # Noncompliant {{Directly create tensors as torch.Tensor instead of using numpy functions.}}
6+
7+
def compliant_random_rand():
8+
tensor = torch.rand([1000, 1000])
9+
10+
def compliant_zeros():
11+
tensor_ = torch.zeros(1, 2)
12+
print(tensor_)
13+
14+
def non_compliant_zeros():
15+
tensor_ = torch.IntTensor(np.zeros(1, 2)) # Noncompliant {{Directly create tensors as torch.Tensor instead of using numpy functions.}}
16+
print(tensor_)
17+
18+
def non_compliant_eye():
19+
tensor = torch.cuda.LongTensor(np.eye(5)) # Noncompliant {{Directly create tensors as torch.Tensor instead of using numpy functions.}}
20+
21+
def non_compliant_ones():
22+
import numpy
23+
from torch import FloatTensor
24+
tensor = FloatTensor(data=np.ones(shape=(1, 5))) # Noncompliant {{Directly create tensors as torch.Tensor instead of using numpy functions.}}

src/main/java/org/greencodeinitiative/creedengo/python/PythonRuleRepository.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ public class PythonRuleRepository implements RulesDefinition, PythonCustomRuleRe
5050
AvoidNonPinnedMemoryForDataloaders.class,
5151
AvoidConvBiasBeforeBatchNorm.class,
5252
StringConcatenation.class,
53-
PreferAppendLeft.class
53+
PreferAppendLeft.class,
54+
AvoidCreatingTensorUsingNumpyOrNativePython.class
5455
);
5556

5657
public static final String LANGUAGE = "py";
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* creedengo - Python language - Provides rules to reduce the environmental footprint of your Python programs
3+
* Copyright © 2024 Green Code Initiative (https://green-code-initiative.org)
4+
*
5+
* This program is free software: you can redistribute it and/or modify
6+
* it under the terms of the GNU General Public License as published by
7+
* the Free Software Foundation, either version 3 of the License, or
8+
* (at your option) any later version.
9+
*
10+
* This program is distributed in the hope that it will be useful,
11+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
12+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13+
* GNU General Public License for more details.
14+
*
15+
* You should have received a copy of the GNU General Public License
16+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
17+
*/
18+
package org.greencodeinitiative.creedengo.python.checks;
19+
20+
import org.greencodeinitiative.creedengo.python.utils.UtilsAST;
21+
import org.sonar.check.Rule;
22+
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
23+
import org.sonar.plugins.python.api.tree.Tree;
24+
import org.sonar.plugins.python.api.tree.CallExpression;
25+
import org.sonar.plugins.python.api.tree.RegularArgument;
26+
27+
import java.util.List;
28+
import java.util.Map;
29+
30+
import static java.util.Map.entry;
31+
import static org.sonar.plugins.python.api.tree.Tree.Kind.CALL_EXPR;
32+
33+
@Rule(key = "GCI104")
34+
public class AvoidCreatingTensorUsingNumpyOrNativePython extends PythonSubscriptionCheck {
35+
36+
private static final String dataArgumentName = "data";
37+
private static final int dataArgumentPosition = 0;
38+
private static final Map<String, String> torchOtherFunctionsMapping = Map.ofEntries(
39+
entry("numpy.random.rand", "torch.rand"),
40+
entry("numpy.random.randint", "torch.randint"),
41+
entry("numpy.random.randn", "torch.randn"),
42+
entry("numpy.zeros", "torch.zeros"),
43+
entry("numpy.zeros_like", "torch.zeros_like"),
44+
entry("numpy.ones", "torch.ones"),
45+
entry("numpy.ones_like", "torch.ones_like"),
46+
entry("numpy.full", "torch.full"),
47+
entry("numpy.full_like", "torch.full_like"),
48+
entry("numpy.eye", "torch.eye"),
49+
entry("numpy.arange", "torch.arange"),
50+
entry("numpy.linspace", "torch.linspace"),
51+
entry("numpy.logspace", "torch.logspace"),
52+
entry("numpy.identity", "torch.eye"),
53+
entry("numpy.tile", "torch.tile")
54+
);
55+
private static final List<String> torchTensorConstructors = List.of(
56+
"torch.tensor", "torch.FloatTensor",
57+
"torch.DoubleTensor", "torch.HalfTensor",
58+
"torch.BFloat16Tensor", "torch.ByteTensor",
59+
"torch.CharTensor", "torch.ShortTensor",
60+
"torch.IntTensor", "torch.LongTensor",
61+
"torch.BoolTensor", "torch.cuda.FloatTensor",
62+
"torch.cuda.DoubleTensor", "torch.cuda.HalfTensor",
63+
"torch.cuda.BFloat16Tensor", "torch.cuda.ByteTensor",
64+
"torch.cuda.CharTensor", "torch.cuda.ShortTensor",
65+
"torch.cuda.IntTensor", "torch.cuda.LongTensor",
66+
"torch.cuda.BoolTensor");
67+
protected static final String MESSAGE = "Directly create tensors as torch.Tensor instead of using numpy functions.";
68+
69+
@Override
70+
public void initialize(Context context) {
71+
context.registerSyntaxNodeConsumer(Tree.Kind.CALL_EXPR, ctx -> {
72+
CallExpression callExpression = (CallExpression) ctx.syntaxNode();
73+
if (torchTensorConstructors.contains(UtilsAST.getQualifiedName(callExpression))) {
74+
RegularArgument tensorCreatorArgument = UtilsAST.nthArgumentOrKeyword(dataArgumentPosition, dataArgumentName, callExpression.arguments());
75+
if (tensorCreatorArgument != null) {
76+
if (tensorCreatorArgument.expression().is(CALL_EXPR)) {
77+
String functionQualifiedName = UtilsAST.getQualifiedName((CallExpression) tensorCreatorArgument.expression());
78+
if (torchOtherFunctionsMapping.containsKey(functionQualifiedName)) {
79+
ctx.addIssue(callExpression, MESSAGE);
80+
}
81+
}
82+
}
83+
}
84+
});
85+
}
86+
}

src/main/java/org/greencodeinitiative/creedengo/python/utils/UtilsAST.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,19 @@ public static String getQualifiedName(CallExpression callExpression) {
5252
.orElse("");
5353
}
5454

55+
/**
56+
* Retrieves the variable name from the given SubscriptionContext.
57+
*
58+
* This method traverses the syntax tree of the provided context to locate
59+
* the nearest assignment statement. If an assignment statement is found,
60+
* it extracts the name of the variable on the left-hand side of the assignment.
61+
*
62+
* @param context The SubscriptionContext containing the syntax node to analyze.
63+
* It may be null or contain a null syntax node, in which case
64+
* the method returns null.
65+
* @return The name of the variable on the left-hand side of the assignment
66+
* statement, or null if no valid variable name can be determined.
67+
*/
5568
public static String getVariableName(SubscriptionContext context) {
5669

5770
if (context == null || context.syntaxNode() == null) {

src/main/resources/org/greencodeinitiative/creedengo/python/creedengo_way_profile.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"GCI101",
1717
"GCI102",
1818
"GCI103",
19+
"GCI104",
1920
"GCI105",
2021
"GCI106",
2122
"GCI107",
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* creedengo - Python language - Provides rules to reduce the environmental footprint of your Python programs
3+
* Copyright © 2024 Green Code Initiative (https://green-code-initiative.org)
4+
*
5+
* This program is free software: you can redistribute it and/or modify
6+
* it under the terms of the GNU General Public License as published by
7+
* the Free Software Foundation, either version 3 of the License, or
8+
* (at your option) any later version.
9+
*
10+
* This program is distributed in the hope that it will be useful,
11+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
12+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13+
* GNU General Public License for more details.
14+
*
15+
* You should have received a copy of the GNU General Public License
16+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
17+
*/
18+
package org.greencodeinitiative.creedengo.python.checks;
19+
20+
import org.junit.Test;
21+
import org.sonar.python.checks.utils.PythonCheckVerifier;
22+
23+
24+
25+
public class AvoidCreatingTensorUsingNumpyOrNativePythonTest {
26+
@Test
27+
public void test() {
28+
PythonCheckVerifier.verify("src/test/resources/checks/avoidCreatingTensorUsingNumpyOrNativePython.py", new AvoidCreatingTensorUsingNumpyOrNativePython());
29+
}
30+
31+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import numpy as np
2+
import torch
3+
4+
def non_compliant_random_rand():
5+
tensor = torch.tensor(np.random.rand(1000, 1000)) # Noncompliant {{Directly create tensors as torch.Tensor instead of using numpy functions.}}
6+
7+
def compliant_random_rand():
8+
tensor = torch.rand([1000, 1000])
9+
10+
def compliant_zeros():
11+
tensor_ = torch.zeros(1, 2)
12+
print(tensor_)
13+
14+
def non_compliant_zeros():
15+
tensor_ = torch.IntTensor(np.zeros(1, 2)) # Noncompliant {{Directly create tensors as torch.Tensor instead of using numpy functions.}}
16+
print(tensor_)
17+
18+
def non_compliant_eye():
19+
tensor = torch.cuda.LongTensor(np.eye(5)) # Noncompliant {{Directly create tensors as torch.Tensor instead of using numpy functions.}}
20+
21+
def non_compliant_ones():
22+
import numpy
23+
from torch import FloatTensor
24+
tensor = FloatTensor(data=np.ones(shape=(1, 5))) # Noncompliant {{Directly create tensors as torch.Tensor instead of using numpy functions.}}

0 commit comments

Comments
 (0)