Skip to content

Commit de49e91

Browse files
authored
Allow default values for deftype* fields (#380)
* Allow default values for deftype* fields * I hate linting
1 parent 762d3e6 commit de49e91

File tree

4 files changed

+66
-6
lines changed

4 files changed

+66
-6
lines changed

src/basilisp/lang/compiler/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ class SpecialForm:
135135

136136
SYM_ASYNC_META_KEY = kw.keyword("async")
137137
SYM_CLASSMETHOD_META_KEY = kw.keyword("classmethod")
138+
SYM_DEFAULT_META_KEY = kw.keyword("default")
138139
SYM_DYNAMIC_META_KEY = kw.keyword("dynamic")
139140
SYM_PROPERTY_META_KEY = kw.keyword("property")
140141
SYM_MACRO_META_KEY = kw.keyword("macro")

src/basilisp/lang/compiler/generator.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -893,17 +893,29 @@ def _deftype_to_py_ast( # pylint: disable=too-many-branches
893893

894894
with ctx.new_symbol_table(node.name):
895895
type_nodes = []
896+
type_deps: List[ast.AST] = []
896897
for field in node.fields:
897898
safe_field = munge(field.name)
899+
900+
if field.init is not None:
901+
default_nodes = gen_py_ast(ctx, field.init)
902+
type_deps.extend(default_nodes.dependencies)
903+
attr_default_kws = [
904+
ast.keyword(arg="default", value=default_nodes.node)
905+
]
906+
else:
907+
attr_default_kws = []
908+
898909
type_nodes.append(
899910
ast.Assign(
900911
targets=[ast.Name(id=safe_field, ctx=ast.Store())],
901-
value=ast.Call(func=_ATTRIB_FIELD_FN_NAME, args=[], keywords=[]),
912+
value=ast.Call(
913+
func=_ATTRIB_FIELD_FN_NAME, args=[], keywords=attr_default_kws
914+
),
902915
)
903916
)
904917
ctx.symbol_table.new_symbol(sym.symbol(field.name), safe_field, field.local)
905918

906-
type_deps: List[ast.AST] = []
907919
for member in node.members:
908920
type_ast = __deftype_member_to_py_ast(ctx, member)
909921
type_nodes.append(type_ast.node) # type: ignore

src/basilisp/lang/compiler/parser.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
OBJECT_DUNDER_METHODS,
5454
SYM_ASYNC_META_KEY,
5555
SYM_CLASSMETHOD_META_KEY,
56+
SYM_DEFAULT_META_KEY,
5657
SYM_DYNAMIC_META_KEY,
5758
SYM_MACRO_META_KEY,
5859
SYM_MUTABLE_META_KEY,
@@ -451,6 +452,7 @@ def has_meta_prop(o: Union[IMeta, Var]) -> bool:
451452

452453

453454
_is_async = _meta_getter(SYM_ASYNC_META_KEY)
455+
_is_mutable = _meta_getter(SYM_MUTABLE_META_KEY)
454456
_is_py_classmethod = _meta_getter(SYM_CLASSMETHOD_META_KEY)
455457
_is_py_property = _meta_getter(SYM_PROPERTY_META_KEY)
456458
_is_py_staticmethod = _meta_getter(SYM_STATICMETHOD_META_KEY)
@@ -1125,7 +1127,12 @@ def __assert_deftype_impls_are_abstract( # pylint: disable=too-many-branches,to
11251127
)
11261128

11271129

1128-
def _deftype_ast(ctx: ParserContext, form: ISeq) -> DefType:
1130+
__DEFTYPE_DEFAULT_SENTINEL = object()
1131+
1132+
1133+
def _deftype_ast( # pylint: disable=too-many-branches
1134+
ctx: ParserContext, form: ISeq
1135+
) -> DefType:
11291136
assert form.first == SpecialForm.DEFTYPE
11301137

11311138
nelems = count(form)
@@ -1154,18 +1161,33 @@ def _deftype_ast(ctx: ParserContext, form: ISeq) -> DefType:
11541161
f"deftype* fields must be vector, not {type(fields)}", form=fields
11551162
)
11561163

1164+
has_defaults = False
11571165
with ctx.new_symbol_table(name.name):
11581166
is_frozen = True
11591167
param_nodes = []
11601168
for field in fields:
11611169
if not isinstance(field, sym.Symbol):
11621170
raise ParserException(f"deftype* fields must be symbols", form=field)
11631171

1164-
is_mutable = (
1172+
field_default = (
11651173
Maybe(field.meta)
1166-
.map(lambda m: m.entry(SYM_MUTABLE_META_KEY)) # type: ignore
1167-
.or_else_get(False)
1174+
.map(
1175+
lambda m: m.entry( # type: ignore
1176+
SYM_DEFAULT_META_KEY, __DEFTYPE_DEFAULT_SENTINEL
1177+
)
1178+
)
1179+
.value
11681180
)
1181+
if not has_defaults and field_default is not __DEFTYPE_DEFAULT_SENTINEL:
1182+
has_defaults = True
1183+
elif has_defaults and field_default is __DEFTYPE_DEFAULT_SENTINEL:
1184+
raise ParserException(
1185+
"deftype* fields without defaults may not appear after fields "
1186+
"without defaults",
1187+
form=field,
1188+
)
1189+
1190+
is_mutable = _is_mutable(field)
11691191
if is_mutable:
11701192
is_frozen = False
11711193

@@ -1175,6 +1197,9 @@ def _deftype_ast(ctx: ParserContext, form: ISeq) -> DefType:
11751197
local=LocalType.FIELD,
11761198
is_assignable=is_mutable,
11771199
env=ctx.get_node_env(),
1200+
init=parse_ast(ctx, field_default)
1201+
if field_default is not __DEFTYPE_DEFAULT_SENTINEL
1202+
else None,
11781203
)
11791204
param_nodes.append(binding)
11801205
ctx.put_new_symbol(field, binding, warn_if_unused=False)

tests/basilisp/compiler_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,28 @@ def test_deftype_cannot_set_immutable_field(self, ns: runtime.Namespace):
618618
"""
619619
)
620620

621+
def test_deftype_allow_default_fields(self, ns: runtime.Namespace):
622+
Point = lcompile("(deftype* Point [x ^{:default 2} y ^{:default 3} z])")
623+
pt = Point(1)
624+
assert (1, 2, 3) == (pt.x, pt.y, pt.z)
625+
pt1 = Point(1, 4)
626+
assert (1, 4, 3) == (pt1.x, pt1.y, pt1.z)
627+
pt2 = Point(1, 4, 5)
628+
assert (1, 4, 5) == (pt2.x, pt2.y, pt2.z)
629+
630+
@pytest.mark.parametrize(
631+
"code",
632+
[
633+
"(deftype* Point [^{:default 1} x y z])",
634+
"(deftype* Point [x ^{:default 2} y z])",
635+
],
636+
)
637+
def test_deftype_disallow_non_default_fields_after_default(
638+
self, ns: runtime.Namespace, code: str
639+
):
640+
with pytest.raises(compiler.CompilerException):
641+
lcompile(code)
642+
621643
class TestDefTypeMember:
622644
@pytest.mark.parametrize(
623645
"code",

0 commit comments

Comments
 (0)