Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 73 additions & 16 deletions Lib/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,8 +843,12 @@ def is_strict_base(typ):
mro.append(subcls)
return _c3_mro(cls, abcs=mro)

def _find_impl(cls, registry):
"""Returns the best matching implementation from *registry* for type *cls*.
def _pep585_registry_matches(cls, registry):
from typing import get_origin
return (i for i in registry.keys() if get_origin(i) == cls)

def _find_impl_match(cls_obj, registry):
"""Returns the best matching implementation from *registry* for type *cls_obj*.

Where there is no registered implementation for a specific type, its method
resolution order is used to find a more generic implementation.
Expand All @@ -853,8 +857,35 @@ def _find_impl(cls, registry):
*object* type, this function may return None.

"""
cls = cls_obj if isinstance(cls_obj, type) else cls_obj.__class__
mro = _compose_mro(cls, registry.keys())
match = None

from typing import get_origin, get_args

if (not isinstance(cls_obj, type) and
len(cls_obj) > 0 and # dont try to match the types of empty containers
any(_pep585_registry_matches(cls, registry))):
# check containers that match cls first
for t in _pep585_registry_matches(cls, registry):
if not all((isinstance(i, get_args(t)) for i in cls_obj)):
continue

if match is None:
match = t

else:
match_args = get_args(get_args(match)[0])
t_args = get_args(get_args(t)[0])
if len(match_args) == len(t_args):
raise RuntimeError("Ambiguous dispatch: {} or {}".format( match, t))

elif len(t_args)<len(match_args):
match = t

if match:
return match

for t in mro:
if match is not None:
# If *match* is an implicit ABC but there is another unrelated,
Expand All @@ -867,7 +898,14 @@ def _find_impl(cls, registry):
break
if t in registry:
match = t
return registry.get(match)

return match

def _find_impl(cls_obj, registry):
return registry.get(
_find_impl_match(cls_obj, registry)
)


def singledispatch(func):
"""Single-dispatch generic function decorator.
Expand All @@ -887,34 +925,52 @@ def singledispatch(func):
dispatch_cache = weakref.WeakKeyDictionary()
cache_token = None

def dispatch(cls):
def _fetch_dispatch_with_cache(cls):
try:
impl = dispatch_cache[cls]
except KeyError:
try:
impl = registry[cls]
except KeyError:
impl = _find_impl(cls, registry)
dispatch_cache[cls] = impl
return impl


def dispatch(cls_obj):
"""generic_func.dispatch(cls) -> <function implementation>

Runs the dispatch algorithm to return the best available implementation
for the given *cls* registered on *generic_func*.

"""
cls = cls_obj.__class__
nonlocal cache_token
if cache_token is not None:
current_token = get_cache_token()
if cache_token != current_token:
dispatch_cache.clear()
cache_token = current_token
try:
impl = dispatch_cache[cls]
except KeyError:
try:
impl = registry[cls]
except KeyError:
impl = _find_impl(cls, registry)
dispatch_cache[cls] = impl
return impl

# if PEP-585 types are not registered for the given *cls*,
# then we can use the cache. Otherwise, the cache cannot be used
# because we need to confirm every item matches first
if not any(_pep585_registry_matches(cls, registry)):
return _fetch_dispatch_with_cache(cls)

return _find_impl(cls_obj, registry)

def _is_valid_dispatch_type(cls):
if isinstance(cls, type):
return True

if isinstance(cls, GenericAlias):
from typing import get_args
return all(isinstance(arg, (type, UnionType)) for arg in get_args(cls))

return (isinstance(cls, UnionType) and
all(isinstance(arg, type) for arg in cls.__args__))
all(isinstance(arg, (type, GenericAlias)) for arg in cls.__args__))


def register(cls, func=None):
"""generic_func.register(cls, func) -> func
Expand All @@ -932,6 +988,7 @@ def register(cls, func=None):
f"Invalid first argument to `register()`. "
f"{cls!r} is not a class or union type."
)

ann = getattr(cls, '__annotate__', None)
if ann is None:
raise TypeError(
Expand Down Expand Up @@ -976,7 +1033,7 @@ def wrapper(*args, **kw):
if not args:
raise TypeError(f'{funcname} requires at least '
'1 positional argument')
return dispatch(args[0].__class__)(*args, **kw)
return dispatch(args[0])(*args, **kw)

funcname = getattr(func, '__name__', 'singledispatch function')
registry[object] = func
Expand Down Expand Up @@ -1064,7 +1121,7 @@ def __call__(self, /, *args, **kwargs):
'singledispatchmethod method')
raise TypeError(f'{funcname} requires at least '
'1 positional argument')
return self._dispatch(args[0].__class__).__get__(self._obj, self._cls)(*args, **kwargs)
return self._dispatch(args[0]).__get__(self._obj, self._cls)(*args, **kwargs)

def __getattr__(self, name):
# Resolve these attributes lazily to speed up creation of
Expand Down
Loading
Loading