Skip to content

Commit 678806d

Browse files
committed
fixes #560
1 parent 50a1581 commit 678806d

File tree

4 files changed

+32
-20
lines changed

4 files changed

+32
-20
lines changed

fastcore/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.5.41"
1+
__version__ = "1.5.42"

fastcore/basics.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,13 @@ def tonull(x):
9898
return null if x is None else x
9999

100100
# %% ../nbs/01_basics.ipynb 41
101-
def get_class(nm, *fld_names, sup=None, doc=None, funcs=None, **flds):
101+
def get_class(nm, *fld_names, sup=None, doc=None, funcs=None, anno=None, **flds):
102102
"Dynamically create a class, optionally inheriting from `sup`, containing `fld_names`"
103103
attrs = {}
104-
for f in fld_names: attrs[f] = None
104+
if not anno: anno = {}
105+
for f in fld_names:
106+
attrs[f] = None
107+
if f not in anno: anno[f] = typing.Any
105108
for f in listify(funcs): attrs[f.__name__] = f
106109
for k,v in flds.items(): attrs[k] = v
107110
sup = ifnone(sup, ())
@@ -111,22 +114,23 @@ def _init(self, *args, **kwargs):
111114
for i,v in enumerate(args): setattr(self, list(attrs.keys())[i], v)
112115
for k,v in kwargs.items(): setattr(self,k,v)
113116

114-
all_flds = [*fld_names,*flds.keys()]
117+
attrs['_fields'] = [*fld_names,*flds.keys()]
115118
def _eq(self,b):
116-
return all([getattr(self,k)==getattr(b,k) for k in all_flds])
119+
return all([getattr(self,k)==getattr(b,k) for k in self._fields])
117120

118-
if not sup: attrs['__repr__'] = basic_repr(all_flds)
121+
if not sup: attrs['__repr__'] = basic_repr(attrs['_fields'])
119122
attrs['__init__'] = _init
120123
attrs['__eq__'] = _eq
124+
if anno: attrs['__annotations__'] = anno
121125
res = type(nm, sup, attrs)
122126
if doc is not None: res.__doc__ = doc
123127
return res
124128

125129
# %% ../nbs/01_basics.ipynb 45
126-
def mk_class(nm, *fld_names, sup=None, doc=None, funcs=None, mod=None, **flds):
130+
def mk_class(nm, *fld_names, sup=None, doc=None, funcs=None, mod=None, anno=None, **flds):
127131
"Create a class using `get_class` and add to the caller's module"
128132
if mod is None: mod = sys._getframe(1).f_locals
129-
res = get_class(nm, *fld_names, sup=sup, doc=doc, funcs=funcs, **flds)
133+
res = get_class(nm, *fld_names, sup=sup, doc=doc, funcs=funcs, anno=anno, **flds)
130134
mod[nm] = res
131135

132136
# %% ../nbs/01_basics.ipynb 50

