Skip to content

Commit d5e43c9

Browse files
authored
Merge pull request #94 from Salehbigdeli/master
Add Method Dispatching to class methods
2 parents 8622806 + ace354d commit d5e43c9

File tree

2 files changed

+40
-3
lines changed

2 files changed

+40
-3
lines changed

fastcore/dispatch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def __init__(self, funcs=(), bases=()):
8181
self.funcs,self.bases = _TypeDict(),L(bases).filter(is_not(None))
8282
for o in L(funcs): self.add(o)
8383
self.inst = None
84+
self.owner = None
8485

8586
def add(self, f):
8687
"Add type `t` and function `f`"
@@ -109,10 +110,12 @@ def __call__(self, *args, **kwargs):
109110
f = self[tuple(ts)]
110111
if not f: return args[0]
111112
if self.inst is not None: f = MethodType(f, self.inst)
113+
elif self.owner is not None: f = MethodType(f, self.owner)
112114
return f(*args, **kwargs)
113115

114116
def __get__(self, inst, owner):
115117
self.inst = inst
118+
self.owner = owner
116119
return self
117120

118121
def __getitem__(self, k):

nbs/03_dispatch.ipynb

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@
378378
" self.funcs,self.bases = _TypeDict(),L(bases).filter(is_not(None))\n",
379379
" for o in L(funcs): self.add(o)\n",
380380
" self.inst = None\n",
381+
" self.owner = None\n",
381382
"\n",
382383
" def add(self, f):\n",
383384
" \"Add type `t` and function `f`\"\n",
@@ -406,10 +407,12 @@
406407
" f = self[tuple(ts)]\n",
407408
" if not f: return args[0]\n",
408409
" if self.inst is not None: f = MethodType(f, self.inst)\n",
410+
" elif self.owner is not None: f = MethodType(f, self.owner)\n",
409411
" return f(*args, **kwargs)\n",
410412
"\n",
411413
" def __get__(self, inst, owner):\n",
412414
" self.inst = inst\n",
415+
" self.owner = owner\n",
413416
" return self\n",
414417
"\n",
415418
" def __getitem__(self, k):\n",
@@ -550,7 +553,7 @@
550553
{
551554
"data": {
552555
"text/markdown": [
553-
"<h4 id=\"TypeDispatch.add\" class=\"doc_header\"><code>TypeDispatch.add</code><a href=\"__main__.py#L10\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
556+
"<h4 id=\"TypeDispatch.add\" class=\"doc_header\"><code>TypeDispatch.add</code><a href=\"__main__.py#L11\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
554557
"\n",
555558
"> <code>TypeDispatch.add</code>(**`f`**)\n",
556559
"\n",
@@ -873,7 +876,7 @@
873876
{
874877
"data": {
875878
"text/markdown": [
876-
"<h4 id=\"TypeDispatch.__call__\" class=\"doc_header\"><code>TypeDispatch.__call__</code><a href=\"__main__.py#L32\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
879+
"<h4 id=\"TypeDispatch.__call__\" class=\"doc_header\"><code>TypeDispatch.__call__</code><a href=\"__main__.py#L33\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
877880
"\n",
878881
"> <code>TypeDispatch.__call__</code>(**\\*`args`**, **\\*\\*`kwargs`**)\n",
879882
"\n",
@@ -961,7 +964,7 @@
961964
{
962965
"data": {
963966
"text/markdown": [
964-
"<h4 id=\"TypeDispatch.returns\" class=\"doc_header\"><code>TypeDispatch.returns</code><a href=\"__main__.py#L20\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
967+
"<h4 id=\"TypeDispatch.returns\" class=\"doc_header\"><code>TypeDispatch.returns</code><a href=\"__main__.py#L21\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
965968
"\n",
966969
"> <code>TypeDispatch.returns</code>(**`x`**)\n",
967970
"\n",
@@ -1080,6 +1083,36 @@
10801083
"test_eq(a2.f(()), (1,))"
10811084
]
10821085
},
1086+
{
1087+
"cell_type": "markdown",
1088+
"metadata": {},
1089+
"source": [
1090+
"#### Using TypeDispatch With Class Methods\n",
1091+
"\n",
1092+
"You can use `TypeDispatch` when defining class methods too:"
1093+
]
1094+
},
1095+
{
1096+
"cell_type": "code",
1097+
"execution_count": null,
1098+
"metadata": {},
1099+
"outputs": [],
1100+
"source": [
1101+
"def m_nin(cls, x:(str,numbers.Integral)): return str(x)+'1'\n",
1102+
"def m_bll(cls, x:bool): cls.foo='a'\n",
1103+
"def m_num(cls, x:numbers.Number): return x*2\n",
1104+
"\n",
1105+
"t = TypeDispatch([m_nin,m_num,m_bll])\n",
1106+
"class A: f = t # set class attribute `f` equal to a TypeDispatch\n",
1107+
"\n",
1108+
"test_eq(A.f(1), '11') #dispatch to m_nin\n",
1109+
"test_eq(A.f(1.), 2.) #dispatch to m_num\n",
1110+
"test_is(A.f.owner, A)\n",
1111+
"\n",
1112+
"A.f(False) # this triggers t.m_bll to run, which sets A.foo to 'a'\n",
1113+
"test_eq(A.foo, 'a')"
1114+
]
1115+
},
10831116
{
10841117
"cell_type": "markdown",
10851118
"metadata": {},
@@ -1396,6 +1429,7 @@
13961429
"Converted 04_transform.ipynb.\n",
13971430
"Converted 05_logargs.ipynb.\n",
13981431
"Converted 06_meta.ipynb.\n",
1432+
"Converted 07_script.ipynb.\n",
13991433
"Converted index.ipynb.\n"
14001434
]
14011435
}

0 commit comments

Comments
 (0)