Skip to content

Commit 664e540

Browse files
authored
Add multi-methods to Basilisp (#222)
* Add multi-methods to Basilisp * Simplify call logic for MultiFunction
1 parent e52d2c6 commit 664e540

File tree

3 files changed

+185
-0
lines changed

3 files changed

+185
-0
lines changed

src/basilisp/core/__init__.lpy

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1351,3 +1351,51 @@
13511351
[pattern s]
13521352
(lazy-re-seq (seq (re/finditer pattern s))))
13531353

1354+
;;;;;;;;;;;;;;;;;;
1355+
;; Multimethods ;;
1356+
;;;;;;;;;;;;;;;;;;
1357+
1358+
(import basilisp.lang.multifn)
1359+
1360+
(defmacro defmulti
1361+
"Define a new multimethod with the dispatch function."
1362+
[name & body]
1363+
(let [doc (when (string? (first body))
1364+
(first body))
1365+
body (if doc
1366+
(rest body)
1367+
body)
1368+
dispatch-fn (first body)
1369+
opts (apply hash-map (rest body))]
1370+
`(def ~name (basilisp.lang.multifn/MultiFunction ~name
1371+
~dispatch-fn
1372+
(or (:default opts) :default)))))
1373+
1374+
(defmacro defmethod
1375+
"Add a new method to the multi-function which responds to dispatch-val."
1376+
[multifn dispatch-val & fn-tail]
1377+
`(.add-method multi-fn dispatch-val (fn ~@fn-tail)))
1378+
1379+
(defn methods
1380+
"Return a map of dispatch values to methods for the given multi function."
1381+
[multifn]
1382+
(.-methods multifn))
1383+
1384+
(defn get-method
1385+
"Return the method which would respond to dispatch-val or nil if no method
1386+
exists for dispatch-val."
1387+
[multifn dispatch-val]
1388+
(.get-method multifn dispatch-val))
1389+
1390+
(defn remove-method
1391+
"Remove the method which responds to dispatch-val, if it exists. Return the
1392+
multi function."
1393+
[multifn dispatch-val]
1394+
(.remove-method multifn dispatch-val)
1395+
multifn)
1396+
1397+
(defn remove-all-methods
1398+
"Remove all method for this multi-function. Return the multi function."
1399+
[multifn dispatch-val]
1400+
(.remove-all-methods multifn)
1401+
multifn)

src/basilisp/lang/multifn.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from typing import TypeVar, Generic, Callable, Any, Optional
2+
3+
import basilisp.lang.atom as atom
4+
import basilisp.lang.map as lmap
5+
import basilisp.lang.symbol as sym
6+
from basilisp.util import Maybe
7+
8+
T = TypeVar('T')
9+
DispatchFunction = Callable[..., T]
10+
Method = Callable[..., Any]
11+
12+
13+
class MultiFunction(Generic[T]):
14+
__slots__ = ('_name', '_default', '_dispatch', '_methods')
15+
16+
def __init__(self, name: sym.Symbol, dispatch: DispatchFunction, default: T) -> None:
17+
self._name = name # pylint:disable=assigning-non-slot
18+
self._default = default # pylint:disable=assigning-non-slot
19+
self._dispatch = dispatch # pylint:disable=assigning-non-slot
20+
self._methods: atom.Atom = atom.Atom(lmap.Map.empty()) # pylint:disable=assigning-non-slot
21+
22+
def __call__(self, *args, **kwargs):
23+
key = self._dispatch(*args, **kwargs)
24+
method_cache = self.methods
25+
method = Maybe(method_cache.entry(key, None)).or_else(
26+
lambda: method_cache.entry(self._default, None))
27+
if method:
28+
return method(*args, **kwargs)
29+
raise NotImplementedError
30+
31+
@staticmethod
32+
def __add_method(m: lmap.Map, key: T, method: Method) -> lmap.Map:
33+
"""Swap the methods atom to include method with key."""
34+
return m.assoc(key, method)
35+
36+
def add_method(self, key: T, method: Method) -> None:
37+
"""Add a new method to this function which will respond for
38+
key returned from the dispatch function."""
39+
self._methods.swap(MultiFunction.__add_method, key, method)
40+
41+
def get_method(self, key: T) -> Optional[Method]:
42+
"""Return the method which would handle this dispatch key or
43+
None if no method defined for this key and no default."""
44+
method_cache = self.methods
45+
# The 'type: ignore' comment below silences a spurious MyPy error
46+
# about having a return statement in a method which does not return.
47+
return Maybe(method_cache.entry(key, None)).or_else(
48+
lambda: method_cache.entry(self._default, None)) # type: ignore
49+
50+
@staticmethod
51+
def __remove_method(m: lmap.Map, key: T) -> lmap.Map:
52+
"""Swap the methods atom to remove method with key."""
53+
return m.dissoc(key)
54+
55+
def remove_method(self, key: T) -> Optional[Method]:
56+
"""Remove the method defined for this key and return it."""
57+
method = self.methods.entry(key, None)
58+
if method:
59+
self._methods.swap(MultiFunction.__remove_method, key)
60+
return method
61+
62+
def remove_all_methods(self) -> None:
63+
"""Remove all methods defined for this multi-function."""
64+
self._methods.reset(lmap.Map.empty())
65+
66+
@property
67+
def default(self) -> T:
68+
return self._default
69+
70+
@property
71+
def methods(self) -> lmap.Map:
72+
return self._methods.deref()
73+
74+
75+
def multifn(dispatch: DispatchFunction, default=None) -> MultiFunction[T]:
76+
"""Decorator function which can be used to make Python multi functions."""
77+
name = sym.symbol(dispatch.__qualname__, ns=dispatch.__module__)
78+
return MultiFunction(name, dispatch, default)

tests/basilisp/multifn_test.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import pytest
2+
3+
import basilisp.lang.keyword as kw
4+
import basilisp.lang.map as lmap
5+
import basilisp.lang.multifn as multifn
6+
import basilisp.lang.symbol as sym
7+
8+
9+
def test_multi_function():
10+
def dispatch(v) -> kw.Keyword:
11+
if v == "i":
12+
return kw.keyword('a')
13+
elif v == 'ii':
14+
return kw.keyword('b')
15+
return kw.keyword('default')
16+
17+
def fn_a(v) -> str:
18+
return "1"
19+
20+
def fn_b(v) -> str:
21+
return "2"
22+
23+
def fn_default(v) -> str:
24+
return "BLAH"
25+
26+
f = multifn.MultiFunction(sym.symbol('test-fn'), dispatch, kw.keyword('default'))
27+
f.add_method(kw.keyword('a'), fn_a)
28+
f.add_method(kw.keyword('b'), fn_b)
29+
f.add_method(kw.keyword('default'), fn_default)
30+
31+
assert lmap.map({kw.keyword('a'): fn_a,
32+
kw.keyword('b'): fn_b,
33+
kw.keyword('default'): fn_default}) == f.methods
34+
35+
assert kw.keyword('default') == f.default
36+
37+
assert fn_a is f.get_method(kw.keyword('a'))
38+
assert fn_b is f.get_method(kw.keyword('b'))
39+
assert fn_default is f.get_method(kw.keyword('default'))
40+
assert fn_default is f.get_method(kw.keyword('other'))
41+
42+
assert "1" == f("i")
43+
assert "2" == f("ii")
44+
assert "BLAH" == f("iii")
45+
assert "BLAH" == f("whatever")
46+
47+
f.remove_method(kw.keyword('b'))
48+
49+
assert "1" == f("i")
50+
assert "BLAH" == f("ii")
51+
assert "BLAH" == f("iii")
52+
assert "BLAH" == f("whatever")
53+
54+
f.remove_all_methods()
55+
56+
assert lmap.Map.empty() == f.methods
57+
58+
with pytest.raises(NotImplementedError):
59+
f('blah')

0 commit comments

Comments
 (0)