Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Lib/annotationlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def __init__(
self.__forward_is_argument__ = is_argument
self.__forward_is_class__ = is_class
self.__forward_module__ = module
# These are always set to None here but may be non-None if a ForwardRef
# is created through __class__ assignment on a _Stringifier object.
self.__globals__ = None
self.__code__ = None
self.__ast_node__ = None
Expand Down
161 changes: 160 additions & 1 deletion Lib/test/test_annotationlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import functools
import itertools
import pickle
import typing
import unittest
from annotationlib import (
Format,
Expand All @@ -15,7 +16,12 @@
annotations_to_string,
type_repr,
)
from typing import Unpack
from typing import (
Unpack,
get_type_hints,
List,
Union,
)

from test import support
from test.test_inspect import inspect_stock_annotations
Expand Down Expand Up @@ -1205,6 +1211,159 @@ def test_annotations_to_string(self):
)


class A:
pass


class ForwardRefTests(unittest.TestCase):
def test_forwardref_instance_type_error(self):
fr = ForwardRef('int')
with self.assertRaises(TypeError):
isinstance(42, fr)

def test_forwardref_subclass_type_error(self):
fr = ForwardRef('int')
with self.assertRaises(TypeError):
issubclass(int, fr)

def test_forwardref_only_str_arg(self):
with self.assertRaises(TypeError):
ForwardRef(1) # only `str` type is allowed

def test_forward_equality(self):
fr = ForwardRef('int')
self.assertEqual(fr, ForwardRef('int'))
self.assertNotEqual(List['int'], List[int])
self.assertNotEqual(fr, ForwardRef('int', module=__name__))
frm = ForwardRef('int', module=__name__)
self.assertEqual(frm, ForwardRef('int', module=__name__))
self.assertNotEqual(frm, ForwardRef('int', module='__other_name__'))

def test_forward_equality_gth(self):
c1 = ForwardRef('C')
c1_gth = ForwardRef('C')
c2 = ForwardRef('C')
c2_gth = ForwardRef('C')

class C:
pass
def foo(a: c1_gth, b: c2_gth):
pass

self.assertEqual(get_type_hints(foo, globals(), locals()), {'a': C, 'b': C})
self.assertEqual(c1, c2)
self.assertEqual(c1, c1_gth)
self.assertEqual(c1_gth, c2_gth)
self.assertEqual(List[c1], List[c1_gth])
self.assertNotEqual(List[c1], List[C])
self.assertNotEqual(List[c1_gth], List[C])
self.assertEqual(Union[c1, c1_gth], Union[c1])
self.assertEqual(Union[c1, c1_gth, int], Union[c1, int])

def test_forward_equality_hash(self):
c1 = ForwardRef('int')
c1_gth = ForwardRef('int')
c2 = ForwardRef('int')
c2_gth = ForwardRef('int')

def foo(a: c1_gth, b: c2_gth):
pass
get_type_hints(foo, globals(), locals())

self.assertEqual(hash(c1), hash(c2))
self.assertEqual(hash(c1_gth), hash(c2_gth))
self.assertEqual(hash(c1), hash(c1_gth))

c3 = ForwardRef('int', module=__name__)
c4 = ForwardRef('int', module='__other_name__')

self.assertNotEqual(hash(c3), hash(c1))
self.assertNotEqual(hash(c3), hash(c1_gth))
self.assertNotEqual(hash(c3), hash(c4))
self.assertEqual(hash(c3), hash(ForwardRef('int', module=__name__)))

def test_forward_equality_namespace(self):
def namespace1():
a = ForwardRef('A')
def fun(x: a):
pass
get_type_hints(fun, globals(), locals())
return a

def namespace2():
a = ForwardRef('A')

class A:
pass
def fun(x: a):
pass

get_type_hints(fun, globals(), locals())
return a

self.assertEqual(namespace1(), namespace1())
self.assertEqual(namespace1(), namespace2())

def test_forward_repr(self):
self.assertEqual(repr(List['int']), "typing.List[ForwardRef('int')]")
self.assertEqual(repr(List[ForwardRef('int', module='mod')]),
"typing.List[ForwardRef('int', module='mod')]")

def test_forward_recursion_actually(self):
def namespace1():
a = ForwardRef('A')
A = a
def fun(x: a): pass

ret = get_type_hints(fun, globals(), locals())
return a

def namespace2():
a = ForwardRef('A')
A = a
def fun(x: a): pass

ret = get_type_hints(fun, globals(), locals())
return a

r1 = namespace1()
r2 = namespace2()
self.assertIsNot(r1, r2)
self.assertEqual(r1, r2)

def test_syntax_error(self):

with self.assertRaises(SyntaxError):
typing.Generic['/T']

def test_delayed_syntax_error(self):

def foo(a: 'Node[T'):
pass

with self.assertRaises(SyntaxError):
get_type_hints(foo)

def test_syntax_error_empty_string(self):
for form in [typing.List, typing.Set, typing.Type, typing.Deque]:
with self.subTest(form=form):
with self.assertRaises(SyntaxError):
form['']

def test_or(self):
X = ForwardRef('X')
# __or__/__ror__ itself
self.assertEqual(X | "x", Union[X, "x"])
self.assertEqual("x" | X, Union["x", X])

def test_multiple_ways_to_create(self):
X1 = Union["X"]
self.assertIsInstance(X1, ForwardRef)
X2 = ForwardRef("X")
self.assertIsInstance(X2, ForwardRef)
self.assertEqual(X1, X2)


class TestAnnotationLib(unittest.TestCase):
def test__all__(self):
support.check__all__(self, annotationlib)
Loading
Loading