Skip to content
Merged
69 changes: 59 additions & 10 deletions mypy/stubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
from typing_extensions import get_origin, is_typeddict

import mypy.build
import mypy.checkexpr
import mypy.checkmember
import mypy.erasetype
import mypy.modulefinder
import mypy.nodes
import mypy.state
Expand Down Expand Up @@ -792,7 +795,11 @@ def _verify_arg_default_value(
"has a default value but stub parameter does not"
)
else:
runtime_type = get_mypy_type_of_runtime_value(runtime_arg.default)
type_context = stub_arg.variable.type
runtime_type = get_mypy_type_of_runtime_value(
runtime_arg.default, type_context=type_context
)

# Fallback to the type annotation type if var type is missing. The type annotation
# is an UnboundType, but I don't know enough to know what the pros and cons here are.
# UnboundTypes have ugly question marks following them, so default to var type.
Expand Down Expand Up @@ -1247,7 +1254,7 @@ def verify_var(
):
yield Error(object_path, "is read-only at runtime but not in the stub", stub, runtime)

runtime_type = get_mypy_type_of_runtime_value(runtime)
runtime_type = get_mypy_type_of_runtime_value(runtime, type_context=stub.type)
if (
runtime_type is not None
and stub.type is not None
Expand Down Expand Up @@ -1832,7 +1839,18 @@ def is_subtype_helper(left: mypy.types.Type, right: mypy.types.Type) -> bool:
return mypy.subtypes.is_subtype(left, right)


def get_mypy_type_of_runtime_value(runtime: Any) -> mypy.types.Type | None:
def get_mypy_node_for_name(module: str, type_name: str) -> mypy.nodes.SymbolNode | None:
stub = get_stub(module)
if stub is None:
return None
if type_name not in stub.names:
return None
return stub.names[type_name].node


def get_mypy_type_of_runtime_value(
runtime: Any, type_context: mypy.types.Type | None = None
) -> mypy.types.Type | None:
"""Returns a mypy type object representing the type of ``runtime``.

Returns None if we can't find something that works.
Expand Down Expand Up @@ -1893,14 +1911,45 @@ def anytype() -> mypy.types.AnyType:
is_ellipsis_args=True,
)

# Try and look up a stub for the runtime object
stub = get_stub(type(runtime).__module__)
if stub is None:
return None
type_name = type(runtime).__name__
if type_name not in stub.names:
skip_type_object_type = False
if type_context:
# Don't attempt to process the type object when context is generic
# This is related to issue #3737
type_context = mypy.types.get_proper_type(type_context)
# Callable types with a generic return value
if isinstance(type_context, mypy.types.CallableType):
if isinstance(type_context.ret_type, mypy.types.TypeVarType):
skip_type_object_type = True
# Type[x] where x is generic
if isinstance(type_context, mypy.types.TypeType):
if isinstance(type_context.item, mypy.types.TypeVarType):
skip_type_object_type = True

if isinstance(runtime, type) and not skip_type_object_type:

def _named_type(name: str) -> mypy.types.Instance:
parts = name.rsplit(".", maxsplit=1)
node = get_mypy_node_for_name(parts[0], parts[1])
assert isinstance(node, nodes.TypeInfo)
any_type = mypy.types.AnyType(mypy.types.TypeOfAny.special_form)
return mypy.types.Instance(node, [any_type] * len(node.defn.type_vars))

# Try and look up a stub for the runtime object itself
# The logic here is similar to ExpressionChecker.analyze_ref_expr
type_info = get_mypy_node_for_name(runtime.__module__, runtime.__name__)
if isinstance(type_info, nodes.TypeInfo):
result: mypy.types.Type | None = None
result = mypy.typeops.type_object_type(type_info, _named_type)
if mypy.checkexpr.is_type_type_context(type_context):
# This is the type in a type[] expression, so substitute type
# variables with Any.
result = mypy.erasetype.erase_typevars(result)
return result

# Try and look up a stub for the runtime object's type
type_info = get_mypy_node_for_name(type(runtime).__module__, type(runtime).__name__)
if type_info is None:
return None
type_info = stub.names[type_name].node
if isinstance(type_info, nodes.Var):
return type_info.type
if not isinstance(type_info, nodes.TypeInfo):
Expand Down
25 changes: 25 additions & 0 deletions mypy/test/teststubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2636,6 +2636,31 @@ class _X1: ...
error=None,
)

@collect_cases
def test_type_default_protocol(self) -> Iterator[Case]:
yield Case(
stub="""
from typing import Protocol

class _FormatterClass(Protocol):
def __call__(self, *, prog: str) -> HelpFormatter: ...

class ArgumentParser:
def __init__(self, formatter_class: _FormatterClass = ...) -> None: ...

class HelpFormatter:
def __init__(self, prog: str, indent_increment: int = 2) -> None: ...
""",
runtime="""
class HelpFormatter:
def __init__(self, prog, indent_increment=2) -> None: ...

class ArgumentParser:
def __init__(self, formatter_class=HelpFormatter): ...
""",
error=None,
)


def remove_color_code(s: str) -> str:
return re.sub("\\x1b.*?m", "", s) # this works!
Expand Down
Loading