Skip to content

Commit 8cc4464

Browse files
committed
Restructured tests and applied suggested changes
1 parent 7f000b1 commit 8cc4464

File tree

2 files changed

+58
-44
lines changed

2 files changed

+58
-44
lines changed

src/test_typing_extensions.py

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1649,32 +1649,47 @@ def test_annotation_and_optional_default(self):
16491649
annotation = Annotated[Union[int, None], "data"]
16501650
optional_annotation = Optional[annotation]
16511651

1652-
def wanted_optional(bar: optional_annotation): ...
1653-
def wanted_optional_default(bar: optional_annotation = None): ...
1654-
def wanted_optional_ref(bar: 'Optional[Annotated[Union[int, None], "data"]]'): ...
1655-
1656-
def no_optional(bar: annotation): ...
1657-
def no_optional_default(bar: annotation = None): ...
1658-
def no_optional_defaultT(bar: Union[annotation, T] = None): ...
1659-
def no_optional_defaultT_ref(bar: "Union[annotation, T]" = None): ...
1660-
1661-
for func in(wanted_optional, wanted_optional_default, wanted_optional_ref):
1662-
self.assertEqual(
1663-
get_type_hints(func, include_extras=True),
1664-
{"bar": optional_annotation}
1665-
)
1666-
1667-
for func in (no_optional, no_optional_default):
1668-
self.assertEqual(
1669-
get_type_hints(func, include_extras=True),
1670-
{"bar": annotation}
1671-
)
1652+
cases = {
1653+
((), False): {},
1654+
((), True): {},
1655+
(int, False): {"x": int},
1656+
(int, True): {"x": int},
1657+
(Optional[int], False): {"x": Optional[int]},
1658+
(Optional[int], True): {"x": Optional[int]},
1659+
(optional_annotation, False): {"x": optional_annotation},
1660+
(optional_annotation, True): {"x": optional_annotation},
1661+
(str(optional_annotation), True): {"x": optional_annotation},
1662+
(annotation, False): {"x": annotation},
1663+
(annotation, True): {"x": annotation},
1664+
(Union[annotation, T], False): {"x": Union[annotation, T]},
1665+
(Union[annotation, T], True): {"x": Union[annotation, T]},
1666+
(Union[str, None, "str"], False): {"x": Optional[str]},
1667+
(Union[str, None, "str"], True): {"x": Optional[str]},
1668+
(Union[str, "str"], False): {
1669+
"x": str
1670+
if sys.version_info >= (3, 9)
1671+
# _eval_type does not resolve correctly to str in 3.8
1672+
else typing._eval_type(Union[str, "str"], None, None),
1673+
},
1674+
(Union[str, "str"], True): {"x": str},
1675+
(List["str"], False): {"x": List[str]},
1676+
(List["str"], True): {"x": List[str]},
1677+
(Optional[List[str]], False): {"x": Optional[List[str]]},
1678+
(Optional[List[str]], True): {"x": Optional[List[str]]},
1679+
}
16721680

1673-
for func in (no_optional_defaultT, no_optional_defaultT_ref):
1674-
self.assertEqual(
1675-
get_type_hints(func, globals(), locals(), include_extras=True),
1676-
{"bar": Union[annotation, T]}
1677-
)
1681+
for (annot, none_default), expected in cases.items():
1682+
with self.subTest(annotation=annot, none_default=none_default, expected_type_hints=expected):
1683+
if annot == ():
1684+
if none_default:
1685+
def func(x = None): pass
1686+
else:
1687+
def func(x): pass
1688+
elif none_default:
1689+
def func(x: annot = None): pass
1690+
else:
1691+
def func(x: annot): pass
1692+
self.assertEqual(get_type_hints(func, include_extras=True), expected)
16781693

16791694

16801695
class GetUtilitiesTestCase(TestCase):

src/typing_extensions.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,8 +1236,8 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
12361236
)
12371237
else: # 3.8
12381238
hint = typing.get_type_hints(obj, globalns=globalns, localns=localns)
1239-
if sys.version_info < (3, 11) and hint:
1240-
hint = _clean_optional(obj, hint, globalns, localns)
1239+
if sys.version_info < (3, 11):
1240+
_clean_optional(obj, hint, globalns, localns)
12411241
if include_extras:
12421242
return hint
12431243
return {k: _strip_extras(t) for k, t in hint.items()}
@@ -1247,7 +1247,7 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
12471247
def _could_be_inserted_optional(t):
12481248
"""detects Union[..., None] pattern"""
12491249
# 3.8+ compatible checking before _UnionGenericAlias
1250-
if not hasattr(t, "__origin__") or t.__origin__ is not Union:
1250+
if get_origin(t) is not Union:
12511251
return False
12521252
# Assume if last argument is not None they are user defined
12531253
if t.__args__[-1] is not _NoneType:
@@ -1259,8 +1259,12 @@ def _clean_optional(obj, hints, globalns=None, localns=None):
12591259
# reverts injected Union[..., None] cases from typing.get_type_hints
12601260
# when a None default value is used.
12611261
# see https://github.com/python/typing_extensions/issues/310
1262-
original_hints = getattr(obj, '__annotations__', None)
1262+
if not hints or isinstance(obj, type):
1263+
return
12631264
defaults = typing._get_defaults(obj)
1265+
if not defaults:
1266+
return
1267+
original_hints = obj.__annotations__
12641268
for name, value in hints.items():
12651269
# Not a Union[..., None] or replacement conditions not fullfilled
12661270
if (not _could_be_inserted_optional(value)
@@ -1269,7 +1273,7 @@ def _clean_optional(obj, hints, globalns=None, localns=None):
12691273
):
12701274
continue
12711275
original_value = original_hints[name]
1272-
if original_value is None:
1276+
if original_value is None: # should not happen
12731277
original_value = _NoneType
12741278
# Forward reference
12751279
if isinstance(original_value, str):
@@ -1287,24 +1291,19 @@ def _clean_optional(obj, hints, globalns=None, localns=None):
12871291
elif localns is None:
12881292
localns = globalns
12891293
if sys.version_info < (3, 9):
1290-
ref = ForwardRef(original_value)
1294+
original_value = ForwardRef(original_value)
12911295
else:
1292-
ref = ForwardRef(
1296+
original_value = ForwardRef(
12931297
original_value,
12941298
is_argument=not isinstance(obj, _types.ModuleType)
12951299
)
1296-
original_value = typing._eval_type(ref, globalns, localns)
1297-
# Values was not modified or original is already Optional
1298-
if original_value == value or _could_be_inserted_optional(original_value):
1299-
continue
1300-
# NoneType was added to value
1301-
if len(value.__args__) == 2:
1302-
hints[name] = value.__args__[0] # not a Union
1303-
else:
1304-
hints[name] = Union[value.__args__[:-1]] # still a Union
1305-
1306-
return hints
1307-
1300+
original_evaluated = typing._eval_type(original_value, globalns, localns)
1301+
if sys.version_info < (3, 9) and get_origin(original_evaluated) is Union:
1302+
# Union[str, None, "str"] is not reduced to Union[str, None]
1303+
original_evaluated = Union[original_evaluated.__args__]
1304+
# Compare if values differ
1305+
if original_evaluated != value:
1306+
hints[name] = original_evaluated
13081307

13091308
# Python 3.9+ has PEP 593 (Annotated)
13101309
if hasattr(typing, 'Annotated'):

0 commit comments

Comments
 (0)