Skip to content

Commit 5947a5d

Browse files
authored
SONARPY-1915 implement S6929 for PyTorch (#1962)
* SONARPY-1915 implement S6929 for PyTorch * Implement QuickFix for TensorFlow reduce functions * prevent raising an issue if an unpack expression is present * SONARPY-1915 add test case for if fqn is null * SONARPY-1915 use Expressions.containsSpreadOperator(...) * SONARPY-1915 extract constant in test
1 parent 7d62937 commit 5947a5d

File tree

7 files changed

+288
-96
lines changed

7 files changed

+288
-96
lines changed

python-checks/src/main/java/org/sonar/python/checks/CheckList.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ public static Iterable<Class> getChecks() {
376376
TfFunctionRecursivityCheck.class,
377377
TfInputShapeOnModelSubclassCheck.class,
378378
TfGatherDeprecatedValidateIndicesCheck.class,
379-
TfSpecifyReductionAxisCheck.class,
379+
TfPyTorchSpecifyReductionAxisCheck.class,
380380
ReferencedBeforeAssignmentCheck.class,
381381
RegexComplexityCheck.class,
382382
RegexLookaheadCheck.class,
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* SonarQube Python Plugin
3+
* Copyright (C) 2011-2024 SonarSource SA
4+
* mailto:info AT sonarsource DOT com
5+
*
6+
* This program is free software; you can redistribute it and/or
7+
* modify it under the terms of the GNU Lesser General Public
8+
* License as published by the Free Software Foundation; either
9+
* version 3 of the License, or (at your option) any later version.
10+
*
11+
* This program is distributed in the hope that it will be useful,
12+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14+
* Lesser General Public License for more details.
15+
*
16+
* You should have received a copy of the GNU Lesser General Public License
17+
* along with this program; if not, write to the Free Software Foundation,
18+
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
19+
*/
20+
package org.sonar.python.checks;
21+
22+
import java.util.Arrays;
23+
import java.util.HashSet;
24+
import java.util.Map;
25+
import java.util.Optional;
26+
import java.util.Set;
27+
import org.sonar.check.Rule;
28+
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
29+
import org.sonar.plugins.python.api.SubscriptionContext;
30+
import org.sonar.plugins.python.api.quickfix.PythonQuickFix;
31+
import org.sonar.plugins.python.api.symbols.Symbol;
32+
import org.sonar.plugins.python.api.tree.CallExpression;
33+
import org.sonar.plugins.python.api.tree.Tree;
34+
import org.sonar.python.checks.utils.Expressions;
35+
import org.sonar.python.quickfix.TextEditUtils;
36+
import org.sonar.python.tree.TreeUtils;
37+
38+
@Rule(key = "S6929")
39+
public class TfPyTorchSpecifyReductionAxisCheck extends PythonSubscriptionCheck {
40+
41+
private static final Set<String> TF_REDUCTION_FUNCTIONS = new HashSet<>(Arrays.asList("reduce_all", "reduce_mean", "reduce_any",
42+
"reduce_euclidean_norm", "reduce_logsumexp",
43+
"reduce_max", "reduce_min", "reduce_prod", "reduce_std", "reduce_sum", "reduce_variance"));
44+
private static final Set<String> TF_REDUCTION_FUNCTIONS_FQN = new HashSet<>();
45+
private static final String TF_MESSAGE = "Provide a value for the axis argument.";
46+
public static final String AXIS_PARAMETER = "axis";
47+
public static final int AXIS_PARAMETER_POSITION = 1;
48+
49+
static {
50+
for (String reductionFunction : TF_REDUCTION_FUNCTIONS) {
51+
TF_REDUCTION_FUNCTIONS_FQN.add("tensorflow.math." + reductionFunction);
52+
TF_REDUCTION_FUNCTIONS_FQN.add("tensorflow.tf." + reductionFunction);
53+
}
54+
}
55+
56+
private static final String PY_TORCH_MESSAGE = "Provide a value for the dim argument.";
57+
public static final String DIM_PARAMETER = "dim";
58+
public static final int NO_POSITIONAL_ARG = -1;
59+
60+
/**
61+
* Contains a list of reduction functions with a {@code dim} parameter and the position of the dim argument in that function.
62+
*/
63+
private static final Map<String, Integer> PY_TORCH_REDUCTION_FUNCTIONS_DIM_POS = Map.ofEntries(
64+
Map.entry("torch.argmin", 1),
65+
Map.entry("torch.aminmax", NO_POSITIONAL_ARG),
66+
Map.entry("torch.nanmean", 1),
67+
Map.entry("torch.mode", 1),
68+
Map.entry("torch.norm", 2),
69+
Map.entry("torch.quantile", 2),
70+
Map.entry("torch.nanquantile", 2),
71+
Map.entry("torch.std", 1),
72+
Map.entry("torch.std_mean", 1),
73+
Map.entry("torch.unique", 4),
74+
Map.entry("torch.unique_consecutive", 3),
75+
Map.entry("torch.var", 1),
76+
Map.entry("torch.var_mean", 1),
77+
Map.entry("torch.count_nonzero", 1)
78+
);
79+
80+
@Override
81+
public void initialize(Context context) {
82+
context.registerSyntaxNodeConsumer(Tree.Kind.CALL_EXPR, TfPyTorchSpecifyReductionAxisCheck::checkCallExpr);
83+
}
84+
85+
private static void checkCallExpr(SubscriptionContext context) {
86+
CallExpression callExpression = (CallExpression) context.syntaxNode();
87+
Symbol symbol = callExpression.calleeSymbol();
88+
if (symbol != null && !Expressions.containsSpreadOperator(callExpression.arguments())) {
89+
if (isTfReductionMissingAxisArg(symbol, callExpression)) {
90+
PreciseIssue issue = context.addIssue(callExpression.callee(), TF_MESSAGE);
91+
createTfQuickFix(callExpression).ifPresent(issue::addQuickFix);
92+
}
93+
94+
if (isPyTorchReductionMissingDimArg(symbol, callExpression)) {
95+
context.addIssue(callExpression.callee(), PY_TORCH_MESSAGE);
96+
}
97+
}
98+
}
99+
100+
private static Optional<PythonQuickFix> createTfQuickFix(CallExpression callExpression) {
101+
if(callExpression.arguments().isEmpty()) {
102+
return Optional.empty();
103+
}
104+
return Optional.of(PythonQuickFix.newQuickFix("Add axis parameter", TextEditUtils.insertBefore(callExpression.rightPar(), ", axis=None")));
105+
}
106+
107+
private static boolean isTfReductionMissingAxisArg(Symbol symbol, CallExpression callExpression) {
108+
return TF_REDUCTION_FUNCTIONS_FQN.contains(symbol.fullyQualifiedName())
109+
&& TreeUtils.nthArgumentOrKeyword(AXIS_PARAMETER_POSITION, AXIS_PARAMETER, callExpression.arguments()) == null;
110+
}
111+
112+
private static boolean isPyTorchReductionMissingDimArg(Symbol symbol, CallExpression callExpression) {
113+
String fqn = symbol.fullyQualifiedName();
114+
return fqn != null && PY_TORCH_REDUCTION_FUNCTIONS_DIM_POS.containsKey(fqn)
115+
&& TreeUtils.nthArgumentOrKeyword(PY_TORCH_REDUCTION_FUNCTIONS_DIM_POS.get(fqn), DIM_PARAMETER, callExpression.arguments()) == null;
116+
}
117+
}

python-checks/src/main/java/org/sonar/python/checks/TfSpecifyReductionAxisCheck.java

Lines changed: 0 additions & 65 deletions
This file was deleted.
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* SonarQube Python Plugin
3+
* Copyright (C) 2011-2024 SonarSource SA
4+
* mailto:info AT sonarsource DOT com
5+
*
6+
* This program is free software; you can redistribute it and/or
7+
* modify it under the terms of the GNU Lesser General Public
8+
* License as published by the Free Software Foundation; either
9+
* version 3 of the License, or (at your option) any later version.
10+
*
11+
* This program is distributed in the hope that it will be useful,
12+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14+
* Lesser General Public License for more details.
15+
*
16+
* You should have received a copy of the GNU Lesser General Public License
17+
* along with this program; if not, write to the Free Software Foundation,
18+
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
19+
*/
20+
package org.sonar.python.checks;
21+
22+
import org.junit.jupiter.api.Test;
23+
import org.sonar.python.checks.quickfix.PythonQuickFixVerifier;
24+
import org.sonar.python.checks.utils.PythonCheckVerifier;
25+
26+
class TfPyTorchSpecifyReductionAxisCheckTest {
27+
private static final TfPyTorchSpecifyReductionAxisCheck CHECK_OBJECT = new TfPyTorchSpecifyReductionAxisCheck();
28+
29+
@Test
30+
void testTensorFlow() {
31+
PythonCheckVerifier.verify("src/test/resources/checks/tfSpecifyReductionAxis.py", CHECK_OBJECT);
32+
}
33+
34+
@Test
35+
void testTensorFlowQuickFix() {
36+
PythonQuickFixVerifier.verify(CHECK_OBJECT,
37+
"""
38+
from tensorflow import math
39+
math.reduce_all(input)
40+
""",
41+
"""
42+
from tensorflow import math
43+
math.reduce_all(input, axis=None)
44+
""");
45+
46+
PythonQuickFixVerifier.verify(CHECK_OBJECT,
47+
"""
48+
from tensorflow import math
49+
math.reduce_all(input, keepdims=True)
50+
""",
51+
"""
52+
from tensorflow import math
53+
math.reduce_all(input, keepdims=True, axis=None)
54+
""");
55+
56+
PythonQuickFixVerifier.verifyNoQuickFixes(CHECK_OBJECT,
57+
"""
58+
from tensorflow import math
59+
math.reduce_all()
60+
""");
61+
}
62+
63+
@Test
64+
void testPyTorch() {
65+
PythonCheckVerifier.verify("src/test/resources/checks/pyTorchSpecifyReductionDim.py", CHECK_OBJECT);
66+
}
67+
}

python-checks/src/test/java/org/sonar/python/checks/TfSpecifyReductionAxisCheckTest.java

Lines changed: 0 additions & 30 deletions
This file was deleted.
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import torch
2+
3+
def argmin(input, dict_args):
4+
torch.argmin(input) # Noncompliant {{Provide a value for the dim argument.}}
5+
#^^^^^^^^^^^^
6+
7+
torch.argmin(**dict_args)
8+
torch.argmin(input, 2)
9+
torch.argmin(input, dim=2)
10+
11+
12+
def aminmax(input):
13+
torch.aminmax(input) # Noncompliant
14+
torch.aminmax(input, 2) #Noncompliant
15+
16+
torch.aminmax(input, dim=2)
17+
18+
def nanmean(input):
19+
torch.nanmean(input) # Noncompliant
20+
21+
torch.nanmean(input, 2)
22+
torch.nanmean(input, dim=2)
23+
24+
def mode(input):
25+
torch.mode(input) # Noncompliant
26+
27+
torch.mode(input, 2)
28+
torch.mode(input, dim=2)
29+
30+
def norm(input, p):
31+
torch.norm(input) # Noncompliant
32+
torch.norm(input, 2) #Noncompliant
33+
34+
torch.norm(input, p, 2)
35+
torch.norm(input, dim=2)
36+
37+
def quantile(input, q):
38+
torch.quantile(input) # Noncompliant
39+
torch.quantile(input, 2) #Noncompliant
40+
41+
torch.quantile(input, q, 2)
42+
torch.quantile(input, dim=2)
43+
44+
def nanquantile(input, q):
45+
torch.nanquantile(input) # Noncompliant
46+
torch.nanquantile(input, 2) #Noncompliant
47+
48+
torch.nanquantile(input, q, 2)
49+
torch.nanquantile(input, dim=2)
50+
51+
def std(input):
52+
torch.std(input) # Noncompliant
53+
54+
torch.std(input, 2)
55+
torch.std(input, dim=2)
56+
57+
def std_mean(input):
58+
torch.std_mean(input) # Noncompliant
59+
60+
torch.std_mean(input, 2)
61+
torch.std_mean(input, dim=2)
62+
63+
def unique(input, sorted, return_inverse, return_counts):
64+
torch.unique(input) # Noncompliant
65+
torch.unique(input, sorted, return_inverse, 2) #Noncompliant
66+
67+
torch.unique(input, sorted, return_inverse, return_counts, 2)
68+
torch.unique(input, dim=2)
69+
70+
def unique_consecutive(input, return_inverse, return_counts):
71+
torch.unique_consecutive(input) # Noncompliant
72+
torch.unique_consecutive(input, return_inverse, return_counts) #Noncompliant
73+
74+
torch.unique_consecutive(input, return_inverse, return_counts, 2)
75+
torch.unique_consecutive(input, dim=2)
76+
77+
def var(input):
78+
torch.var(input) # Noncompliant
79+
80+
torch.var(input, 2)
81+
torch.var(input, dim=2)
82+
83+
def var_mean(input):
84+
torch.var_mean(input) # Noncompliant
85+
86+
torch.var_mean(input, 2)
87+
torch.var_mean(input, dim=2)
88+
89+
def count_nonzero(input):
90+
torch.count_nonzero(input) # Noncompliant
91+
92+
torch.count_nonzero(input, 2)
93+
torch.count_nonzero(input, dim=2)
94+
95+
def multiple_imports():
96+
from torch import var_mean
97+
from other_module import var_mean
98+
99+
var_mean()

0 commit comments

Comments
 (0)