|
| 1 | +from typing import TypeVar, Generic, Callable, Any, Optional |
| 2 | + |
| 3 | +import basilisp.lang.atom as atom |
| 4 | +import basilisp.lang.map as lmap |
| 5 | +import basilisp.lang.symbol as sym |
| 6 | +from basilisp.util import Maybe |
| 7 | + |
| 8 | +T = TypeVar('T') |
| 9 | +DispatchFunction = Callable[..., T] |
| 10 | +Method = Callable[..., Any] |
| 11 | + |
| 12 | + |
| 13 | +class MultiFunction(Generic[T]): |
| 14 | + __slots__ = ('_name', '_default', '_dispatch', '_methods') |
| 15 | + |
| 16 | + def __init__(self, name: sym.Symbol, dispatch: DispatchFunction, default: T) -> None: |
| 17 | + self._name = name # pylint:disable=assigning-non-slot |
| 18 | + self._default = default # pylint:disable=assigning-non-slot |
| 19 | + self._dispatch = dispatch # pylint:disable=assigning-non-slot |
| 20 | + self._methods: atom.Atom = atom.Atom(lmap.Map.empty()) # pylint:disable=assigning-non-slot |
| 21 | + |
| 22 | + def __call__(self, *args, **kwargs): |
| 23 | + key = self._dispatch(*args, **kwargs) |
| 24 | + method_cache = self.methods |
| 25 | + method = Maybe(method_cache.entry(key, None)).or_else( |
| 26 | + lambda: method_cache.entry(self._default, None)) |
| 27 | + if method: |
| 28 | + return method(*args, **kwargs) |
| 29 | + raise NotImplementedError |
| 30 | + |
| 31 | + @staticmethod |
| 32 | + def __add_method(m: lmap.Map, key: T, method: Method) -> lmap.Map: |
| 33 | + """Swap the methods atom to include method with key.""" |
| 34 | + return m.assoc(key, method) |
| 35 | + |
| 36 | + def add_method(self, key: T, method: Method) -> None: |
| 37 | + """Add a new method to this function which will respond for |
| 38 | + key returned from the dispatch function.""" |
| 39 | + self._methods.swap(MultiFunction.__add_method, key, method) |
| 40 | + |
| 41 | + def get_method(self, key: T) -> Optional[Method]: |
| 42 | + """Return the method which would handle this dispatch key or |
| 43 | + None if no method defined for this key and no default.""" |
| 44 | + method_cache = self.methods |
| 45 | + # The 'type: ignore' comment below silences a spurious MyPy error |
| 46 | + # about having a return statement in a method which does not return. |
| 47 | + return Maybe(method_cache.entry(key, None)).or_else( |
| 48 | + lambda: method_cache.entry(self._default, None)) # type: ignore |
| 49 | + |
| 50 | + @staticmethod |
| 51 | + def __remove_method(m: lmap.Map, key: T) -> lmap.Map: |
| 52 | + """Swap the methods atom to remove method with key.""" |
| 53 | + return m.dissoc(key) |
| 54 | + |
| 55 | + def remove_method(self, key: T) -> Optional[Method]: |
| 56 | + """Remove the method defined for this key and return it.""" |
| 57 | + method = self.methods.entry(key, None) |
| 58 | + if method: |
| 59 | + self._methods.swap(MultiFunction.__remove_method, key) |
| 60 | + return method |
| 61 | + |
| 62 | + def remove_all_methods(self) -> None: |
| 63 | + """Remove all methods defined for this multi-function.""" |
| 64 | + self._methods.reset(lmap.Map.empty()) |
| 65 | + |
| 66 | + @property |
| 67 | + def default(self) -> T: |
| 68 | + return self._default |
| 69 | + |
| 70 | + @property |
| 71 | + def methods(self) -> lmap.Map: |
| 72 | + return self._methods.deref() |
| 73 | + |
| 74 | + |
| 75 | +def multifn(dispatch: DispatchFunction, default=None) -> MultiFunction[T]: |
| 76 | + """Decorator function which can be used to make Python multi functions.""" |
| 77 | + name = sym.symbol(dispatch.__qualname__, ns=dispatch.__module__) |
| 78 | + return MultiFunction(name, dispatch, default) |
0 commit comments