Skip to content

Commit d79830f

Browse files
authored
SONARPY-1909: Einops pattern should be valid (#1960)
1 parent 41749af commit d79830f

File tree

7 files changed

+313
-0
lines changed

7 files changed

+313
-0
lines changed

python-checks/src/main/java/org/sonar/python/checks/CheckList.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ public static Iterable<Class> getChecks() {
179179
DuplicatedMethodImplementationCheck.class,
180180
DuplicatesInCharacterClassCheck.class,
181181
DynamicCodeExecutionCheck.class,
182+
EinopsSyntaxCheck.class,
182183
ElseAfterLoopsWithoutBreakCheck.class,
183184
EmailSendingCheck.class,
184185
EmptyAlternativeCheck.class,
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
/*
2+
* SonarQube Python Plugin
3+
* Copyright (C) 2011-2024 SonarSource SA
4+
* mailto:info AT sonarsource DOT com
5+
*
6+
* This program is free software; you can redistribute it and/or
7+
* modify it under the terms of the GNU Lesser General Public
8+
* License as published by the Free Software Foundation; either
9+
* version 3 of the License, or (at your option) any later version.
10+
*
11+
* This program is distributed in the hope that it will be useful,
12+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14+
* Lesser General Public License for more details.
15+
*
16+
* You should have received a copy of the GNU Lesser General Public License
17+
* along with this program; if not, write to the Free Software Foundation,
18+
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
19+
*/
20+
package org.sonar.python.checks;
21+
22+
import java.util.LinkedHashSet;
23+
import java.util.List;
24+
import java.util.Objects;
25+
import java.util.Optional;
26+
import java.util.Set;
27+
import java.util.regex.Matcher;
28+
import java.util.regex.Pattern;
29+
import java.util.stream.Collectors;
30+
import org.sonar.check.Rule;
31+
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
32+
import org.sonar.plugins.python.api.SubscriptionContext;
33+
import org.sonar.plugins.python.api.symbols.Symbol;
34+
import org.sonar.plugins.python.api.tree.Argument;
35+
import org.sonar.plugins.python.api.tree.CallExpression;
36+
import org.sonar.plugins.python.api.tree.RegularArgument;
37+
import org.sonar.plugins.python.api.tree.StringLiteral;
38+
import org.sonar.plugins.python.api.tree.Tree;
39+
import org.sonar.python.tree.TreeUtils;
40+
41+
@Rule(key = "S6984")
42+
public class EinopsSyntaxCheck extends PythonSubscriptionCheck {
43+
44+
private static final String MESSAGE_TEMPLATE = "Fix the syntax of this einops operation: %s.";
45+
private static final String NESTED_PARENTHESIS_MESSAGE = "nested parenthesis are not allowed";
46+
private static final String LHS_ELLIPSIS_MESSAGE = "Ellipsis inside parenthesis on the left side is not allowed";
47+
private static final String UNBALANCED_PARENTHESIS_MESSAGE = "parenthesis are unbalanced";
48+
private static final Set<String> FQN_TO_CHECK = Set.of("einops.repeat", "einops.reduce", "einops.rearrange");
49+
private static final Pattern ellipsisPattern = Pattern.compile("\\((.*)\\)");
50+
51+
@Override
52+
public void initialize(Context context) {
53+
context.registerSyntaxNodeConsumer(Tree.Kind.CALL_EXPR, EinopsSyntaxCheck::checkEinopsSyntax);
54+
}
55+
56+
private static void checkEinopsSyntax(SubscriptionContext ctx) {
57+
CallExpression callExpression = (CallExpression) ctx.syntaxNode();
58+
Symbol calleeSymbol = callExpression.calleeSymbol();
59+
if (calleeSymbol != null && calleeSymbol.fullyQualifiedName() != null && FQN_TO_CHECK.contains(calleeSymbol.fullyQualifiedName())) {
60+
extractPatternFromCallExpr(callExpression).ifPresent(stringLiteral -> {
61+
var maybePattern = toEinopsPattern(stringLiteral);
62+
if (maybePattern.isPresent()) {
63+
var pattern = maybePattern.get();
64+
checkForEllipsisInParenthesis(ctx, pattern);
65+
checkForUnbalancedParenthesis(ctx, pattern);
66+
checkForUnusedParameter(ctx, callExpression.arguments(), pattern);
67+
} else {
68+
ctx.addIssue(callExpression.callee(), "Provide a valid einops pattern.");
69+
}
70+
});
71+
}
72+
}
73+
74+
private static void checkForUnusedParameter(SubscriptionContext ctx, List<Argument> arguments, EinopsPattern pattern) {
75+
List<String> argsToCheck = arguments.stream()
76+
.map(TreeUtils.toInstanceOfMapper(RegularArgument.class))
77+
.filter(Objects::nonNull)
78+
.filter(arg -> arg.expression().is(Tree.Kind.NUMERIC_LITERAL))
79+
.filter(arg -> arg.keywordArgument() != null)
80+
.map(arg -> arg.keywordArgument().name())
81+
.filter(argName -> !pattern.lhs.identifiers.contains(argName))
82+
.filter(argName -> !pattern.rhs.identifiers.contains(argName))
83+
.toList();
84+
85+
if (!argsToCheck.isEmpty()) {
86+
var isPlural = argsToCheck.size() > 1;
87+
var missingParameters = argsToCheck.stream().collect(Collectors.joining(", "));
88+
var missingParametersMessage = String.format("the parameter%s %s do%s not appear in the pattern", isPlural ? "s" : "", missingParameters, isPlural ? "" : "es");
89+
ctx.addIssue(pattern.originalPattern(), String.format(MESSAGE_TEMPLATE, missingParametersMessage));
90+
}
91+
}
92+
93+
private static void checkForUnbalancedParenthesis(SubscriptionContext ctx, EinopsPattern pattern) {
94+
pattern.lhs.state.errorMessage.or(() -> pattern.rhs.state.errorMessage)
95+
.ifPresent(message -> ctx.addIssue(pattern.originalPattern(), String.format(MESSAGE_TEMPLATE, message)));
96+
}
97+
98+
private static void checkForEllipsisInParenthesis(SubscriptionContext ctx, EinopsPattern pattern) {
99+
Matcher m = ellipsisPattern.matcher(pattern.lhs.originalPattern);
100+
if (m.find() && (m.group().contains("...") || m.group().contains("…"))) {
101+
ctx.addIssue(pattern.originalPattern(), String.format(MESSAGE_TEMPLATE, LHS_ELLIPSIS_MESSAGE));
102+
}
103+
}
104+
105+
private record EinopsPattern(StringLiteral originalPattern, EinopsSide lhs, EinopsSide rhs) {
106+
}
107+
108+
private record EinopsSide(String originalPattern, Set<String> identifiers, ParenthesisState state) {
109+
}
110+
111+
private record ParenthesisState(boolean hasOpenParenthesis, Optional<String> errorMessage) {
112+
}
113+
114+
private static Optional<StringLiteral> extractPatternFromCallExpr(CallExpression callExpression) {
115+
return Optional.ofNullable(TreeUtils.nthArgumentOrKeyword(1, "pattern", callExpression.arguments()))
116+
.map(RegularArgument::expression)
117+
.flatMap(TreeUtils.toOptionalInstanceOfMapper(StringLiteral.class));
118+
}
119+
120+
private static Optional<EinopsPattern> toEinopsPattern(StringLiteral pattern) {
121+
String[] split = pattern.trimmedQuotesValue().split("->");
122+
if (split.length == 2) {
123+
var lhsStr = split[0].trim();
124+
var rhsStr = split[1].trim();
125+
if (!lhsStr.isEmpty() && !rhsStr.isEmpty()) {
126+
var lhs = parseEinopsPattern(lhsStr);
127+
var rhs = parseEinopsPattern(rhsStr);
128+
return Optional.of(new EinopsPattern(pattern, lhs, rhs));
129+
}
130+
}
131+
return Optional.empty();
132+
}
133+
134+
private static EinopsSide parseEinopsPattern(String pattern) {
135+
Set<String> identifiers = new LinkedHashSet<>();
136+
var currentIdentifier = new StringBuilder();
137+
ParenthesisState state = new ParenthesisState(false, Optional.empty());
138+
for (int i = 0; i < pattern.length(); i++) {
139+
char c = pattern.charAt(i);
140+
if (c == ' ' || c == '(' || c == ')') {
141+
if (!currentIdentifier.isEmpty()) {
142+
identifiers.add(currentIdentifier.toString());
143+
currentIdentifier.setLength(0);
144+
}
145+
state = checkParenthesisBalance(c, state);
146+
} else if (Character.isLetterOrDigit(c) || c == '_' || c == '…') {
147+
currentIdentifier.append(c);
148+
}
149+
}
150+
if (!currentIdentifier.isEmpty()) {
151+
identifiers.add(currentIdentifier.toString());
152+
}
153+
if (state.hasOpenParenthesis && state.errorMessage.isEmpty()) {
154+
state = new ParenthesisState(true, Optional.of(UNBALANCED_PARENTHESIS_MESSAGE));
155+
}
156+
return new EinopsSide(pattern, identifiers, state);
157+
}
158+
159+
private static ParenthesisState checkParenthesisBalance(char c, ParenthesisState state) {
160+
Optional<String> errorMessage = state.errorMessage;
161+
if (' ' == c) {
162+
return state;
163+
}
164+
if ('(' == c && state.hasOpenParenthesis) {
165+
errorMessage = Optional.of(NESTED_PARENTHESIS_MESSAGE);
166+
}
167+
if (')' == c && !state.hasOpenParenthesis && errorMessage.isEmpty()) {
168+
errorMessage = Optional.of(UNBALANCED_PARENTHESIS_MESSAGE);
169+
}
170+
return new ParenthesisState('(' == c, errorMessage);
171+
}
172+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<p>This rule raises an issue when an incorrect pattern is provided to an <code>einops</code> operation.</p>
2+
<h2>Why is this an issue?</h2>
3+
<p>The <code>einops</code> library provides a powerful and flexible way to manipulate tensors using the Einstein summation convention. The
4+
<code>einops</code> uses a different convention than the <a href="https://rockt.github.io/2018/04/30/einsum">traditional</a> one. In particular, the
5+
axis names can be more than one letter long and are separated by spaces.</p>
6+
<h2>How to fix it</h2>
7+
<p>Correct the syntax of the <code>einops</code> operation by balancing the parentheses and following the convention.</p>
8+
<h3>Code examples</h3>
9+
<h4>Noncompliant code example</h4>
10+
<pre data-diff-id="1" data-diff-type="noncompliant">
11+
from einops import rearrange
12+
import torch
13+
14+
x = torch.randn(2, 3, 4, 5)
15+
x2 = rearrange(x, 'b c h w -&gt; b (c h w') # Noncompliant : the parentheses are not balanced
16+
</pre>
17+
<h4>Compliant solution</h4>
18+
<pre data-diff-id="1" data-diff-type="compliant">
19+
from einops import rearrange
20+
import torch
21+
22+
x = torch.randn(2, 3, 4, 5)
23+
x2 = rearrange(x, 'b c h w -&gt; b (c h w)')
24+
</pre>
25+
<h2>Resources</h2>
26+
<h3>Documentation</h3>
27+
<ul>
28+
<li> <code>einops</code> documentation - <a href="https://einops.rocks/1-einops-basics/#welcome-to-einops-land">Einops basics</a> </li>
29+
</ul>
30+
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
{
2+
"title": "Einops pattern should be valid",
3+
"type": "BUG",
4+
"status": "ready",
5+
"remediation": {
6+
"func": "Constant\/Issue",
7+
"constantCost": "5min"
8+
},
9+
"tags": [],
10+
"defaultSeverity": "Major",
11+
"ruleSpecification": "RSPEC-6984",
12+
"sqKey": "S6984",
13+
"scope": "All",
14+
"quickfix": "infeasible",
15+
"code": {
16+
"impacts": {
17+
"RELIABILITY": "HIGH"
18+
},
19+
"attribute": "LOGICAL"
20+
}
21+
}

python-checks/src/main/resources/org/sonar/l10n/py/rules/python/Sonar_way_profile.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@
248248
"S6974",
249249
"S6979",
250250
"S6983",
251+
"S6984",
251252
"S6985"
252253
]
253254
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* SonarQube Python Plugin
3+
* Copyright (C) 2011-2024 SonarSource SA
4+
* mailto:info AT sonarsource DOT com
5+
*
6+
* This program is free software; you can redistribute it and/or
7+
* modify it under the terms of the GNU Lesser General Public
8+
* License as published by the Free Software Foundation; either
9+
* version 3 of the License, or (at your option) any later version.
10+
*
11+
* This program is distributed in the hope that it will be useful,
12+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14+
* Lesser General Public License for more details.
15+
*
16+
* You should have received a copy of the GNU Lesser General Public License
17+
* along with this program; if not, write to the Free Software Foundation,
18+
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
19+
*/
20+
package org.sonar.python.checks;
21+
22+
import org.junit.jupiter.api.Test;
23+
import org.sonar.python.checks.utils.PythonCheckVerifier;
24+
25+
class EinopsSyntaxCheckTest {
26+
27+
@Test
28+
void test() {
29+
PythonCheckVerifier.verify("src/test/resources/checks/einopsSyntaxCheck.py", new EinopsSyntaxCheck()); }
30+
}
31+
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import test
2+
import torch
3+
import einops
4+
from einops import reduce, repeat, rearrange
5+
6+
img = torch.randn(32, 32, 3)
7+
imgs = torch.randn(10, 32, 32, 3)
8+
9+
rearrange(img, 'h w c -> c h w')
10+
rearrange(imgs, 'b h w c -> b c h w')
11+
rearrange(imgs, 'b h w c -> (b h) w c')
12+
13+
a = rearrange(imgs,"") # Noncompliant {{Provide a valid einops pattern.}}
14+
# ^^^^^^^^^
15+
a = rearrange(imgs,"b h") # Noncompliant
16+
# ^^^^^^^^^
17+
a = rearrange(imgs,"b h ->") # Noncompliant
18+
# ^^^^^^^^^
19+
a = rearrange(imgs,"-> h w") # Noncompliant
20+
# ^^^^^^^^^
21+
a = rearrange(imgs,"->") # Noncompliant
22+
# ^^^^^^^^^
23+
24+
test.rearrange(imgs, '(...) -> (... h) w c ')
25+
unstacked2 = rearrange(imgs, 'h w c -> (... h) w c')
26+
unstacked2 = rearrange(imgs, '(... h) w c -> ... h w c') # Noncompliant {{Fix the syntax of this einops operation: Ellipsis inside parenthesis on the left side is not allowed.}}
27+
#^^^^^^^^^^^^^^^^^^^^^^^^^^
28+
unstacked2 = einops.rearrange(imgs, '(... h) w c -> ... h w c') # Noncompliant
29+
#^^^^^^^^^^^^^^^^^^^^^^^^^^
30+
unstacked2 = einops.rearrange(pattern='(... h) w c -> ... h w c', tensor=imgs) # Noncompliant
31+
#^^^^^^^^^^^^^^^^^^^^^^^^^^
32+
unstacked2 = rearrange(imgs, '(... h) w c -> (... h) w c') # Noncompliant
33+
#^^^^^^^^^^^^^^^^^^^^^^^^^^^^
34+
35+
unstacked2 = rearrange(imgs, '(… h) w c -> (... h) w c') # Noncompliant
36+
#^^^^^^^^^^^^^^^^^^^^^^^^^^
37+
38+
repeat(imgs, "(h) (( w c) -> (h w c)") # Noncompliant {{Fix the syntax of this einops operation: nested parenthesis are not allowed.}}
39+
repeat(imgs, "(h w c -> (h w c)") # Noncompliant {{Fix the syntax of this einops operation: parenthesis are unbalanced.}}
40+
repeat(imgs, "h w c -> h w c))") # Noncompliant
41+
repeat(imgs, "h w c -> h w c(") # Noncompliant
42+
repeat(imgs, "h w c) -> h w c") # Noncompliant
43+
repeat(imgs, "h w c -> (h w c(") # Noncompliant
44+
rearrange(imgs, ")h w c -> h w c") # Noncompliant
45+
reduce(imgs, "h w c -> (h w c(") # FN should be fixed with SONARPY-2137
46+
47+
48+
reduce(imgs, 'b c -> b c', 'max')
49+
rearrange(imgs, "h w c -> h w c", 1) # Not a correct parameter but still we should not raise.
50+
unstacked = rearrange(imgs, '(b h) w c -> b h w c', b=10)
51+
rearrange(imgs, 'b c h2 -> b c w2', h2=2, w2=2)
52+
53+
rearrange(imgs, 'b c -> b c', h2=2, w2=2) # Noncompliant {{Fix the syntax of this einops operation: the parameters h2, w2 do not appear in the pattern.}}
54+
rearrange(imgs, "(b h) w c -> b h w c ", b1=1) # Noncompliant {{Fix the syntax of this einops operation: the parameter b1 does not appear in the pattern.}}
55+
rearrange(imgs, "(b h) w c -> b h w c b1", b1=1)
56+
reduce(imgs, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h2=2, z=2) # FN should be fixed with SONARPY-2137
57+
reduce(imgs, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h2=2, w2=2) # FN should be fixed with SONARPY-2137

0 commit comments

Comments
 (0)