Skip to content

Commit 828e406

Browse files
GCI104 AI AvoidCreatingTensorUsingNumpyOrNativePython : Add IT test and remove utils class
Co-authored-by: DataLabGroupe-CreditAgricole <[email protected]>
1 parent 02ea926 commit 828e406

File tree

7 files changed

+64
-113
lines changed

7 files changed

+64
-113
lines changed

src/it/java/org/greencodeinitiative/creedengo/python/integration/tests/GCIRulesIT.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,22 @@ void testGCI103(){
366366

367367
checkIssuesForFile(filePath, ruleId, ruleMsg, startLines, endLines, SEVERITY, TYPE, EFFORT_1MIN);
368368
}
369+
370+
@Test
371+
void testGCI104() {
372+
373+
String filePath = "src/avoidCreatingTensorUsingNumpyOrNativePython.py";
374+
String ruleId = "creedengo-python:GCI104";
375+
String ruleMsg = "Directly create tensors as torch.Tensor instead of using numpy functions.";
376+
int[] startLines = new int[]{
377+
5, 15, 19, 24
378+
};
379+
int[] endLines = new int[]{
380+
5, 15, 19, 24
381+
};
382+
383+
checkIssuesForFile(filePath, ruleId, ruleMsg, startLines, endLines, SEVERITY, TYPE, EFFORT_10MIN);
384+
}
369385

