Skip to content

Commit 0367141

Browse files
authored
Extend constraints to support ternary expressions (#2836)
1 parent fe09b02 commit 0367141

File tree

4 files changed

+191
-4
lines changed

4 files changed

+191
-4
lines changed

ChangeLog

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ What's New in astroid 4.0.0?
77
============================
88
Release date: TBA
99

10+
* Support constraints from ternary expressions in inference.
11+
12+
Closes pylint-dev/pylint#9729
13+
1014
* Handle deprecated `bool(NotImplemented)` cast in const nodes.
1115

1216
* Add support for boolean truthiness constraints (`x`, `not x`) in inference.

astroid/constraint.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def satisfied_by(self, inferred: InferenceResult) -> bool:
127127

128128
def get_constraints(
129129
expr: _NameNodes, frame: nodes.LocalsDictNodeNG
130-
) -> dict[nodes.If, set[Constraint]]:
130+
) -> dict[nodes.If | nodes.IfExp, set[Constraint]]:
131131
"""Returns the constraints for the given expression.
132132
133133
The returned dictionary maps the node where the constraint was generated to the
@@ -137,10 +137,10 @@ def get_constraints(
137137
Currently this only supports constraints generated from if conditions.
138138
"""
139139
current_node: nodes.NodeNG | None = expr
140-
constraints_mapping: dict[nodes.If, set[Constraint]] = {}
140+
constraints_mapping: dict[nodes.If | nodes.IfExp, set[Constraint]] = {}
141141
while current_node is not None and current_node is not frame:
142142
parent = current_node.parent
143-
if isinstance(parent, nodes.If):
143+
if isinstance(parent, (nodes.If, nodes.IfExp)):
144144
branch, _ = parent.locate_child(current_node)
145145
constraints: set[Constraint] | None = None
146146
if branch == "body":

