diff --git a/Lib/functools.py b/Lib/functools.py index 714070c6ac9460..3cd4c23c7104cf 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -843,8 +843,12 @@ def is_strict_base(typ): mro.append(subcls) return _c3_mro(cls, abcs=mro) -def _find_impl(cls, registry): - """Returns the best matching implementation from *registry* for type *cls*. +def _pep585_registry_matches(cls, registry): + from typing import get_origin + return (i for i in registry.keys() if get_origin(i) == cls) + +def _find_impl_match(cls_obj, registry): + """Returns the best matching implementation from *registry* for type *cls_obj*. Where there is no registered implementation for a specific type, its method resolution order is used to find a more generic implementation. @@ -853,8 +857,35 @@ def _find_impl(cls, registry): *object* type, this function may return None. """ + cls = cls_obj if isinstance(cls_obj, type) else cls_obj.__class__ mro = _compose_mro(cls, registry.keys()) match = None + + from typing import get_origin, get_args + + if (not isinstance(cls_obj, type) and + len(cls_obj) > 0 and # dont try to match the types of empty containers + any(_pep585_registry_matches(cls, registry))): + # check containers that match cls first + for t in _pep585_registry_matches(cls, registry): + if not all((isinstance(i, get_args(t)) for i in cls_obj)): + continue + + if match is None: + match = t + + else: + match_args = get_args(get_args(match)[0]) + t_args = get_args(get_args(t)[0]) + if len(match_args) == len(t_args): + raise RuntimeError("Ambiguous dispatch: {} or {}".format( match, t)) + + elif len(t_args) Runs the dispatch algorithm to return the best available implementation for the given *cls* registered on *generic_func*. """ + cls = cls_obj.__class__ nonlocal cache_token if cache_token is not None: current_token = get_cache_token() if cache_token != current_token: dispatch_cache.clear() cache_token = current_token - try: - impl = dispatch_cache[cls] - except KeyError: - try: - impl = registry[cls] - except KeyError: - impl = _find_impl(cls, registry) - dispatch_cache[cls] = impl - return impl + + # if PEP-585 types are not registered for the given *cls*, + # then we can use the cache. Otherwise, the cache cannot be used + # because we need to confirm every item matches first + if not any(_pep585_registry_matches(cls, registry)): + return _fetch_dispatch_with_cache(cls) + + return _find_impl(cls_obj, registry) def _is_valid_dispatch_type(cls): if isinstance(cls, type): return True + + if isinstance(cls, GenericAlias): + from typing import get_args + return all(isinstance(arg, (type, UnionType)) for arg in get_args(cls)) + return (isinstance(cls, UnionType) and - all(isinstance(arg, type) for arg in cls.__args__)) + all(isinstance(arg, (type, GenericAlias)) for arg in cls.__args__)) + def register(cls, func=None): """generic_func.register(cls, func) -> func @@ -932,6 +988,7 @@ def register(cls, func=None): f"Invalid first argument to `register()`. " f"{cls!r} is not a class or union type." ) + ann = getattr(cls, '__annotate__', None) if ann is None: raise TypeError( @@ -976,7 +1033,7 @@ def wrapper(*args, **kw): if not args: raise TypeError(f'{funcname} requires at least ' '1 positional argument') - return dispatch(args[0].__class__)(*args, **kw) + return dispatch(args[0])(*args, **kw) funcname = getattr(func, '__name__', 'singledispatch function') registry[object] = func @@ -1064,7 +1121,7 @@ def __call__(self, /, *args, **kwargs): 'singledispatchmethod method') raise TypeError(f'{funcname} requires at least ' '1 positional argument') - return self._dispatch(args[0].__class__).__get__(self._obj, self._cls)(*args, **kwargs) + return self._dispatch(args[0]).__get__(self._obj, self._cls)(*args, **kwargs) def __getattr__(self, name): # Resolve these attributes lazily to speed up creation of diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index 2b49615178f136..fbc6623e83be01 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -1,3 +1,4 @@ +from __future__ import annotations import abc import builtins import collections @@ -2136,6 +2137,250 @@ def cached_staticmeth(x, y): class TestSingleDispatch(unittest.TestCase): + + def test_pep585_basic(self): + @functools.singledispatch + def g(obj): + return "base" + def g_list_int(li): + return "list of ints" + # previously this failed with: 'not a class' + g.register(list[int], g_list_int) + self.assertEqual(g([1]), "list of ints") + self.assertIs(g.dispatch(list[int]), g_list_int) + + def test_pep585_annotation(self): + @functools.singledispatch + def g(obj): + return "base" + # previously this failed with: 'not a class' + @g.register + def g_list_int(li: list[int]): + return "list of ints" + self.assertEqual(g([1,2,3]), "list of ints") + self.assertIs(g.dispatch(tuple[int]), g_list_int) + + def test_pep585_all_must_match(self): + @functools.singledispatch + def g(obj): + return "base" + def g_list_int(li): + return "list of ints" + def g_list_not_ints(l): + # should only trigger if list doesnt match `list[int]` + # ie. at least one element is not an int + return "!all(int)" + + g.register(list[int], g_list_int) + g.register(list, g_list_not_ints) + + self.assertEqual(g([1,2,3]), "list of ints") + self.assertEqual(g([1,2,3, "hello"]), "!all(int)") + self.assertEqual(g([3.14]), "!all(int)") + + self.assertIs(g.dispatch(list[int]), g_list_int) + self.assertIs(g.dispatch(list[str]), g_list_not_ints) + self.assertIs(g.dispatch(list[float]), g_list_not_ints) + self.assertIs(g.dispatch(list[int|str]), g_list_not_ints) + + def test_pep585_specificity(self): + @functools.singledispatch + def g(obj): + return "base" + @g.register + def g_list(l: list): + return "basic list" + @g.register + def g_list_int(li: list[int]): + return "int" + @g.register + def g_list_str(ls: list[str]): + return "str" + @g.register + def g_list_mixed_int_str(lmis:list[int|str]): + return "int|str" + @g.register + def g_list_mixed_int_float(lmif: list[int|float]): + return "int|float" + @g.register + def g_list_mixed_int_float_str(lmifs: list[int|float|str]): + return "int|float|str" + + # this matches list, list[int], list[int|str], list[int|float|str], list[int|...|...|...|...] + # but list[int] is the most specific, so that is correct + self.assertEqual(g([1,2,3]), "int") + + # this cannot match list[int] because of the string + # it does match list[int|float|str] but this is incorrect because, + # the most specific is list[int|str] + self.assertEqual(g([1,2,3, "hello"]), "int|str") + + # list[float] is not mapped so, + # list[int|float] is the most specific + self.assertEqual(g([3.14]), "int|float") + + self.assertIs(g.dispatch(list[int]), g_list_int) + self.assertIs(g.dispatch(list[float]), g_list_mixed_int_float) + self.assertIs(g.dispatch(list[int|str]), g_list_mixed_int_str) + + def test_pep585_ambiguous(self): + @functools.singledispatch + def g(obj): + return "base" + @g.register + def g_list_int_float(l: list[int|float]): + return "int|float" + @g.register + def g_list_int_str(l: list[int|str]): + return "int|str" + @g.register + def g_list_int(l: list[int]): + return "int only" + + self.assertEqual(g([3.1]), "int|float") # floats only + self.assertEqual(g(["hello"]), "int|str") # strings only + self.assertEqual(g([3.14, 1]), "int|float") # ints and floats + self.assertEqual(g(["hello", 1]), "int|str") # ints and strings + + self.assertIs(g.dispatch(list[int]), g_list_int) + self.assertIs(g.dispatch(list[str]), g_list_int_str) + self.assertIs(g.dispatch(list[float]), g_list_int_float) + self.assertIs(g.dispatch(list[int|str]), g_list_int_str) + self.assertIs(g.dispatch(list[int|float]), g_list_int_float) + + # these should fail because it's unclear which target is "correct" + with self.assertRaises(RuntimeError): + g([1]) + + self.assertRaises(RuntimeError, g.dispatch(list[int])) + + def test_pep585_method_basic(self): + class A: + @functools.singledispatchmethod + def g(obj): + return "base" + def g_list_int(li): + return "list of ints" + + a = A() + a.g.register(list[int], A.g_list_int) + self.assertEqual(a.g([1]), "list of ints") + self.assertIs(a.g.dispatch(list[int]), A.g_list_int) + + def test_pep585_method_annotation(self): + class A: + @functools.singledispatchmethod + def g(obj): + return "base" + # previously this failed with: 'not a class' + @g.register + def g_list_int(li: list[int]): + return "list of ints" + a = A() + self.assertEqual(a.g([1,2,3]), "list of ints") + self.assertIs(g.dispatch(tuple[int]), A.g_list_int) + + def test_pep585_method_all_must_match(self): + class A: + @functools.singledispatch + def g(obj): + return "base" + def g_list_int(li): + return "list of ints" + def g_list_not_ints(l): + # should only trigger if list doesnt match `list[int]` + # ie. at least one element is not an int + return "!all(int)" + + a = A() + a.g.register(list[int], A.g_list_int) + a.g.register(list, A.g_list_not_ints) + + self.assertEqual(a.g([1,2,3]), "list of ints") + self.assertEqual(a.g([1,2,3, "hello"]), "!all(int)") + self.assertEqual(a.g([3.14]), "!all(int)") + + self.assertIs(a.g.dispatch(list[int]), A.g_list_int) + self.assertIs(a.g.dispatch(list[str]), A.g_list_not_ints) + self.assertIs(a.g.dispatch(list[float]), A.g_list_not_ints) + self.assertIs(a.g.dispatch(list[int|str]), A.g_list_not_ints) + + def test_pep585_method_specificity(self): + class A: + @functools.singledispatch + def g(obj): + return "base" + @g.register + def g_list(l: list): + return "basic list" + @g.register + def g_list_int(li: list[int]): + return "int" + @g.register + def g_list_str(ls: list[str]): + return "str" + @g.register + def g_list_mixed_int_str(lmis:list[int|str]): + return "int|str" + @g.register + def g_list_mixed_int_float(lmif: list[int|float]): + return "int|float" + @g.register + def g_list_mixed_int_float_str(lmifs: list[int|float|str]): + return "int|float|str" + + a = A() + + # this matches list, list[int], list[int|str], list[int|float|str], list[int|...|...|...|...] + # but list[int] is the most specific, so that is correct + self.assertEqual(a.g([1,2,3]), "int") + + # this cannot match list[int] because of the string + # it does match list[int|float|str] but this is incorrect because, + # the most specific is list[int|str] + self.assertEqual(a.g([1,2,3, "hello"]), "int|str") + + # list[float] is not mapped so, + # list[int|float] is the most specific + self.assertEqual(a.g([3.14]), "int|float") + + self.assertIs(a.g.dispatch(list[int]), A.g_list_int) + self.assertIs(a.g.dispatch(list[float]), A.g_list_mixed_int_float) + self.assertIs(a.g.dispatch(list[int|str]), A.g_list_mixed_int_str) + + def test_pep585_method_ambiguous(self): + class A: + @functools.singledispatch + def g(obj): + return "base" + @g.register + def g_list_int_float(l: list[int|float]): + return "int|float" + @g.register + def g_list_int_str(l: list[int|str]): + return "int|str" + @g.register + def g_list_int(l: list[int]): + return "int only" + + a = A() + + self.assertEqual(a.g([3.1]), "int|float") # floats only + self.assertEqual(a.g(["hello"]), "int|str") # strings only + self.assertEqual(a.g([3.14, 1]), "int|float") # ints and floats + self.assertEqual(a.g(["hello", 1]), "int|str") # ints and strings + + self.assertIs(a.g.dispatch(list[int]), A.g_list_int) + self.assertIs(a.g.dispatch(list[str]), A.g_list_int_str) + self.assertIs(a.g.dispatch(list[float]), A.g_list_int_float) + self.assertIs(a.g.dispatch(list[int|str]), A.g_list_int_str) + self.assertIs(a.g.dispatch(list[int|float]), A.g_list_int_float) + + # these should fail because it's unclear which target is "correct" + self.assertRaises(RuntimeError, a.g([1])) + + self.assertRaises(RuntimeError, a.g.dispatch(list[int])) + def test_simple_overloads(self): @functools.singledispatch def g(obj): @@ -3238,18 +3483,12 @@ def test_register_genericalias(self): def f(arg): return "default" - with self.assertRaisesRegex(TypeError, "Invalid first argument to "): - f.register(list[int], lambda arg: "types.GenericAlias") - with self.assertRaisesRegex(TypeError, "Invalid first argument to "): - f.register(typing.List[int], lambda arg: "typing.GenericAlias") - with self.assertRaisesRegex(TypeError, "Invalid first argument to "): - f.register(list[int] | str, lambda arg: "types.UnionTypes(types.GenericAlias)") - with self.assertRaisesRegex(TypeError, "Invalid first argument to "): - f.register(typing.List[float] | bytes, lambda arg: "typing.Union[typing.GenericAlias]") + f.register(list[int], lambda arg: "types.GenericAlias") + f.register(list[float] | str, lambda arg: "types.UnionTypes(types.GenericAlias)") - self.assertEqual(f([1]), "default") - self.assertEqual(f([1.0]), "default") - self.assertEqual(f(""), "default") + self.assertEqual(f([1]), "types.GenericAlias") + self.assertEqual(f([1.0]), "types.UnionTypes(types.GenericAlias)") + self.assertEqual(f(""), "types.UnionTypes(types.GenericAlias)") self.assertEqual(f(b""), "default") def test_register_genericalias_decorator(self): @@ -3257,41 +3496,39 @@ def test_register_genericalias_decorator(self): def f(arg): return "default" - with self.assertRaisesRegex(TypeError, "Invalid first argument to "): - f.register(list[int]) - with self.assertRaisesRegex(TypeError, "Invalid first argument to "): - f.register(typing.List[int]) - with self.assertRaisesRegex(TypeError, "Invalid first argument to "): - f.register(list[int] | str) - with self.assertRaisesRegex(TypeError, "Invalid first argument to "): - f.register(typing.List[int] | str) + f.register(list[int]) + #f.register(typing.List[int]) + f.register(list[int] | str) + #f.register(typing.List[int] | str) def test_register_genericalias_annotation(self): @functools.singledispatch def f(arg): return "default" - with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): - @f.register - def _(arg: list[int]): - return "types.GenericAlias" + @f.register + def _(arg: list[int]): + return "types.GenericAlias" + with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): @f.register def _(arg: typing.List[float]): return "typing.GenericAlias" - with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): - @f.register - def _(arg: list[int] | str): - return "types.UnionType(types.GenericAlias)" + + @f.register + def _(arg: list[bytes] | str): + return "types.UnionType(types.GenericAlias)" + with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): @f.register def _(arg: typing.List[float] | bytes): return "typing.Union[typing.GenericAlias]" - self.assertEqual(f([1]), "default") + self.assertEqual(f([1]), "types.GenericAlias") self.assertEqual(f([1.0]), "default") - self.assertEqual(f(""), "default") + self.assertEqual(f(""), "types.UnionType(types.GenericAlias)") self.assertEqual(f(b""), "default") + self.assertEqual(f([b""]), "types.UnionType(types.GenericAlias)") def test_forward_reference(self): @functools.singledispatch