Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,18 +411,15 @@ def _infer_constraints(
# When the template is a union, we are okay with leaving some
# type variables indeterminate. This helps with some special
# cases, though this isn't very principled.
result = any_constraints(
if has_recursive_types(template) and not has_recursive_types(actual):
return handle_recursive_union(template, actual, direction)
return any_constraints(
[
infer_constraints_if_possible(t_item, actual, direction)
for t_item in template.items
],
eager=isinstance(actual, AnyType),
)
if result:
return result
elif has_recursive_types(template) and not has_recursive_types(actual):
return handle_recursive_union(template, actual, direction)
return []

# Remaining cases are handled by ConstraintBuilderVisitor.
return template.accept(ConstraintBuilderVisitor(actual, direction, skip_neg_op))
Expand Down Expand Up @@ -535,13 +532,12 @@ def any_constraints(options: list[list[Constraint] | None], *, eager: bool) -> l
# Multiple sets of constraints that are all the same. Just pick any one of them.
return valid_options[0]

if all(is_similar_constraints(valid_options[0], c) for c in valid_options[1:]):
all_similar = all(is_similar_constraints(valid_options[0], c) for c in valid_options[1:])
if all_similar:
# All options have same structure. In this case we can merge-in trivial
# options (i.e. those that only have Any) and try again.
# TODO: More generally, if a given (variable, direction) pair appears in
# every option, combine the bounds with meet/join always, not just for Any.
trivial_options = select_trivial(valid_options)
if trivial_options and len(trivial_options) < len(valid_options):
if 0 < len(trivial_options) < len(valid_options):
merged_options = []
for option in valid_options:
if option in trivial_options:
Expand All @@ -563,6 +559,31 @@ def any_constraints(options: list[list[Constraint] | None], *, eager: bool) -> l
if filtered_options != options:
return any_constraints(filtered_options, eager=eager)

if (
eager
and all_similar
and not any(isinstance(c.target, ErasedType) for group in valid_options for c in group)
):
# Now we know all constraints might be satisfiable and have similar structure.
# Solver will apply meets and joins as necessary, but Any should be forced into
# union to survive during meet.
# If any targets are erased, fall back to empty, otherwise they will be discarded
# by solver, causing false early matches.
cmap: dict[TypeVarId, list[Constraint]] = {}
for option in valid_options:
for c in option:
cmap.setdefault(c.type_var, []).append(c)
out: list[Constraint] = []
for group in cmap.values():
if any(isinstance(get_proper_type(c.target), AnyType) for c in group):
group = [
merge_with_any(c)
for c in group
if not isinstance(get_proper_type(c.target), AnyType)
]
out.extend(dict.fromkeys(group))
return out

# Otherwise, there are either no valid options or multiple, inconsistent valid
# options. Give up and deduce nothing.
return []
Expand Down
30 changes: 30 additions & 0 deletions test-data/unit/check-inference-context.test
Original file line number Diff line number Diff line change
Expand Up @@ -1530,3 +1530,33 @@ def check3(use: bool, val: str) -> "str | Literal[False]":
def check4(use: bool, val: str) -> "str | bool":
return use and identity(val)
[builtins fixtures/tuple.pyi]

[case testDictOrLiteralInContext]
from typing import Union, Optional, Any

P = dict[str, Union[Optional[str], dict[str, Optional[str]]]]

def f(x: P) -> None:
pass

def g(x: Union[dict[str, Any], None], s: Union[str, None]) -> None:
f(x or {'x': s})
[builtins fixtures/dict.pyi]

[case testInferConstrainedTypeVarInUnion]
from typing import Generic, TypeVar, Union

_S_co = TypeVar("_S_co", str, int, covariant=True)
_S = TypeVar("_S", str, int)

class HasFoo(Generic[_S_co]):
def foo(self) -> _S_co: ...

def walk(path: Union[_S, HasFoo[_S]]) -> None:
...

class Path(HasFoo[str]):
def foo(self) -> str: ...

walk(Path())
[builtins fixtures/tuple.pyi]
Loading