astroid/context.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ def __init__(
8080
self.extra_context: dict[SuccessfulInferenceResult, InferenceContext] = {}
8181
"""Context that needs to be passed down through call stacks for call arguments."""
8282

83-
self.constraints: dict[str, dict[nodes.If, set[constraint.Constraint]]] = {}
83+
self.constraints: dict[
84+
str, dict[nodes.If | nodes.IfExp, set[constraint.Constraint]]
85+
] = {}
8486
"""The constraints on nodes."""
8587

8688
@property

tests/test_constraint.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,3 +592,184 @@ def method(self):
592592

593593
assert isinstance(inferred[1], nodes.Const)
594594
assert inferred[1].value == fail_val
595+
596+
597+
@common_params(node="x")
598+
def test_if_exp_body(
599+
condition: str, satisfy_val: int | None, fail_val: int | None
600+
) -> None:
601+
"""Test constraint for a variable that is used in an if exp body."""
602+
node1, node2 = builder.extract_node(
603+
f"""
604+
def f1(x = {fail_val}):
605+
return (
606+
x if {condition} else None #@
607+
)
608+
609+
def f2(x = {satisfy_val}):
610+
return (
611+
x if {condition} else None #@
612+
)
613+
"""
614+
)
615+
616+
inferred = node1.body.inferred()
617+
assert len(inferred) == 1
618+
assert inferred[0] is Uninferable
619+
620+
inferred = node2.body.inferred()
621+
assert len(inferred) == 2
622+
assert isinstance(inferred[0], nodes.Const)
623+
assert inferred[0].value == satisfy_val
624+
assert inferred[1] is Uninferable
625+
626+
627+
@common_params(node="x")
628+
def test_if_exp_else(
629+
condition: str, satisfy_val: int | None, fail_val: int | None
630+
) -> None:
631+
"""Test constraint for a variable that is used in an if exp else block."""
632+
node1, node2 = builder.extract_node(
633+
f"""
634+
def f1(x = {satisfy_val}):
635+
return (
636+
None if {condition} else x #@
637+
)
638+
639+
def f2(x = {fail_val}):
640+
return (
641+
None if {condition} else x #@
642+
)
643+
"""
644+
)
645+
646+
inferred = node1.orelse.inferred()
647+
assert len(inferred) == 1
648+
assert inferred[0] is Uninferable
649+
650+
inferred = node2.orelse.inferred()
651+
assert len(inferred) == 2
652+
assert isinstance(inferred[0], nodes.Const)
653+
assert inferred[0].value == fail_val
654+
assert inferred[1] is Uninferable
655+
656+
657+
@common_params(node="x")
658+
def test_outside_if_exp(
659+
condition: str, satisfy_val: int | None, fail_val: int | None
660+
) -> None:
661+
"""Test that constraint in an if exp condition doesn't apply outside of the if exp."""
662+
nodes_ = builder.extract_node(
663+
f"""
664+
def f1(x = {fail_val}):
665+
x if {condition} else None
666+
return (
667+
x #@
668+
)
669+
670+
def f2(x = {satisfy_val}):
671+
None if {condition} else x
672+
return (
673+
x #@
674+
)
675+
"""
676+
)
677+
for node, val in zip(nodes_, (fail_val, satisfy_val)):
678+
inferred = node.inferred()
679+
assert len(inferred) == 2
680+
assert isinstance(inferred[0], nodes.Const)
681+
assert inferred[0].value == val
682+
assert inferred[1] is Uninferable
683+
684+
685+
@common_params(node="x")
686+
def test_nested_if_exp(
687+
condition: str, satisfy_val: int | None, fail_val: int | None
688+
) -> None:
689+
"""Test that constraint in an if exp condition applies within inner if exp."""
690+
node1, node2 = builder.extract_node(
691+
f"""
692+
def f1(y, x = {fail_val}):
693+
return (
694+
(x if y else None) if {condition} else None #@
695+
)
696+
697+
def f2(y, x = {satisfy_val}):
698+
return (
699+
(x if y else None) if {condition} else None #@
700+
)
701+
"""
702+
)
703+
704+
inferred = node1.body.body.inferred()
705+
assert len(inferred) == 1
706+
assert inferred[0] is Uninferable
707+
708+
inferred = node2.body.body.inferred()
709+
assert len(inferred) == 2
710+
assert isinstance(inferred[0], nodes.Const)
711+
assert inferred[0].value == satisfy_val
712+
assert inferred[1] is Uninferable
713+
714+
715+
@common_params(node="self.x")
716+
def test_if_exp_instance_attr(
717+
condition: str, satisfy_val: int | None, fail_val: int | None
718+
) -> None:
719+
"""Test constraint for an instance attribute in an if exp."""
720+
node1, node2 = builder.extract_node(
721+
f"""
722+
class A1:
723+
def __init__(self, x = {fail_val}):
724+
self.x = x
725+
726+
def method(self):
727+
return (
728+
self.x if {condition} else None #@
729+
)
730+
731+
class A2:
732+
def __init__(self, x = {satisfy_val}):
733+
self.x = x
734+
735+
def method(self):
736+
return (
737+
self.x if {condition} else None #@
738+
)
739+
"""
740+
)
741+
742+
inferred = node1.body.inferred()
743+
assert len(inferred) == 1
744+
assert inferred[0] is Uninferable
745+
746+
inferred = node2.body.inferred()
747+
assert len(inferred) == 2
748+
assert isinstance(inferred[0], nodes.Const)
749+
assert inferred[0].value == satisfy_val
750+
assert inferred[1].value is Uninferable
751+
752+
753+
@common_params(node="self.x")
754+
def test_if_exp_instance_attr_varname_collision(
755+
condition: str, satisfy_val: int | None, fail_val: int | None
756+
) -> None:
757+
"""Test that constraint in an if exp condition doesn't apply to a variable with the same name."""
758+
node = builder.extract_node(
759+
f"""
760+
class A:
761+
def __init__(self, x = {fail_val}):
762+
self.x = x
763+
764+
def method(self, x = {fail_val}):
765+
return (
766+
x if {condition} else None #@
767+
)
768+
"""
769+
)
770+
771+
inferred = node.body.inferred()
772+
assert len(inferred) == 2
773+
assert isinstance(inferred[0], nodes.Const)
774+
assert inferred[0].value == fail_val
775+
assert inferred[1].value is Uninferable

0 commit comments

Comments
 (0)