Skip to content
Closed
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions Lib/test/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import collections
import collections.abc
from collections import defaultdict
from collections.abc import Callable as ABCallable
from functools import lru_cache, wraps, reduce
import gc
import inspect
Expand Down Expand Up @@ -42,6 +43,7 @@
from typing import TypeAlias
from typing import ParamSpec, Concatenate, ParamSpecArgs, ParamSpecKwargs
from typing import TypeGuard, TypeIs, NoDefault
from typing import _eval_type
import abc
import textwrap
import typing
Expand Down Expand Up @@ -10668,9 +10670,126 @@ def test_eq(self):
with self.assertWarns(DeprecationWarning):
self.assertNotEqual(int, typing._UnionGenericAlias)

class MyType:
pass

class TestGenericAliasHandling(BaseTestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of these tests pass on master today without your changes in typing.py.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not the one I sent in the comments!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't be adding tests that are unrelated to the change.


def test_forward_ref(self):
fwd_ref = ForwardRef('MyType')

def func(arg: fwd_ref):
pass

result = get_type_hints(func)
self.assertEqual(result['arg'], MyType, f"Expected MyType, got {result['arg']}")

def test_generic_alias(self):
fwd_ref = ForwardRef('MyType')
generic_list = List[fwd_ref]

def func(arg: generic_list):
pass

result = get_type_hints(func)
self.assertEqual(result['arg'], List[MyType], f"Expected List[MyType], got {result['arg']}")

def test_union(self):
fwd_ref_1 = ForwardRef('MyType')
fwd_ref_2 = ForwardRef('int')
union_type = Union[fwd_ref_1, fwd_ref_2]

def func(arg: union_type):
pass

result = get_type_hints(func)
self.assertEqual(result['arg'], Union[MyType, int], f"Expected Union[MyType, int], got {result['arg']}")

def test_recursive_forward_ref(self):
recursive_ref = ForwardRef('RecursiveType')
globals()['RecursiveType'] = recursive_ref
recursive_type = Dict[str, List[recursive_ref]]

def func(arg: recursive_type):
pass

result = get_type_hints(func)
self.assertEqual(result['arg'], Dict[str, List[recursive_ref]], f"Expected Dict[str, List[RecursiveType]], got {result['arg']}")

def test_callable_unpacking(self):
fwd_ref = ForwardRef('MyType')
callable_type = Callable[[fwd_ref, int], str]

def func(arg1: fwd_ref, arg2: int) -> str:
return "test"

result = get_type_hints(func)
self.assertEqual(result['arg1'], MyType, f"Expected MyType for arg1, got {result['arg1']}")
self.assertEqual(result['arg2'], int, f"Expected int for arg2, got {result['arg2']}")
self.assertEqual(result['return'], str, f"Expected str for return, got {result['return']}")

def test_unpacked_generic(self):
fwd_ref = ForwardRef('MyType')
generic_type = Tuple[fwd_ref, int]

def func(arg: generic_type):
pass

result = get_type_hints(func)
self.assertEqual(result['arg'], Tuple[MyType, int], f"Expected Tuple[MyType, int], got {result['arg']}")

def test_preservation_of_type(self):
fwd_ref_1 = ForwardRef('MyType')
fwd_ref_2 = ForwardRef('int')
complex_type = Dict[str, Union[fwd_ref_1, fwd_ref_2]]

def func(arg: complex_type):
pass

result = get_type_hints(func)
self.assertEqual(result['arg'], Dict[str, Union[MyType, int]], f"Expected Dict[str, Union[MyType, int]], got {result['arg']}")

def test_callable_unflattening(self):
callable_type = Callable[[int, str], bool]

def func(arg1: int, arg2: str) -> bool:
return True

result = get_type_hints(func)
self.assertEqual(result['arg1'], int, f"Expected int for arg1, got {result['arg1']}")
self.assertEqual(result['arg2'], str, f"Expected str for arg2, got {result['arg2']}")
self.assertEqual(result['return'], bool, f"Expected bool for return, got {result['return']}")

callable_type_packed = Callable[[int, str], bool]

def func_packed(arg1: int, arg2: str) -> bool:
return True

result = get_type_hints(func_packed)
self.assertEqual(result['arg1'], int, f"Expected int for arg1, got {result['arg1']}")
self.assertEqual(result['arg2'], str, f"Expected str for arg2, got {result['arg2']}")
self.assertEqual(result['return'], bool, f"Expected bool for return, got {result['return']}")

def test_hashable(self):
self.assertEqual(hash(typing._UnionGenericAlias), hash(Union))

class TestCallableAlias(BaseTestCase):
def test_callable_alias_preserves_subclass(self):
C = ABCallable[[str, ForwardRef('int')], int]
class A:
c: C
# Explicitly pass global namespace to ensure correct resolution
hints = get_type_hints(A, globalns=globals())

# Ensure evaluated type retains the correct subclass (_CallableGenericAlias)
self.assertEqual(hints['c'].__class__, C.__class__)

# Ensure evaluated type retains correct origin
self.assertEqual(hints['c'].__origin__, C.__origin__)

# Instead of comparing raw ForwardRef, check if the resolution is correct
expected_args = tuple(int if isinstance(arg, ForwardRef) else arg for arg in C.__args__)
self.assertEqual(hints['c'].__args__, expected_args)

def load_tests(loader, tests, pattern):
import doctest
Expand Down
4 changes: 3 additions & 1 deletion Lib/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,9 @@ def _eval_type(t, globalns, localns, type_params=_sentinel, *, recursive_guard=f
if ev_args == t.__args__:
return t
if isinstance(t, GenericAlias):
return GenericAlias(t.__origin__, ev_args)
if _should_unflatten_callable_args(t, ev_args):
return t.__class__(t.__origin__, (ev_args[:-1], ev_args[-1]))
return t.__class__(t.__origin__, ev_args)
if isinstance(t, Union):
return functools.reduce(operator.or_, ev_args)
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Ensure that typing.Callable retains its subclass (_CallableGenericAlias) instead of being incorrectly converted to GenericAlias.
Loading