Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions Lib/annotationlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,15 @@ def __init__(
self.__forward_is_argument__ = is_argument
self.__forward_is_class__ = is_class
self.__forward_module__ = module
self.__owner__ = owner
# These are always set to None here but may be non-None if a ForwardRef
# is created through __class__ assignment on a _Stringifier object.
self.__globals__ = None
self.__cell__ = None
# These are initially None but serve as a cache and may be set to a non-None
# value later.
self.__code__ = None
self.__ast_node__ = None
self.__cell__ = None
self.__owner__ = owner

def __init_subclass__(cls, /, *args, **kwds):
raise TypeError("Cannot subclass ForwardRef")
Expand Down
161 changes: 160 additions & 1 deletion Lib/test/test_annotationlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import functools
import itertools
import pickle
import typing
import unittest
from annotationlib import (
Format,
Expand All @@ -15,7 +16,12 @@
annotations_to_string,
type_repr,
)
from typing import Unpack
from typing import (
Unpack,
get_type_hints,
List,
Union,
)

from test import support
from test.test_inspect import inspect_stock_annotations
Expand Down Expand Up @@ -1205,6 +1211,159 @@ def test_annotations_to_string(self):
)


class A:
pass


class ForwardRefTests(unittest.TestCase):
def test_forwardref_instance_type_error(self):
fr = ForwardRef('int')
with self.assertRaises(TypeError):
isinstance(42, fr)

def test_forwardref_subclass_type_error(self):
fr = ForwardRef('int')
with self.assertRaises(TypeError):
issubclass(int, fr)

def test_forwardref_only_str_arg(self):
with self.assertRaises(TypeError):
ForwardRef(1) # only `str` type is allowed

def test_forward_equality(self):
fr = ForwardRef('int')
self.assertEqual(fr, ForwardRef('int'))
self.assertNotEqual(List['int'], List[int])
self.assertNotEqual(fr, ForwardRef('int', module=__name__))
frm = ForwardRef('int', module=__name__)
self.assertEqual(frm, ForwardRef('int', module=__name__))
self.assertNotEqual(frm, ForwardRef('int', module='__other_name__'))

def test_forward_equality_get_type_hints(self):
c1 = ForwardRef('C')
c1_gth = ForwardRef('C')
c2 = ForwardRef('C')
c2_gth = ForwardRef('C')

class C:
pass
def foo(a: c1_gth, b: c2_gth):
pass

self.assertEqual(get_type_hints(foo, globals(), locals()), {'a': C, 'b': C})
self.assertEqual(c1, c2)
self.assertEqual(c1, c1_gth)
self.assertEqual(c1_gth, c2_gth)
self.assertEqual(List[c1], List[c1_gth])
self.assertNotEqual(List[c1], List[C])
self.assertNotEqual(List[c1_gth], List[C])
self.assertEqual(Union[c1, c1_gth], Union[c1])
self.assertEqual(Union[c1, c1_gth, int], Union[c1, int])

def test_forward_equality_hash(self):
c1 = ForwardRef('int')
c1_gth = ForwardRef('int')
c2 = ForwardRef('int')
c2_gth = ForwardRef('int')

def foo(a: c1_gth, b: c2_gth):
pass
get_type_hints(foo, globals(), locals())

self.assertEqual(hash(c1), hash(c2))
self.assertEqual(hash(c1_gth), hash(c2_gth))
self.assertEqual(hash(c1), hash(c1_gth))

c3 = ForwardRef('int', module=__name__)
c4 = ForwardRef('int', module='__other_name__')

self.assertNotEqual(hash(c3), hash(c1))
self.assertNotEqual(hash(c3), hash(c1_gth))
self.assertNotEqual(hash(c3), hash(c4))
self.assertEqual(hash(c3), hash(ForwardRef('int', module=__name__)))

def test_forward_equality_namespace(self):
def namespace1():
a = ForwardRef('A')
def fun(x: a):
pass
get_type_hints(fun, globals(), locals())
return a

def namespace2():
a = ForwardRef('A')

class A:
pass
def fun(x: a):
pass

get_type_hints(fun, globals(), locals())
return a

self.assertEqual(namespace1(), namespace1())
self.assertEqual(namespace1(), namespace2())

def test_forward_repr(self):
self.assertEqual(repr(List['int']), "typing.List[ForwardRef('int')]")
self.assertEqual(repr(List[ForwardRef('int', module='mod')]),
"typing.List[ForwardRef('int', module='mod')]")

def test_forward_recursion_actually(self):
def namespace1():
a = ForwardRef('A')
A = a
def fun(x: a): pass

ret = get_type_hints(fun, globals(), locals())
return a

def namespace2():
a = ForwardRef('A')
A = a
def fun(x: a): pass

ret = get_type_hints(fun, globals(), locals())
return a

r1 = namespace1()
r2 = namespace2()
self.assertIsNot(r1, r2)
self.assertEqual(r1, r2)

def test_syntax_error(self):

with self.assertRaises(SyntaxError):
typing.Generic['/T']

def test_delayed_syntax_error(self):

def foo(a: 'Node[T'):
pass

with self.assertRaises(SyntaxError):
get_type_hints(foo)

def test_syntax_error_empty_string(self):
for form in [typing.List, typing.Set, typing.Type, typing.Deque]:
with self.subTest(form=form):
with self.assertRaises(SyntaxError):
form['']

def test_or(self):
X = ForwardRef('X')
# __or__/__ror__ itself
self.assertEqual(X | "x", Union[X, "x"])
self.assertEqual("x" | X, Union["x", X])

def test_multiple_ways_to_create(self):
X1 = Union["X"]
self.assertIsInstance(X1, ForwardRef)
X2 = ForwardRef("X")
self.assertIsInstance(X2, ForwardRef)
self.assertEqual(X1, X2)


class TestAnnotationLib(unittest.TestCase):
def test__all__(self):
support.check__all__(self, annotationlib)
Loading
Loading