Skip to content

Commit 02a422e

Browse files
committed
correction after GCI104 merging
1 parent 6e23928 commit 02a422e

File tree

1 file changed

+13
-14
lines changed

1 file changed

+13
-14
lines changed

src/main/java/org/greencodeinitiative/creedengo/python/checks/AvoidCreatingTensorUsingNumpyOrNativePython.java

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@
3333
@Rule(key = "GCI104")
3434
public class AvoidCreatingTensorUsingNumpyOrNativePython extends PythonSubscriptionCheck {
3535

36-
private static final String dataArgumentName = "data";
37-
private static final int dataArgumentPosition = 0;
38-
private static final Map<String, String> torchOtherFunctionsMapping = Map.ofEntries(
36+
private static final String DATA_ARGUMENT_NAME = "data";
37+
private static final int DATA_ARGUMENT_POSITION = 0;
38+
private static final Map<String, String> TORCH_OTHER_FUNCTIONS_MAPPING = Map.ofEntries(
3939
entry("numpy.random.rand", "torch.rand"),
4040
entry("numpy.random.randint", "torch.randint"),
4141
entry("numpy.random.randn", "torch.randn"),
@@ -52,7 +52,7 @@ public class AvoidCreatingTensorUsingNumpyOrNativePython extends PythonSubscript
5252
entry("numpy.identity", "torch.eye"),
5353
entry("numpy.tile", "torch.tile")
5454
);
55-
private static final List<String> torchTensorConstructors = List.of(
55+
private static final List<String> TORCH_TENSOR_CONSTRUCTORS = List.of(
5656
"torch.tensor", "torch.FloatTensor",
5757
"torch.DoubleTensor", "torch.HalfTensor",
5858
"torch.BFloat16Tensor", "torch.ByteTensor",
@@ -68,18 +68,17 @@ public class AvoidCreatingTensorUsingNumpyOrNativePython extends PythonSubscript
6868

6969
@Override
7070
public void initialize(Context context) {
71-
context.registerSyntaxNodeConsumer(Tree.Kind.CALL_EXPR, ctx -> {
71+
context.registerSyntaxNodeConsumer(Tree.Kind.CALL_EXPR, ctx -> {
7272
CallExpression callExpression = (CallExpression) ctx.syntaxNode();
73-
if (torchTensorConstructors.contains(UtilsAST.getQualifiedName(callExpression))) {
74-
RegularArgument tensorCreatorArgument = UtilsAST.nthArgumentOrKeyword(dataArgumentPosition, dataArgumentName, callExpression.arguments());
75-
if (tensorCreatorArgument != null) {
76-
if (tensorCreatorArgument.expression().is(CALL_EXPR)) {
77-
String functionQualifiedName = UtilsAST.getQualifiedName((CallExpression) tensorCreatorArgument.expression());
78-
if (torchOtherFunctionsMapping.containsKey(functionQualifiedName)) {
79-
ctx.addIssue(callExpression, MESSAGE);
80-
}
73+
74+
if (TORCH_TENSOR_CONSTRUCTORS.contains(UtilsAST.getQualifiedName(callExpression))) {
75+
RegularArgument tensorCreatorArgument = UtilsAST.nthArgumentOrKeyword(DATA_ARGUMENT_POSITION, DATA_ARGUMENT_NAME, callExpression.arguments());
76+
if (tensorCreatorArgument != null && tensorCreatorArgument.expression().is(CALL_EXPR)) {
77+
String functionQualifiedName = UtilsAST.getQualifiedName((CallExpression) tensorCreatorArgument.expression());
78+
if (TORCH_OTHER_FUNCTIONS_MAPPING.containsKey(functionQualifiedName)) {
79+
ctx.addIssue(callExpression, MESSAGE);
80+
}
8181
}
82-
}
8382
}
8483
});
8584
}

0 commit comments

Comments
 (0)