Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,13 +451,13 @@ def __init__(
self.type_stack = []
# Are the namespaces of classes being processed complete?
self.incomplete_type_stack: list[bool] = []
self.tvar_scope = TypeVarLikeScope()
self.function_stack = []
self.block_depth = [0]
self.loop_depth = [0]
self.errors = errors
self.modules = modules
self.msg = MessageBuilder(errors, modules)
self.tvar_scope = TypeVarLikeScope(msg=self.msg)
self.missing_modules = missing_modules
self.missing_names = [set()]
# These namespaces are still in process of being populated. If we encounter a
Expand Down Expand Up @@ -859,7 +859,7 @@ def file_context(
self._is_stub_file = file_node.path.lower().endswith(".pyi")
self._is_typeshed_stub_file = file_node.is_typeshed_file(options)
self.globals = file_node.names
self.tvar_scope = TypeVarLikeScope()
self.tvar_scope = TypeVarLikeScope(msg=self.msg)

self.named_tuple_analyzer = NamedTupleAnalyzer(options, self, self.msg)
self.typed_dict_analyzer = TypedDictAnalyzer(options, self, self.msg)
Expand Down Expand Up @@ -2404,7 +2404,7 @@ def tvar_defs_from_tvars(
self.fail(
message_registry.TYPE_VAR_REDECLARED_IN_NESTED_CLASS.format(name), context
)
tvar_def = self.tvar_scope.bind_new(name, tvar_expr)
tvar_def = self.tvar_scope.bind_new(name, tvar_expr, context)
if last_tvar_name_with_default is not None and not tvar_def.has_default():
self.msg.tvar_without_default_type(
tvar_def.name, last_tvar_name_with_default, context
Expand All @@ -2422,19 +2422,18 @@ def get_and_bind_all_tvars(self, type_exprs: list[Expression]) -> list[TypeVarLi
a simplified version of the logic we use for ClassDef bases. We duplicate
some amount of code, because it is hard to refactor common pieces.
"""
tvars = []
tvars: dict[str, tuple[TypeVarLikeExpr, Expression]] = {}
for base_expr in type_exprs:
try:
base = self.expr_to_unanalyzed_type(base_expr)
except TypeTranslationError:
# This error will be caught later.
continue
base_tvars = self.find_type_var_likes(base)
tvars.extend(base_tvars)
tvars = remove_dups(tvars) # Variables are defined in order of textual appearance.
for name, expr in self.find_type_var_likes(base):
tvars.setdefault(name, (expr, base_expr))
tvar_defs = []
for name, tvar_expr in tvars:
tvar_def = self.tvar_scope.bind_new(name, tvar_expr)
for name, (tvar_expr, context) in tvars.items():
tvar_def = self.tvar_scope.bind_new(name, tvar_expr, context)
tvar_defs.append(tvar_def)
return tvar_defs

Expand Down Expand Up @@ -7442,7 +7441,7 @@ def analyze_type_expr(self, expr: Expression) -> None:
# them semantically analyzed, however, if they need to treat it as an expression
# and not a type. (Which is to say, mypyc needs to do this.) Do the analysis
# in a fresh tvar scope in order to suppress any errors about using type variables.
with self.tvar_scope_frame(TypeVarLikeScope()), self.allow_unbound_tvars_set():
with self.tvar_scope_frame(TypeVarLikeScope(msg=self.msg)), self.allow_unbound_tvars_set():
expr.accept(self)

def type_analyzer(
Expand Down
87 changes: 64 additions & 23 deletions mypy/tvar_scope.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,70 @@
from __future__ import annotations

from mypy.messages import MessageBuilder
from mypy.nodes import (
Context,
ParamSpecExpr,
SymbolTableNode,
TypeVarExpr,
TypeVarLikeExpr,
TypeVarTupleExpr,
)
from mypy.types import (
AnyType,
ParamSpecFlavor,
ParamSpecType,
TrivialSyntheticTypeTranslator,
Type,
TypeAliasType,
TypeOfAny,
TypeVarId,
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
)
from mypy.typetraverser import TypeTraverserVisitor


class TypeVarLikeNamespaceSetter(TypeTraverserVisitor):
class TypeVarLikeDefaultFixer(TrivialSyntheticTypeTranslator):
"""Set namespace for all TypeVarLikeTypes types."""

def __init__(self, namespace: str) -> None:
self.namespace = namespace

def visit_type_var(self, t: TypeVarType) -> None:
t.id.namespace = self.namespace
super().visit_type_var(t)

def visit_param_spec(self, t: ParamSpecType) -> None:
t.id.namespace = self.namespace
return super().visit_param_spec(t)

def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
t.id.namespace = self.namespace
super().visit_type_var_tuple(t)
def __init__(
self, scope: TypeVarLikeScope, source_tv: TypeVarLikeExpr, context: Context
) -> None:
self.scope = scope
self.source_tv = source_tv
self.context = context
super().__init__()

def visit_type_var(self, t: TypeVarType) -> Type:
existing = self.scope.get_binding(t.fullname)
if existing is None:
self._report_unbound_tvar(t)
return AnyType(TypeOfAny.from_error)
return existing

def visit_param_spec(self, t: ParamSpecType) -> Type:
existing = self.scope.get_binding(t.fullname)
if existing is None:
self._report_unbound_tvar(t)
return AnyType(TypeOfAny.from_error)
return existing

def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
existing = self.scope.get_binding(t.fullname)
if existing is None:
self._report_unbound_tvar(t)
return AnyType(TypeOfAny.from_error)
return existing

def visit_type_alias_type(self, t: TypeAliasType) -> Type:
return t

def _report_unbound_tvar(self, tvar: TypeVarLikeType) -> None:
self.scope.msg.fail(
f"Type variable {tvar.name} referenced in the default"
f" of {self.source_tv.name} is unbound",
self.context,
)


class TypeVarLikeScope:
Expand All @@ -49,6 +79,8 @@ def __init__(
is_class_scope: bool = False,
prohibited: TypeVarLikeScope | None = None,
namespace: str = "",
*,
msg: MessageBuilder,
) -> None:
"""Initializer for TypeVarLikeScope

Expand All @@ -65,6 +97,7 @@ def __init__(
self.is_class_scope = is_class_scope
self.prohibited = prohibited
self.namespace = namespace
self.msg = msg
if parent is not None:
self.func_id = parent.func_id
self.class_id = parent.class_id
Expand All @@ -87,26 +120,34 @@ def allow_binding(self, fullname: str) -> bool:

def method_frame(self, namespace: str) -> TypeVarLikeScope:
"""A new scope frame for binding a method"""
return TypeVarLikeScope(self, False, None, namespace=namespace)
return TypeVarLikeScope(self, False, None, namespace=namespace, msg=self.msg)

def class_frame(self, namespace: str) -> TypeVarLikeScope:
"""A new scope frame for binding a class. Prohibits *this* class's tvars"""
return TypeVarLikeScope(self.get_function_scope(), True, self, namespace=namespace)
return TypeVarLikeScope(
self.get_function_scope(), True, self, namespace=namespace, msg=self.msg
)

def new_unique_func_id(self) -> TypeVarId:
"""Used by plugin-like code that needs to make synthetic generic functions."""
self.func_id -= 1
return TypeVarId(self.func_id)

def bind_new(self, name: str, tvar_expr: TypeVarLikeExpr) -> TypeVarLikeType:
def bind_new(self, name: str, tvar_expr: TypeVarLikeExpr, context: Context) -> TypeVarLikeType:
if self.is_class_scope:
self.class_id += 1
i = self.class_id
else:
self.func_id -= 1
i = self.func_id
namespace = self.namespace
tvar_expr.default.accept(TypeVarLikeNamespaceSetter(namespace))

# Defaults may reference other type variables. That is only valid when the
# referenced variable is already in scope (textually precedes the definition we're
# processing now).
default = tvar_expr.default.accept(
TypeVarLikeDefaultFixer(self, tvar_expr, context=context)
)

if isinstance(tvar_expr, TypeVarExpr):
tvar_def: TypeVarLikeType = TypeVarType(
Expand All @@ -115,7 +156,7 @@ def bind_new(self, name: str, tvar_expr: TypeVarLikeExpr) -> TypeVarLikeType:
id=TypeVarId(i, namespace=namespace),
values=tvar_expr.values,
upper_bound=tvar_expr.upper_bound,
default=tvar_expr.default,
default=default,
variance=tvar_expr.variance,
line=tvar_expr.line,
column=tvar_expr.column,
Expand All @@ -127,7 +168,7 @@ def bind_new(self, name: str, tvar_expr: TypeVarLikeExpr) -> TypeVarLikeType:
id=TypeVarId(i, namespace=namespace),
flavor=ParamSpecFlavor.BARE,
upper_bound=tvar_expr.upper_bound,
default=tvar_expr.default,
default=default,
line=tvar_expr.line,
column=tvar_expr.column,
)
Expand All @@ -138,7 +179,7 @@ def bind_new(self, name: str, tvar_expr: TypeVarLikeExpr) -> TypeVarLikeType:
id=TypeVarId(i, namespace=namespace),
upper_bound=tvar_expr.upper_bound,
tuple_fallback=tvar_expr.tuple_fallback,
default=tvar_expr.default,
default=default,
line=tvar_expr.line,
column=tvar_expr.column,
)
Expand Down
6 changes: 3 additions & 3 deletions mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1561,7 +1561,7 @@ def analyze_callable_type(self, t: UnboundType) -> Type:
# below happens at very early stage.
variables = []
for name, tvar_expr in self.find_type_var_likes(callable_args):
variables.append(self.tvar_scope.bind_new(name, tvar_expr))
variables.append(self.tvar_scope.bind_new(name, tvar_expr, t))
maybe_ret = self.analyze_callable_args_for_paramspec(
callable_args, ret_type, fallback
) or self.analyze_callable_args_for_concatenate(
Expand Down Expand Up @@ -1833,7 +1833,7 @@ def bind_function_type_variables(
assert var_node, "Binding for function type variable not found within function"
var_expr = var_node.node
assert isinstance(var_expr, TypeVarLikeExpr)
binding = self.tvar_scope.bind_new(var.name, var_expr)
binding = self.tvar_scope.bind_new(var.name, var_expr, fun_type)
defs.append(binding)
return tuple(defs), has_self_type
typevars, has_self_type = self.infer_type_variables(fun_type)
Expand All @@ -1846,7 +1846,7 @@ def bind_function_type_variables(
if not self.tvar_scope.allow_binding(tvar.fullname):
err_msg = message_registry.TYPE_VAR_REDECLARED_IN_NESTED_CLASS.format(name)
self.fail(err_msg.value, defn, code=err_msg.code)
binding = self.tvar_scope.bind_new(name, tvar)
binding = self.tvar_scope.bind_new(name, tvar, fun_type)
defs.append(binding)

return tuple(defs), has_self_type
Expand Down
42 changes: 42 additions & 0 deletions test-data/unit/check-typevar-defaults.test
Original file line number Diff line number Diff line change
Expand Up @@ -880,3 +880,45 @@ reveal_type(A1().x) # N: Revealed type is "TypedDict('__main__.TD', {'foo': bui
reveal_type(A2().x) # N: Revealed type is "tuple[builtins.int, fallback=__main__.NT[builtins.int]]"
reveal_type(A3().x) # N: Revealed type is "TypedDict('__main__.TD', {'foo': builtins.int})"
[builtins fixtures/tuple.pyi]

[case testDefaultsApplicationInAliasNoCrash]
# https://github.com/python/mypy/issues/19186
from typing import Generic, TypeVar
from typing_extensions import TypeAlias

T1 = TypeVar("T1")
T2 = TypeVar("T2", default=T1)

Alias: TypeAlias = "MyClass[T1, T2]"

class MyClass(Generic["T1", "T2"]): ...
[builtins fixtures/tuple.pyi]

[case testDefaultsMustBeInScope]
from typing import Generic, TypeVar

T1 = TypeVar("T1")
T2 = TypeVar("T2", default=T1)
T3 = TypeVar("T3", default=T2)

class A(Generic[T1, T2, T3]): ...
reveal_type(A) # N: Revealed type is "def [T1, T2 = T1`1, T3 = T2`2 = T1`1] () -> __main__.A[T1`1, T2`2 = T1`1, T3`3 = T2`2 = T1`1]"
a: A[int]
reveal_type(a) # N: Revealed type is "__main__.A[builtins.int, builtins.int, T1`1]"

class B(Generic[T1, T3]): ... # E: Type variable T2 referenced in the default of T3 is unbound
reveal_type(B) # N: Revealed type is "def [T1, T3 = Any] () -> __main__.B[T1`1, T3`2 = Any]"
b: B[int]
reveal_type(b) # N: Revealed type is "__main__.B[builtins.int, Any]"

class C(Generic[T2]): ... # E: Type variable T1 referenced in the default of T2 is unbound
reveal_type(C) # N: Revealed type is "def [T2 = Any] () -> __main__.C[T2`1 = Any]"
c: C
reveal_type(c) # N: Revealed type is "__main__.C[Any]"

class D(Generic[T2, T1]): ... # E: Type variable T1 referenced in the default of T2 is unbound \
# E: "T1" cannot appear after "T2" in type parameter list because it has no default type
reveal_type(D) # N: Revealed type is "def [T2 = Any, T1 = Any] () -> __main__.D[T2`1 = Any, T1`2 = Any]"
d: D
reveal_type(d) # N: Revealed type is "__main__.D[Any, Any]"
[builtins fixtures/tuple.pyi]