370386
@Test
371387
void testGCI105() {
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import numpy as np
2+
import torch
3+
4+
def non_compliant_random_rand():
5+
tensor = torch.tensor(np.random.rand(1000, 1000)) # Noncompliant {{Directly create tensors as torch.Tensor instead of using numpy functions.}}
6+
7+
def compliant_random_rand():
8+
tensor = torch.rand([1000, 1000])
9+
10+
def compliant_zeros():
11+
tensor_ = torch.zeros(1, 2)
12+
print(tensor_)
13+
14+
def non_compliant_zeros():
15+
tensor_ = torch.IntTensor(np.zeros(1, 2)) # Noncompliant {{Directly create tensors as torch.Tensor instead of using numpy functions.}}
16+
print(tensor_)
17+
18+
def non_compliant_eye():
19+
tensor = torch.cuda.LongTensor(np.eye(5)) # Noncompliant {{Directly create tensors as torch.Tensor instead of using numpy functions.}}
20+
21+
def non_compliant_ones():
22+
import numpy
23+
from torch import FloatTensor
24+
tensor = FloatTensor(data=np.ones(shape=(1, 5))) # Noncompliant {{Directly create tensors as torch.Tensor instead of using numpy functions.}}

src/main/java/org/greencodeinitiative/creedengo/python/checks/AvoidCreatingTensorUsingNumpyOrNativePython.java

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
*/
1818
package org.greencodeinitiative.creedengo.python.checks;
1919

20-
import org.sonar.check.Priority;
20+
import org.greencodeinitiative.creedengo.python.utils.UtilsAST;
2121
import org.sonar.check.Rule;
2222
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
2323
import org.sonar.plugins.python.api.tree.Tree;
@@ -33,7 +33,6 @@
3333
@Rule(key = "GCI104")
3434
public class AvoidCreatingTensorUsingNumpyOrNativePython extends PythonSubscriptionCheck {
3535

36-
public static final String RULE_KEY = "P5";
3736
private static final String dataArgumentName = "data";
3837
private static final int dataArgumentPosition = 0;
3938
private static final Map<String, String> torchOtherFunctionsMapping = Map.ofEntries(
@@ -65,19 +64,19 @@ public class AvoidCreatingTensorUsingNumpyOrNativePython extends PythonSubscript
6564
"torch.cuda.CharTensor", "torch.cuda.ShortTensor",
6665
"torch.cuda.IntTensor", "torch.cuda.LongTensor",
6766
"torch.cuda.BoolTensor");
68-
protected static final String MESSAGE = "Directly create tensors as torch.Tensor. Use %s instead of %s.";
67+
protected static final String MESSAGE = "Directly create tensors as torch.Tensor instead of using numpy functions.";
6968

7069
@Override
7170
public void initialize(Context context) {
7271
context.registerSyntaxNodeConsumer(Tree.Kind.CALL_EXPR, ctx -> {
7372
CallExpression callExpression = (CallExpression) ctx.syntaxNode();
74-
if (torchTensorConstructors.contains(Utils.getQualifiedName(callExpression))) {
75-
RegularArgument tensorCreatorArgument = Utils.nthArgumentOrKeyword(dataArgumentPosition, dataArgumentName, callExpression.arguments());
73+
if (torchTensorConstructors.contains(UtilsAST.getQualifiedName(callExpression))) {
74+
RegularArgument tensorCreatorArgument = UtilsAST.nthArgumentOrKeyword(dataArgumentPosition, dataArgumentName, callExpression.arguments());
7675
if (tensorCreatorArgument != null) {
7776
if (tensorCreatorArgument.expression().is(CALL_EXPR)) {
78-
String functionQualifiedName = Utils.getQualifiedName((CallExpression) tensorCreatorArgument.expression());
77+
String functionQualifiedName = UtilsAST.getQualifiedName((CallExpression) tensorCreatorArgument.expression());
7978
if (torchOtherFunctionsMapping.containsKey(functionQualifiedName)) {
80-
ctx.addIssue(callExpression, String.format(MESSAGE, torchOtherFunctionsMapping.get(functionQualifiedName), functionQualifiedName));
79+
ctx.addIssue(callExpression, MESSAGE);
8180
}
8281
}
8382
}

src/main/java/org/greencodeinitiative/creedengo/python/checks/Utils.java

Lines changed: 0 additions & 102 deletions
This file was deleted.

src/main/java/org/greencodeinitiative/creedengo/python/utils/UtilsAST.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,19 @@ public static String getQualifiedName(CallExpression callExpression) {
5252
.orElse("");
5353
}
5454

55+
/**
56+
* Retrieves the variable name from the given SubscriptionContext.
57+
*
58+
* This method traverses the syntax tree of the provided context to locate
59+
* the nearest assignment statement. If an assignment statement is found,
60+
* it extracts the name of the variable on the left-hand side of the assignment.
61+
*
62+
* @param context The SubscriptionContext containing the syntax node to analyze.
63+
* It may be null or contain a null syntax node, in which case
64+
* the method returns null.
65+
* @return The name of the variable on the left-hand side of the assignment
66+
* statement, or null if no valid variable name can be determined.
67+
*/
5568
public static String getVariableName(SubscriptionContext context) {
5669

5770
if (context == null || context.syntaxNode() == null) {

src/main/resources/org/greencodeinitiative/creedengo/python/creedengo_way_profile.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"GCI101",
1717
"GCI102",
1818
"GCI103",
19+
"GCI104",
1920
"GCI105",
2021
"GCI106",
2122
"GCI107",

src/test/resources/checks/avoidCreatingTensorUsingNumpyOrNativePython.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33

44
def non_compliant_random_rand():
5-
tensor = torch.tensor(np.random.rand(1000, 1000)) # Noncompliant {{Directly create tensors as torch.Tensor. Use torch.rand instead of numpy.random.rand.}}
5+
tensor = torch.tensor(np.random.rand(1000, 1000)) # Noncompliant {{Directly create tensors as torch.Tensor instead of using numpy functions.}}
66

77
def compliant_random_rand():
88
tensor = torch.rand([1000, 1000])
@@ -12,13 +12,13 @@ def compliant_zeros():
1212
print(tensor_)
1313

1414
def non_compliant_zeros():
15-
tensor_ = torch.IntTensor(np.zeros(1, 2)) # Noncompliant {{Directly create tensors as torch.Tensor. Use torch.zeros instead of numpy.zeros.}}
15+
tensor_ = torch.IntTensor(np.zeros(1, 2)) # Noncompliant {{Directly create tensors as torch.Tensor instead of using numpy functions.}}
1616
print(tensor_)
1717

1818
def non_compliant_eye():
19-
tensor = torch.cuda.LongTensor(np.eye(5)) # Noncompliant {{Directly create tensors as torch.Tensor. Use torch.eye instead of numpy.eye.}}
19+
tensor = torch.cuda.LongTensor(np.eye(5)) # Noncompliant {{Directly create tensors as torch.Tensor instead of using numpy functions.}}
2020

2121
def non_compliant_ones():
2222
import numpy
2323
from torch import FloatTensor
24-
tensor = FloatTensor(data=np.ones(shape=(1, 5))) # Noncompliant {{Directly create tensors as torch.Tensor. Use torch.ones instead of numpy.ones.}}
24+
tensor = FloatTensor(data=np.ones(shape=(1, 5))) # Noncompliant {{Directly create tensors as torch.Tensor instead of using numpy functions.}}

0 commit comments

Comments
 (0)