3333@ Rule (key = "GCI104" )
3434public class AvoidCreatingTensorUsingNumpyOrNativePython extends PythonSubscriptionCheck {
3535
36- private static final String dataArgumentName = "data" ;
37- private static final int dataArgumentPosition = 0 ;
38- private static final Map <String , String > torchOtherFunctionsMapping = Map .ofEntries (
36+ private static final String DATA_ARGUMENT_NAME = "data" ;
37+ private static final int DATA_ARGUMENT_POSITION = 0 ;
38+ private static final Map <String , String > TORCH_OTHER_FUNCTIONS_MAPPING = Map .ofEntries (
3939 entry ("numpy.random.rand" , "torch.rand" ),
4040 entry ("numpy.random.randint" , "torch.randint" ),
4141 entry ("numpy.random.randn" , "torch.randn" ),
@@ -52,7 +52,7 @@ public class AvoidCreatingTensorUsingNumpyOrNativePython extends PythonSubscript
5252 entry ("numpy.identity" , "torch.eye" ),
5353 entry ("numpy.tile" , "torch.tile" )
5454 );
55- private static final List <String > torchTensorConstructors = List .of (
55+ private static final List <String > TORCH_TENSOR_CONSTRUCTORS = List .of (
5656 "torch.tensor" , "torch.FloatTensor" ,
5757 "torch.DoubleTensor" , "torch.HalfTensor" ,
5858 "torch.BFloat16Tensor" , "torch.ByteTensor" ,
@@ -68,18 +68,17 @@ public class AvoidCreatingTensorUsingNumpyOrNativePython extends PythonSubscript
6868
6969 @ Override
7070 public void initialize (Context context ) {
71- context .registerSyntaxNodeConsumer (Tree .Kind .CALL_EXPR , ctx -> {
71+ context .registerSyntaxNodeConsumer (Tree .Kind .CALL_EXPR , ctx -> {
7272 CallExpression callExpression = (CallExpression ) ctx .syntaxNode ();
73- if ( torchTensorConstructors . contains ( UtilsAST . getQualifiedName ( callExpression ))) {
74- RegularArgument tensorCreatorArgument = UtilsAST . nthArgumentOrKeyword ( dataArgumentPosition , dataArgumentName , callExpression . arguments ());
75- if ( tensorCreatorArgument != null ) {
76- if (tensorCreatorArgument .expression ().is (CALL_EXPR )) {
77- String functionQualifiedName = UtilsAST .getQualifiedName ((CallExpression ) tensorCreatorArgument .expression ());
78- if (torchOtherFunctionsMapping .containsKey (functionQualifiedName )) {
79- ctx .addIssue (callExpression , MESSAGE );
80- }
73+
74+ if ( TORCH_TENSOR_CONSTRUCTORS . contains ( UtilsAST . getQualifiedName ( callExpression ))) {
75+ RegularArgument tensorCreatorArgument = UtilsAST . nthArgumentOrKeyword ( DATA_ARGUMENT_POSITION , DATA_ARGUMENT_NAME , callExpression . arguments ());
76+ if (tensorCreatorArgument != null && tensorCreatorArgument .expression ().is (CALL_EXPR )) {
77+ String functionQualifiedName = UtilsAST .getQualifiedName ((CallExpression ) tensorCreatorArgument .expression ());
78+ if (TORCH_OTHER_FUNCTIONS_MAPPING .containsKey (functionQualifiedName )) {
79+ ctx .addIssue (callExpression , MESSAGE );
80+ }
8181 }
82- }
8382 }
8483 });
8584 }
0 commit comments