Skip to content

Commit 58082b9

Browse files
committed
Increase test coverage
1 parent 881ffee commit 58082b9

File tree

2 files changed

+66
-32
lines changed

2 files changed

+66
-32
lines changed

src/test_typing_extensions.py

Lines changed: 47 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,40 +1650,50 @@ def test_annotation_and_optional_default(self):
16501650
optional_annotation = Optional[annotation]
16511651

16521652
cases = {
1653-
# (annotation, none_default) : expected_type_hints
1654-
((), False): {},
1653+
# (annotation, skip_as_str): expected_type_hints
1654+
# Should skip_as_str if contains a ForwardRef.
16551655
((), True): {},
1656-
(int, False): {'x': int},
1657-
(int, True): {'x': int},
1658-
(Optional[int], False): {'x': Optional[int]},
1659-
(Optional[int], True): {'x': Optional[int]},
1660-
(optional_annotation, False): {'x': optional_annotation},
1661-
(optional_annotation, True): {'x': optional_annotation},
1662-
(str(optional_annotation), True): {'x': optional_annotation},
1663-
(annotation, False): {'x': annotation},
1664-
(annotation, True): {'x': annotation},
1665-
(Union[annotation, T], False): {'x': Union[annotation, T]},
1666-
(Union[annotation, T], True): {'x': Union[annotation, T]},
1656+
(int, True): {"x": int},
1657+
("int", True): {"x": int},
1658+
(Optional[int], False): {"x": Optional[int]},
1659+
(optional_annotation, False): {"x": optional_annotation},
1660+
(annotation, False): {"x": annotation},
1661+
(Union[annotation, T], True): {"x": Union[annotation, T]},
16671662
("Union[Annotated[Union[int, None], 'data'], T]", True): {
1668-
'x': Union[annotation, T]
1663+
"x": Union[annotation, T]
16691664
},
1670-
(Union[str, None, "str"], False): {'x': Optional[str]},
1671-
(Union[str, None, "str"], True): {'x': Optional[str]},
1672-
(Union[str, "str"], False): {
1673-
'x': str
1674-
if sys.version_info >= (3, 9)
1675-
# _eval_type does not resolve correctly to str in 3.8
1676-
else typing._eval_type(Union[str, "str"], None, None),
1665+
(Union[str, None, str], False): {"x": Optional[str]},
1666+
(Union[str, None, "str"], True): {"x": Optional[str]},
1667+
(Union[str, "str"], True): {"x": str},
1668+
(List["str"], True): {"x": List[str]},
1669+
(Optional[List[str]], False): {"x": Optional[List[str]]},
1670+
(Tuple[Unpack[Tuple[int, str]]], False): {
1671+
"x": Tuple[Unpack[Tuple[int, str]]]
16771672
},
1678-
(Union[str, "str"], True): {'x': str},
1679-
(List["str"], False): {'x': List[str]},
1680-
(List["str"], True): {'x': List[str]},
1681-
(Optional[List[str]], False): {'x': Optional[List[str]]},
1682-
(Optional[List[str]], True): {'x': Optional[List[str]]},
16831673
}
1684-
1685-
for (annot, none_default), expected in cases.items():
1686-
with self.subTest(annotation=annot, none_default=none_default, expected_type_hints=expected):
1674+
for ((annot, skip_as_str), expected), none_default, as_str, wrap_optional in itertools.product(
1675+
cases.items(), (False, True), (False, True), (False, True)
1676+
):
1677+
if wrap_optional:
1678+
if annot == ():
1679+
continue
1680+
if (get_origin(annot) is not Optional
1681+
or (sys.version_info[:2] == (3, 8) and annot._name != "Optional")
1682+
):
1683+
annot = Optional[annot]
1684+
expected = {"x": Optional[expected['x']]}
1685+
if as_str:
1686+
if skip_as_str or annot == ():
1687+
continue
1688+
annot = str(annot)
1689+
with self.subTest(
1690+
annotation=annot,
1691+
as_str=as_str,
1692+
none_default=none_default,
1693+
expected_type_hints=expected,
1694+
wrap_optional=wrap_optional,
1695+
):
1696+
# Create function to check
16871697
if annot == ():
16881698
if none_default:
16891699
def func(x=None): pass
@@ -1693,7 +1703,14 @@ def func(x): pass
16931703
def func(x: annot = None): pass
16941704
else:
16951705
def func(x: annot): pass
1696-
self.assertEqual(get_type_hints(func, include_extras=True), expected)
1706+
type_hints = get_type_hints(func, include_extras=True)
1707+
self.assertEqual(type_hints, expected)
1708+
self.assertEqual(hash(type_hints.values()), hash(expected.values()))
1709+
with self.subTest("Test str and repr"):
1710+
if sys.version_info[:2] == (3, 8) and annot == Union[str, None, "str"]:
1711+
# This also skips Union[str, "str"] wrap_optional=True which has the same problem
1712+
self.skipTest("In 3.8 repr is Union[str, None, str]")
1713+
self.assertEqual(str(type_hints)+repr(type_hints), str(expected)+repr(type_hints))
16971714

16981715

16991716
class GetUtilitiesTestCase(TestCase):

src/typing_extensions.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,6 +1238,22 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
12381238
hint = typing.get_type_hints(obj, globalns=globalns, localns=localns)
12391239
if sys.version_info < (3, 11):
12401240
_clean_optional(obj, hint, globalns, localns)
1241+
# types from get_type_hints might not be from a cached version
1242+
# In 3.8 eval_type does not handle all Optional[ForwardRef] correctly
1243+
# this also returns cached versions of Union and Optional
1244+
if sys.version_info < (3, 9):
1245+
hint = {
1246+
k: (
1247+
t
1248+
if get_origin(t) not in (Union, Optional)
1249+
else (
1250+
Optional[t.__args__[0]]
1251+
if get_origin(t) == Optional
1252+
else Union[t.__args__]
1253+
)
1254+
)
1255+
for k, t in hint.items()
1256+
}
12411257
if include_extras:
12421258
return hint
12431259
return {k: _strip_extras(t) for k, t in hint.items()}
@@ -1261,7 +1277,7 @@ def _clean_optional(obj, hints, globalns=None, localns=None):
12611277
# see https://github.com/python/typing_extensions/issues/310
12621278
if not hints or isinstance(obj, type):
12631279
return
1264-
defaults = typing._get_defaults(obj)
1280+
defaults = typing._get_defaults(obj) # avoid accessing __annotations___
12651281
if not defaults:
12661282
return
12671283
original_hints = obj.__annotations__
@@ -1300,7 +1316,8 @@ def _clean_optional(obj, hints, globalns=None, localns=None):
13001316
original_evaluated = typing._eval_type(original_value, globalns, localns)
13011317
if sys.version_info < (3, 9) and get_origin(original_evaluated) is Union:
13021318
# Union[str, None, "str"] is not reduced to Union[str, None]
1303-
original_evaluated = Union[original_evaluated.__args__]
1319+
container = Optional if original_evaluated._name == "Optional" else Union
1320+
original_evaluated = container[original_evaluated.__args__]
13041321
# Compare if values differ
13051322
if original_evaluated != value:
13061323
hints[name] = original_evaluated

0 commit comments

Comments
 (0)