Skip to content

Commit dfdf7ac

Browse files
committed
Don't single-dispatch on return types
1 parent fbf0843 commit dfdf7ac

File tree

2 files changed

+42
-11
lines changed

2 files changed

+42
-11
lines changed

Lib/functools.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,31 @@ def _is_valid_dispatch_type(cls):
935935
return (isinstance(cls, UnionType) and
936936
all(isinstance(arg, type) for arg in cls.__args__))
937937

938+
def _get_type_hints(func):
939+
ann = getattr(func, '__annotate__', None)
940+
if ann is None:
941+
raise TypeError(
942+
f"Invalid first argument to `register()`: {func!r}. "
943+
f"Use either `@register(some_class)` or plain `@register` "
944+
f"on an annotated function."
945+
)
946+
947+
# only import typing if annotation parsing is necessary
948+
from typing import get_type_hints
949+
from annotationlib import Format
950+
951+
type_hints = get_type_hints(func, format=Format.FORWARDREF)
952+
type_hints.pop("return", None) # don't dispatch on return types
953+
954+
if not type_hints:
955+
raise TypeError(
956+
f"Invalid first argument to `register()`: {func!r}. "
957+
f"Use either `@register(some_class)` or plain `@register` "
958+
f"on a function with annotated parameters."
959+
)
960+
961+
return type_hints
962+
938963
def register(cls, func=None):
939964
"""generic_func.register(cls, func) -> func
940965
@@ -951,20 +976,14 @@ def register(cls, func=None):
951976
f"Invalid first argument to `register()`. "
952977
f"{cls!r} is not a class or union type."
953978
)
954-
ann = getattr(cls, '__annotate__', None)
955-
if ann is None:
956-
raise TypeError(
957-
f"Invalid first argument to `register()`: {cls!r}. "
958-
f"Use either `@register(some_class)` or plain `@register` "
959-
f"on an annotated function."
960-
)
961979
func = cls
980+
type_hints = _get_type_hints(func)
981+
982+
argname, cls = next(iter(type_hints.items()))
962983

963-
# only import typing if annotation parsing is necessary
964-
from typing import get_type_hints
965-
from annotationlib import Format, ForwardRef
966-
argname, cls = next(iter(get_type_hints(func, format=Format.FORWARDREF).items()))
967984
if not _is_valid_dispatch_type(cls):
985+
from annotationlib import ForwardRef
986+
968987
if isinstance(cls, UnionType):
969988
raise TypeError(
970989
f"Invalid annotation for {argname!r}. "

Lib/test/test_functools.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3180,6 +3180,18 @@ def _(arg):
31803180
)
31813181
self.assertEndsWith(str(exc.exception), msg_suffix)
31823182

3183+
with self.assertRaises(TypeError) as exc:
3184+
@i.register
3185+
def _(arg) -> str:
3186+
return "I only have a return type annotation"
3187+
self.assertStartsWith(str(exc.exception), msg_prefix +
3188+
"<function TestSingleDispatch.test_invalid_registrations.<locals>._"
3189+
)
3190+
self.assertEndsWith(str(exc.exception),
3191+
". Use either `@register(some_class)` or plain `@register` on "
3192+
"a function with annotated parameters."
3193+
)
3194+
31833195
with self.assertRaises(TypeError) as exc:
31843196
@i.register
31853197
def _(arg: typing.Iterable[str]):

0 commit comments

Comments
 (0)