Skip to content

Commit 3980042

Browse files
SONARPY-1815 Enable AST-based type inference for functions/module containing try/catch blocks (#1792)
1 parent 1f1af46 commit 3980042

File tree

4 files changed

+312
-5
lines changed

4 files changed

+312
-5
lines changed

python-frontend/src/main/java/org/sonar/python/semantic/v2/TypeInferenceV2.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,7 @@ private static void inferTypesAndMemberAccessSymbols(Tree scopeTree,
120120
statements.accept(tryStatementVisitor);
121121
if (tryStatementVisitor.hasTryStatement()) {
122122
// CFG doesn't model precisely try-except statements. Hence we fallback to AST based type inference
123-
// TODO: Check if still relevant
124-
/* visitor.processPropagations(getTrackedVars(declaredVariables, assignedNames));
125-
statements.accept(new TypeInference.NameVisitor());*/
123+
propagationVisitor.processPropagations(getTrackedVars(declaredVariables, assignedNames));
126124
} else {
127125
ControlFlowGraph cfg = controlFlowGraphSupplier.get();
128126
if (cfg == null) {

python-frontend/src/main/java/org/sonar/python/semantic/v2/types/Assignment.java

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,84 @@
1919
*/
2020
package org.sonar.python.semantic.v2.types;
2121

22+
import java.util.ArrayDeque;
23+
import java.util.Deque;
24+
import java.util.HashSet;
25+
import java.util.Map;
26+
import java.util.Set;
2227
import org.sonar.plugins.python.api.tree.Expression;
2328
import org.sonar.plugins.python.api.tree.Name;
2429
import org.sonar.python.semantic.v2.SymbolV2;
30+
import org.sonar.python.semantic.v2.UsageV2;
31+
import org.sonar.python.tree.NameImpl;
32+
import org.sonar.python.types.HasTypeDependencies;
33+
import org.sonar.python.types.v2.PythonType;
34+
import org.sonar.python.types.v2.UnionType;
2535

26-
public record Assignment(SymbolV2 lhsSymbol, Name lhsName, Expression rhs) {
36+
public class Assignment {
37+
38+
final SymbolV2 lhsSymbol;
39+
Name lhsName;
40+
Expression rhs;
41+
Set<SymbolV2> variableDependencies = new HashSet<>();
42+
Set<Assignment> dependents = new HashSet<>();
43+
Map<SymbolV2, Set<Assignment>> assignmentsByLhs;
44+
45+
public Assignment(SymbolV2 lhsSymbol, Name lhsName, Expression rhs, Map<SymbolV2, Set<Assignment>> assignmentsByLhs) {
46+
this.lhsSymbol = lhsSymbol;
47+
this.lhsName = lhsName;
48+
this.rhs = rhs;
49+
this.assignmentsByLhs = assignmentsByLhs;
50+
}
51+
52+
void computeDependencies(Expression expression, Set<SymbolV2> trackedVars) {
53+
Deque<Expression> workList = new ArrayDeque<>();
54+
workList.push(expression);
55+
while (!workList.isEmpty()) {
56+
Expression e = workList.pop();
57+
if (e instanceof Name name) {
58+
SymbolV2 symbol = name.symbolV2();
59+
if (symbol != null && trackedVars.contains(symbol)) {
60+
variableDependencies.add(symbol);
61+
assignmentsByLhs.get(symbol).forEach(a -> a.dependents.add(this));
62+
}
63+
} else if (e instanceof HasTypeDependencies hasTypeDependencies) {
64+
workList.addAll(hasTypeDependencies.typeDependencies());
65+
}
66+
}
67+
}
68+
69+
boolean areDependenciesReady(Set<SymbolV2> initializedVars) {
70+
return initializedVars.containsAll(variableDependencies);
71+
}
72+
73+
/** @return true if the propagation effectively changed the inferred type of lhs */
74+
public boolean propagate(Set<SymbolV2> initializedVars) {
75+
PythonType rhsType = rhs.typeV2();
76+
if (initializedVars.add(lhsSymbol)) {
77+
lhsSymbol.usages().stream().map(UsageV2::tree).filter(NameImpl.class::isInstance).map(NameImpl.class::cast).forEach(n -> n.typeV2(rhsType));
78+
return true;
79+
} else {
80+
PythonType currentType = lhsName.typeV2();
81+
PythonType newType = UnionType.or(rhsType, currentType);
82+
lhsSymbol.usages().stream().map(UsageV2::tree).filter(NameImpl.class::isInstance).map(NameImpl.class::cast).forEach(n -> n.typeV2(newType));
83+
return !newType.equals(currentType);
84+
}
85+
}
86+
87+
public Name lhsName() {
88+
return lhsName;
89+
}
90+
91+
public SymbolV2 lhsSymbol() {
92+
return lhsSymbol;
93+
}
94+
95+
public Expression rhs() {
96+
return rhs;
97+
}
98+
99+
public Set<Assignment> dependents() {
100+
return dependents;
101+
}
27102
}

python-frontend/src/main/java/org/sonar/python/semantic/v2/types/PropagationVisitor.java

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import java.util.HashMap;
2323
import java.util.HashSet;
24+
import java.util.Iterator;
2425
import java.util.List;
2526
import java.util.Map;
2627
import java.util.Set;
@@ -83,9 +84,39 @@ public void visitAnnotatedAssignment(AnnotatedAssignment annotatedAssignment){
8384
private void processAssignment(Statement assignmentStatement, Expression lhsExpression, Expression rhsExpression){
8485
if (lhsExpression instanceof Name lhs && lhs.symbolV2() != null) {
8586
var symbol = lhs.symbolV2();
86-
Assignment assignment = new Assignment(symbol, lhs, rhsExpression);
87+
Assignment assignment = new Assignment(symbol, lhs, rhsExpression, assignmentsByLhs);
8788
assignmentsByAssignmentStatement.put(assignmentStatement, assignment);
8889
assignmentsByLhs.computeIfAbsent(symbol, s -> new HashSet<>()).add(assignment);
8990
}
9091
}
92+
93+
public void processPropagations(Set<SymbolV2> trackedVars) {
94+
Set<Assignment> propagations = new HashSet<>();
95+
Set<SymbolV2> initializedVars = new HashSet<>();
96+
97+
assignmentsByLhs.forEach((lhs, as) -> {
98+
if (trackedVars.contains(lhs)) {
99+
as.forEach(a -> a.computeDependencies(a.rhs(), trackedVars));
100+
propagations.addAll(as);
101+
}
102+
});
103+
104+
applyPropagations(propagations, initializedVars, true);
105+
applyPropagations(propagations, initializedVars, false);
106+
}
107+
108+
private static void applyPropagations(Set<Assignment> propagations, Set<SymbolV2> initializedVars, boolean checkDependenciesReadiness) {
109+
Set<Assignment> workSet = new HashSet<>(propagations);
110+
while (!workSet.isEmpty()) {
111+
Iterator<Assignment> iterator = workSet.iterator();
112+
Assignment propagation = iterator.next();
113+
iterator.remove();
114+
if (!checkDependenciesReadiness || propagation.areDependenciesReady(initializedVars)) {
115+
boolean learnt = propagation.propagate(initializedVars);
116+
if (learnt) {
117+
workSet.addAll(propagation.dependents());
118+
}
119+
}
120+
}
121+
}
91122
}

python-frontend/src/test/java/org/sonar/python/semantic/v2/TypeInferenceV2Test.java

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import java.util.HashSet;
2525
import java.util.List;
2626
import java.util.Map;
27+
import java.util.Optional;
2728
import java.util.Set;
2829
import org.assertj.core.api.Assertions;
2930
import org.junit.jupiter.api.Disabled;
@@ -39,6 +40,7 @@
3940
import org.sonar.plugins.python.api.tree.ImportFrom;
4041
import org.sonar.plugins.python.api.tree.ImportName;
4142
import org.sonar.plugins.python.api.tree.Name;
43+
import org.sonar.plugins.python.api.tree.RegularArgument;
4244
import org.sonar.plugins.python.api.tree.Statement;
4345
import org.sonar.plugins.python.api.tree.StatementList;
4446
import org.sonar.plugins.python.api.tree.Tree;
@@ -626,6 +628,207 @@ def foo():
626628
""").typeV2().unwrappedType()).isEqualTo(INT_TYPE);
627629
}
628630

631+
@Test
632+
void flow_insensitive_when_try_except() {
633+
FileInput fileInput = inferTypes("""
634+
try:
635+
if p:
636+
x = 42
637+
type(x)
638+
else:
639+
x = "foo"
640+
type(x)
641+
except:
642+
type(x)
643+
""");
644+
645+
List<CallExpression> calls = PythonTestUtils.getAllDescendant(fileInput, tree -> tree.is(Tree.Kind.CALL_EXPR));
646+
RegularArgument firstX = (RegularArgument) calls.get(0).arguments().get(0);
647+
RegularArgument secondX = (RegularArgument) calls.get(1).arguments().get(0);
648+
RegularArgument thirdX = (RegularArgument) calls.get(2).arguments().get(0);
649+
assertThat(((UnionType) firstX.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
650+
assertThat(((UnionType) secondX.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
651+
assertThat(((UnionType) thirdX.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
652+
}
653+
654+
@Test
655+
void nested_try_except() {
656+
FileInput fileInput = inferTypes("""
657+
def f(p):
658+
try:
659+
if p:
660+
x = 42
661+
type(x)
662+
else:
663+
x = "foo"
664+
type(x)
665+
except:
666+
type(x)
667+
def g(p):
668+
if p:
669+
y = 42
670+
type(y)
671+
else:
672+
y = "hello"
673+
type(y)
674+
type(y)
675+
if cond:
676+
z = 42
677+
type(z)
678+
else:
679+
z = "hello"
680+
type(z)
681+
type(z)
682+
""");
683+
List<CallExpression> calls = PythonTestUtils.getAllDescendant(fileInput, tree -> tree.is(Tree.Kind.CALL_EXPR));
684+
RegularArgument firstX = (RegularArgument) calls.get(0).arguments().get(0);
685+
RegularArgument secondX = (RegularArgument) calls.get(1).arguments().get(0);
686+
RegularArgument thirdX = (RegularArgument) calls.get(2).arguments().get(0);
687+
assertThat(((UnionType) firstX.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
688+
assertThat(((UnionType) secondX.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
689+
assertThat(((UnionType) thirdX.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
690+
691+
RegularArgument firstY = (RegularArgument) calls.get(3).arguments().get(0);
692+
RegularArgument secondY = (RegularArgument) calls.get(4).arguments().get(0);
693+
RegularArgument thirdY = (RegularArgument) calls.get(5).arguments().get(0);
694+
assertThat(firstY.expression().typeV2().unwrappedType()).isEqualTo(INT_TYPE);
695+
assertThat(secondY.expression().typeV2().unwrappedType()).isEqualTo(STR_TYPE);
696+
assertThat(((UnionType) thirdY.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
697+
698+
RegularArgument firstZ = (RegularArgument) calls.get(6).arguments().get(0);
699+
RegularArgument secondZ = (RegularArgument) calls.get(7).arguments().get(0);
700+
RegularArgument thirdZ = (RegularArgument) calls.get(8).arguments().get(0);
701+
assertThat(firstZ.expression().typeV2().unwrappedType()).isEqualTo(INT_TYPE);
702+
assertThat(secondZ.expression().typeV2().unwrappedType()).isEqualTo(STR_TYPE);
703+
assertThat(((UnionType) thirdZ.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
704+
}
705+
706+
@Test
707+
void nested_try_except_2() {
708+
FileInput fileInput = inferTypes("""
709+
try:
710+
if p:
711+
x = 42
712+
type(x)
713+
else:
714+
x = "foo"
715+
type(x)
716+
except:
717+
type(x)
718+
def g(p):
719+
if p:
720+
y = 42
721+
type(y)
722+
else:
723+
y = "hello"
724+
type(y)
725+
type(y)
726+
if cond:
727+
z = 42
728+
type(z)
729+
else:
730+
z = "hello"
731+
type(z)
732+
type(z)
733+
""");
734+
List<CallExpression> calls = PythonTestUtils.getAllDescendant(fileInput, tree -> tree.is(Tree.Kind.CALL_EXPR));
735+
RegularArgument firstX = (RegularArgument) calls.get(0).arguments().get(0);
736+
RegularArgument secondX = (RegularArgument) calls.get(1).arguments().get(0);
737+
RegularArgument thirdX = (RegularArgument) calls.get(2).arguments().get(0);
738+
assertThat(((UnionType) firstX.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
739+
assertThat(((UnionType) secondX.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
740+
assertThat(((UnionType) thirdX.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
741+
742+
RegularArgument firstY = (RegularArgument) calls.get(3).arguments().get(0);
743+
RegularArgument secondY = (RegularArgument) calls.get(4).arguments().get(0);
744+
RegularArgument thirdY = (RegularArgument) calls.get(5).arguments().get(0);
745+
assertThat(firstY.expression().typeV2().unwrappedType()).isEqualTo(INT_TYPE);
746+
assertThat(secondY.expression().typeV2().unwrappedType()).isEqualTo(STR_TYPE);
747+
assertThat(((UnionType) thirdY.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
748+
749+
RegularArgument firstZ = (RegularArgument) calls.get(6).arguments().get(0);
750+
RegularArgument secondZ = (RegularArgument) calls.get(7).arguments().get(0);
751+
RegularArgument thirdZ = (RegularArgument) calls.get(8).arguments().get(0);
752+
assertThat(((UnionType) firstZ.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
753+
assertThat(((UnionType) secondZ.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
754+
assertThat(((UnionType) thirdZ.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
755+
}
756+
757+
@Test
758+
void try_except_with_dependents() {
759+
FileInput fileInput = inferTypes("""
760+
try:
761+
x = 42
762+
y = x
763+
z = y
764+
type(x)
765+
type(y)
766+
type(z)
767+
except:
768+
x = "hello"
769+
y = x
770+
z = y
771+
type(x)
772+
type(y)
773+
type(z)
774+
type(x)
775+
type(y)
776+
type(z)
777+
""");
778+
779+
List<CallExpression> calls = PythonTestUtils.getAllDescendant(fileInput, tree -> tree.is(Tree.Kind.CALL_EXPR));
780+
RegularArgument firstX = (RegularArgument) calls.get(0).arguments().get(0);
781+
RegularArgument firstY = (RegularArgument) calls.get(1).arguments().get(0);
782+
RegularArgument firstZ = (RegularArgument) calls.get(2).arguments().get(0);
783+
assertThat(((UnionType) firstX.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
784+
assertThat(((UnionType) firstY.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
785+
assertThat(((UnionType) firstZ.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
786+
787+
RegularArgument secondX = (RegularArgument) calls.get(3).arguments().get(0);
788+
RegularArgument secondY = (RegularArgument) calls.get(4).arguments().get(0);
789+
RegularArgument secondZ = (RegularArgument) calls.get(5).arguments().get(0);
790+
assertThat(((UnionType) secondX.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
791+
assertThat(((UnionType) secondY.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
792+
assertThat(((UnionType) secondZ.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
793+
794+
RegularArgument thirdX = (RegularArgument) calls.get(6).arguments().get(0);
795+
RegularArgument thirdY = (RegularArgument) calls.get(7).arguments().get(0);
796+
RegularArgument thirdZ = (RegularArgument) calls.get(8).arguments().get(0);
797+
assertThat(((UnionType) thirdX.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
798+
assertThat(((UnionType) thirdY.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
799+
assertThat(((UnionType) thirdZ.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
800+
}
801+
802+
@Test
803+
void try_except_list_attributes() {
804+
FileInput fileInput = inferTypes("""
805+
try:
806+
my_list = [1, 2, 3]
807+
type(my_list)
808+
except:
809+
my_list = ["a", "b", "c"]
810+
type(my_list)
811+
type(my_list)
812+
""");
813+
814+
List<CallExpression> calls = PythonTestUtils.getAllDescendant(fileInput, tree -> tree.is(Tree.Kind.CALL_EXPR));
815+
RegularArgument list1 = (RegularArgument) calls.get(0).arguments().get(0);
816+
RegularArgument list2 = (RegularArgument) calls.get(1).arguments().get(0);
817+
RegularArgument list3 = (RegularArgument) calls.get(2).arguments().get(0);
818+
819+
UnionType listType = (UnionType) list1.expression().typeV2();
820+
assertThat(listType.candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(LIST_TYPE, LIST_TYPE);
821+
assertThat(listType.candidates())
822+
.map(ObjectType.class::cast)
823+
.flatExtracting(ObjectType::attributes)
824+
.extracting(PythonType::unwrappedType)
825+
.containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
826+
827+
assertThat(list2.expression().typeV2()).isEqualTo(listType);
828+
assertThat(list3.expression().typeV2()).isEqualTo(listType);
829+
830+
}
831+
629832
private static FileInput inferTypes(String lines) {
630833
return inferTypes(lines, new HashMap<>());
631834
}

0 commit comments

Comments
 (0)