Skip to content

Commit 2a495ac

Browse files
bugfix wrap dataclass_abc class decorator like in dataclasses.dataclass
1 parent 2453129 commit 2a495ac

File tree

3 files changed

+53
-44
lines changed

3 files changed

+53
-44
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class A(ABC):
2424
def val(self) -> str:
2525
...
2626

27-
@dataclass_abc
27+
@dataclass_abc(frozen=True)
2828
class B(A):
2929
val: str # overwrites the abstract property 'val' in 'A'
3030
```
@@ -50,7 +50,7 @@ class A(ABC):
5050
def val1(self) -> str:
5151
...
5252

53-
@dataclass_abc
53+
@dataclass_abc(frozen=True)
5454
class B(A):
5555
val1: str
5656
val2: str

dataclass_abc/__init__.py

Lines changed: 50 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def set_func(self, val, key=key):
4848
return new_cls
4949

5050

51-
def dataclass_abc(cls, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False):
51+
def dataclass_abc(_cls=None, *, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False):
5252
"""
5353
meant to be used as a class decorator similarly to `dataclasses.dataclass_abc`.
5454
@@ -59,43 +59,52 @@ def dataclass_abc(cls, repr=True, eq=True, order=False, unsafe_hash=False, froze
5959
6060
"""
6161

62-
if cls.__module__ in sys.modules:
63-
globals = sys.modules[cls.__module__].__dict__
64-
else:
65-
# Theoretically this can happen if someone writes
66-
# a custom string to cls.__module__. In which case
67-
# such dataclass_abc won't be fully introspectable
68-
# (w.r.t. typing.get_type_hints) but will still function
69-
# correctly.
70-
globals = {}
71-
72-
cls = parent_dataclass(cls, init=False, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen)
73-
74-
fields = cls.__dict__['__dataclass_fields__']
75-
76-
def gen_fields():
77-
for field in fields.values():
78-
# Include InitVars and regular fields (so, not ClassVars).
79-
if field._field_type in (_FIELD, _FIELD_INITVAR):
80-
field.default = MISSING
81-
field.default_factory = MISSING
82-
yield field
83-
84-
flds = list(gen_fields())
85-
86-
# Does this class have a post-init function?
87-
has_post_init = hasattr(cls, _POST_INIT_NAME)
88-
89-
_set_new_attribute(cls, '__init__',
90-
_init_fn(flds,
91-
frozen,
92-
has_post_init,
93-
# The name to use for the "self"
94-
# param in __init__. Use "self"
95-
# if possible.
96-
'__dataclass_self__' if 'self' in fields
97-
else 'self',
98-
globals,
99-
))
100-
101-
return resolve_abc_prop(cls)
62+
def wrap(cls):
63+
if cls.__module__ in sys.modules:
64+
globals = sys.modules[cls.__module__].__dict__
65+
else:
66+
# Theoretically this can happen if someone writes
67+
# a custom string to cls.__module__. In which case
68+
# such dataclass_abc won't be fully introspectable
69+
# (w.r.t. typing.get_type_hints) but will still function
70+
# correctly.
71+
globals = {}
72+
73+
cls = parent_dataclass(cls, init=False, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen)
74+
75+
fields = cls.__dict__['__dataclass_fields__']
76+
77+
def gen_fields():
78+
for field in fields.values():
79+
# Include InitVars and regular fields (so, not ClassVars).
80+
if field._field_type in (_FIELD, _FIELD_INITVAR):
81+
field.default = MISSING
82+
field.default_factory = MISSING
83+
yield field
84+
85+
flds = list(gen_fields())
86+
87+
# Does this class have a post-init function?
88+
has_post_init = hasattr(cls, _POST_INIT_NAME)
89+
90+
_set_new_attribute(cls, '__init__',
91+
_init_fn(flds,
92+
frozen,
93+
has_post_init,
94+
# The name to use for the "self"
95+
# param in __init__. Use "self"
96+
# if possible.
97+
'__dataclass_self__' if 'self' in fields
98+
else 'self',
99+
globals,
100+
))
101+
102+
return resolve_abc_prop(cls)
103+
104+
# See if we're being called as @dataclass or @dataclass().
105+
if _cls is None:
106+
# We're called with parens.
107+
return wrap
108+
109+
# We're called as @dataclass without parens.
110+
return wrap(_cls)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
setup(
1212
name='dataclass_abc',
13-
version='0.0.3',
13+
version='0.0.4',
1414
description='Library that lets you define abstract properties for dataclasses.',
1515
long_description=long_description,
1616
long_description_content_type='text/markdown',

0 commit comments

Comments
 (0)