@@ -935,6 +935,31 @@ def _is_valid_dispatch_type(cls):
935935 return (isinstance (cls , UnionType ) and
936936 all (isinstance (arg , type ) for arg in cls .__args__ ))
937937
938+ def _get_type_hints (func ):
939+ ann = getattr (func , '__annotate__' , None )
940+ if ann is None :
941+ raise TypeError (
942+ f"Invalid first argument to `register()`: { func !r} . "
943+ f"Use either `@register(some_class)` or plain `@register` "
944+ f"on an annotated function."
945+ )
946+
947+ # only import typing if annotation parsing is necessary
948+ from typing import get_type_hints
949+ from annotationlib import Format
950+
951+ type_hints = get_type_hints (func , format = Format .FORWARDREF )
952+ type_hints .pop ("return" , None ) # don't dispatch on return types
953+
954+ if not type_hints :
955+ raise TypeError (
956+ f"Invalid first argument to `register()`: { func !r} . "
957+ f"Use either `@register(some_class)` or plain `@register` "
958+ f"on a function with annotated parameters."
959+ )
960+
961+ return type_hints
962+
938963 def register (cls , func = None ):
939964 """generic_func.register(cls, func) -> func
940965
@@ -951,20 +976,14 @@ def register(cls, func=None):
951976 f"Invalid first argument to `register()`. "
952977 f"{ cls !r} is not a class or union type."
953978 )
954- ann = getattr (cls , '__annotate__' , None )
955- if ann is None :
956- raise TypeError (
957- f"Invalid first argument to `register()`: { cls !r} . "
958- f"Use either `@register(some_class)` or plain `@register` "
959- f"on an annotated function."
960- )
961979 func = cls
980+ type_hints = _get_type_hints (func )
981+
982+ argname , cls = next (iter (type_hints .items ()))
962983
963- # only import typing if annotation parsing is necessary
964- from typing import get_type_hints
965- from annotationlib import Format , ForwardRef
966- argname , cls = next (iter (get_type_hints (func , format = Format .FORWARDREF ).items ()))
967984 if not _is_valid_dispatch_type (cls ):
985+ from annotationlib import ForwardRef
986+
968987 if isinstance (cls , UnionType ):
969988 raise TypeError (
970989 f"Invalid annotation for { argname !r} . "
0 commit comments