nbs/01_basics.ipynb

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
"from __future__ import annotations\n",
3434
"from fastcore.test import *\n",
3535
"from nbdev.showdoc import *\n",
36-
"from fastcore.nb_imports import *"
36+
"from fastcore.nb_imports import *\n",
37+
"from inspect import get_annotations"
3738
]
3839
},
3940
{
@@ -553,10 +554,13 @@
553554
"outputs": [],
554555
"source": [
555556
"#|export\n",
556-
"def get_class(nm, *fld_names, sup=None, doc=None, funcs=None, **flds):\n",
557+
"def get_class(nm, *fld_names, sup=None, doc=None, funcs=None, anno=None, **flds):\n",
557558
" \"Dynamically create a class, optionally inheriting from `sup`, containing `fld_names`\"\n",
558559
" attrs = {}\n",
559-
" for f in fld_names: attrs[f] = None\n",
560+
" if not anno: anno = {}\n",
561+
" for f in fld_names:\n",
562+
" attrs[f] = None\n",
563+
" if f not in anno: anno[f] = typing.Any\n",
560564
" for f in listify(funcs): attrs[f.__name__] = f\n",
561565
" for k,v in flds.items(): attrs[k] = v\n",
562566
" sup = ifnone(sup, ())\n",
@@ -566,13 +570,14 @@
566570
" for i,v in enumerate(args): setattr(self, list(attrs.keys())[i], v)\n",
567571
" for k,v in kwargs.items(): setattr(self,k,v)\n",
568572
"\n",
569-
" all_flds = [*fld_names,*flds.keys()]\n",
573+
" attrs['_fields'] = [*fld_names,*flds.keys()]\n",
570574
" def _eq(self,b):\n",
571-
" return all([getattr(self,k)==getattr(b,k) for k in all_flds])\n",
575+
" return all([getattr(self,k)==getattr(b,k) for k in self._fields])\n",
572576
"\n",
573-
" if not sup: attrs['__repr__'] = basic_repr(all_flds)\n",
577+
" if not sup: attrs['__repr__'] = basic_repr(attrs['_fields'])\n",
574578
" attrs['__init__'] = _init\n",
575579
" attrs['__eq__'] = _eq\n",
580+
" if anno: attrs['__annotations__'] = anno\n",
576581
" res = type(nm, sup, attrs)\n",
577582
" if doc is not None: res.__doc__ = doc\n",
578583
" return res"
@@ -592,7 +597,8 @@
592597
"\n",
593598
"### get_class\n",
594599
"\n",
595-
"> get_class (nm, *fld_names, sup=None, doc=None, funcs=None, **flds)\n",
600+
"> get_class (nm, *fld_names, sup=None, doc=None, funcs=None, anno=None,\n",
601+
"> **flds)\n",
596602
"\n",
597603
"*Dynamically create a class, optionally inheriting from `sup`, containing `fld_names`*"
598604
],
@@ -603,7 +609,8 @@
603609
"\n",
604610
"### get_class\n",
605611
"\n",
606-
"> get_class (nm, *fld_names, sup=None, doc=None, funcs=None, **flds)\n",
612+
"> get_class (nm, *fld_names, sup=None, doc=None, funcs=None, anno=None,\n",
613+
"> **flds)\n",
607614
"\n",
608615
"*Dynamically create a class, optionally inheriting from `sup`, containing `fld_names`*"
609616
]
@@ -634,7 +641,7 @@
634641
}
635642
],
636643
"source": [
637-
"_t = get_class('_t', 'a', b=2)\n",
644+
"_t = get_class('_t', 'a', b=2, anno={'b':int})\n",
638645
"t = _t()\n",
639646
"test_eq(t.a, None)\n",
640647
"test_eq(t.b, 2)\n",
@@ -645,6 +652,7 @@
645652
"test_eq(t.a, 1)\n",
646653
"test_eq(t.b, 3)\n",
647654
"test_eq(t, pickle.loads(pickle.dumps(t)))\n",
655+
"test_eq(get_annotations(_t), {'b':int, 'a':typing.Any})\n",
648656
"repr(t)"
649657
]
650658
},
@@ -662,10 +670,10 @@
662670
"outputs": [],
663671
"source": [
664672
"#|export\n",
665-
"def mk_class(nm, *fld_names, sup=None, doc=None, funcs=None, mod=None, **flds):\n",
673+
"def mk_class(nm, *fld_names, sup=None, doc=None, funcs=None, mod=None, anno=None, **flds):\n",
666674
" \"Create a class using `get_class` and add to the caller's module\"\n",
667675
" if mod is None: mod = sys._getframe(1).f_locals\n",
668-
" res = get_class(nm, *fld_names, sup=sup, doc=doc, funcs=funcs, **flds)\n",
676+
" res = get_class(nm, *fld_names, sup=sup, doc=doc, funcs=funcs, anno=anno, **flds)\n",
669677
" mod[nm] = res"
670678
]
671679
},

settings.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ author = Jeremy Howard and Sylvain Gugger
88
author_email = [email protected]
99
copyright = fast.ai
1010
branch = master
11-
version = 1.5.41
11+
version = 1.5.42
1212
min_python = 3.7
1313
audience = Developers
1414
language = English

0 commit comments

Comments
 (0)