diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c7d265a..424ed166 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Fix a bug where seqs were not considered valid input for matching clauses of the `case` macro (#1148) * Fix a bug where `py->lisp` did not keywordize string keys potentially containing namespaces (#1156) * Fix a bug where anonymous functions using the `#(...)` reader syntax were not properly expanded in a syntax quote (#1160) + * Fix a bug where certain types of objects (such as objects created via `deftype`) could not be unquoted correctly in macros (#1153) ## [v0.3.3] ### Added diff --git a/src/basilisp/lang/compiler/generator.py b/src/basilisp/lang/compiler/generator.py index ff3e235d..f775fe72 100644 --- a/src/basilisp/lang/compiler/generator.py +++ b/src/basilisp/lang/compiler/generator.py @@ -7,6 +7,7 @@ import functools import hashlib import logging +import pickle # nosec B403 import re import uuid from collections.abc import Collection, Iterable, Mapping, MutableMapping @@ -95,7 +96,7 @@ ast_ClassDef, ast_FunctionDef, ) -from basilisp.lang.interfaces import IMeta, IRecord, ISeq, ISeqable, IType +from basilisp.lang.interfaces import IMeta, ISeq from basilisp.lang.runtime import CORE_NS from basilisp.lang.runtime import NS_VAR_NAME as LISP_NS_VAR from basilisp.lang.runtime import BasilispModule, Var @@ -764,6 +765,7 @@ def _var_ns_as_python_sym(name: str) -> str: _ATTR_CLASS_DECORATOR_NAME = _load_attr(f"{_ATTR_ALIAS}.define") _ATTR_FROZEN_DECORATOR_NAME = _load_attr(f"{_ATTR_ALIAS}.frozen") _ATTRIB_FIELD_FN_NAME = _load_attr(f"{_ATTR_ALIAS}.field") +_BASILISP_LOAD_CONSTANT_NAME = _load_attr(f"{_RUNTIME_ALIAS}._load_constant") _COERCE_SEQ_FN_NAME = _load_attr(f"{_RUNTIME_ALIAS}.to_seq") _BASILISP_FN_FN_NAME = _load_attr(f"{_RUNTIME_ALIAS}._basilisp_fn") _FN_WITH_ATTRS_FN_NAME = _load_attr(f"{_RUNTIME_ALIAS}._with_attrs") @@ -3559,9 +3561,24 @@ def _const_val_to_py_ast( structures need to call into this function to generate Python AST nodes for nested elements. For top-level :const Lisp AST nodes, see `_const_node_to_py_ast`.""" - raise ctx.GeneratorException( - f"No constant handler is defined for type {type(form)}" - ) + try: + serialized = pickle.dumps(form) + except (pickle.PicklingError, RecursionError) as e: + # For types without custom "constant" handling code, we defer to pickle + # to generate a representation that can be reloaded from the generated + # byte code. There are a few cases where that may not be possible for one + # reason or another, in which case we'll fail here. + raise ctx.GeneratorException( + f"Unable to emit bytecode for generating a constant {type(form)}" + ) from e + else: + return GeneratedPyAST( + node=ast.Call( + func=_BASILISP_LOAD_CONSTANT_NAME, + args=[ast.Constant(value=serialized)], + keywords=[], + ), + ) def _collection_literal_to_py_ast( @@ -3777,54 +3794,6 @@ def _const_set_to_py_ast( ) -@_const_val_to_py_ast.register(IRecord) -def _const_record_to_py_ast( - form: IRecord, ctx: GeneratorContext -) -> GeneratedPyAST[ast.expr]: - assert isinstance(form, IRecord) and isinstance( - form, ISeqable - ), "IRecord types should also be ISeq" - - tp = type(form) - assert hasattr(tp, "create") and callable( - tp.create - ), "IRecord and IType must declare a .create class method" - - form_seq = runtime.to_seq(form) - assert form_seq is not None, "IRecord types must be iterable" - - # pylint: disable=no-member - keys: list[Optional[ast.expr]] = [] - vals: list[ast.expr] = [] - vals_deps: list[PyASTNode] = [] - for k, v in form_seq: - assert isinstance(k, kw.Keyword), "Record key in seq must be keyword" - key_nodes = _kw_to_py_ast(k, ctx) - keys.append(key_nodes.node) - assert ( - not key_nodes.dependencies - ), "Simple AST generators must emit no dependencies" - - val_nodes = _const_val_to_py_ast(v, ctx) - vals.append(val_nodes.node) - vals_deps.extend(val_nodes.dependencies) - - return GeneratedPyAST( - node=ast.Call( - func=_load_attr(f"{tp.__qualname__}.create"), - args=[ - ast.Call( - func=_NEW_MAP_FN_NAME, - args=[ast.Dict(keys=keys, values=vals)], - keywords=[], - ) - ], - keywords=[], - ), - dependencies=vals_deps, - ) - - @_const_val_to_py_ast.register(llist.PersistentList) @_const_val_to_py_ast.register(ISeq) def _const_seq_to_py_ast( @@ -3849,25 +3818,6 @@ def _const_seq_to_py_ast( ) -@_const_val_to_py_ast.register(IType) -def _const_type_to_py_ast( - form: IType, ctx: GeneratorContext -) -> GeneratedPyAST[ast.expr]: - tp = type(form) - - ctor_args = [] - ctor_arg_deps: list[PyASTNode] = [] - for field in attr.fields(tp): # type: ignore[arg-type, misc, unused-ignore] - field_nodes = _const_val_to_py_ast(getattr(form, field.name, None), ctx) - ctor_args.append(field_nodes.node) - ctor_args.extend(field_nodes.dependencies) # type: ignore[arg-type] - - return GeneratedPyAST( - node=ast.Call(func=_load_attr(tp.__qualname__), args=ctor_args, keywords=[]), - dependencies=ctor_arg_deps, - ) - - @_const_val_to_py_ast.register(vec.PersistentVector) def _const_vec_to_py_ast( form: vec.PersistentVector, ctx: GeneratorContext diff --git a/src/basilisp/lang/runtime.py b/src/basilisp/lang/runtime.py index c98cd4de..dc1547ed 100644 --- a/src/basilisp/lang/runtime.py +++ b/src/basilisp/lang/runtime.py @@ -11,6 +11,7 @@ import logging import math import numbers +import pickle # nosec B403 import platform import re import sys @@ -2170,6 +2171,18 @@ def wrap_class(cls: type): return wrap_class +def _load_constant(s: bytes) -> Any: + """Load a compiler "constant" stored as a byte string as by Python's `pickle` + module. + + Constant types without special handling are emitted to bytecode as a byte string + produced by `pickle.dumps`.""" + try: + return pickle.loads(s) # nosec B301 + except pickle.UnpicklingError as e: + raise RuntimeException("Unable to load constant value") from e + + ############################### # Symbol and Alias Resolution # ############################### diff --git a/tests/basilisp/compiler_test.py b/tests/basilisp/compiler_test.py index dc3018af..e409300f 100644 --- a/tests/basilisp/compiler_test.py +++ b/tests/basilisp/compiler_test.py @@ -30,7 +30,7 @@ from basilisp.lang import vector as vec from basilisp.lang.compiler.constants import SYM_INLINE_META_KW, SYM_PRIVATE_META_KEY from basilisp.lang.exception import format_exception -from basilisp.lang.interfaces import IType, IWithMeta +from basilisp.lang.interfaces import IRecord, IType, IWithMeta from basilisp.lang.runtime import Var from basilisp.lang.util import demunge from tests.basilisp.helpers import CompileFn, get_or_create_ns @@ -3226,10 +3226,6 @@ def test_function_result_is_same_with_or_without_auto_inline( assert val == 7 -def test_macro_expansion(lcompile: CompileFn): - assert llist.l(1, 2, 3) == lcompile("((fn [] '(1 2 3)))") - - class TestMacroexpandFunctions: @pytest.fixture def example_macro(self, lcompile: CompileFn): @@ -4220,6 +4216,162 @@ def test_loop_with_recur(self, lcompile: CompileFn): assert "tester" == lcompile(code) +class TestMacros: + def test_macro_expansion(self, lcompile: CompileFn): + assert llist.l(1, 2, 3) == lcompile("((fn [] '(1 2 3)))") + + def test_syntax_quoting( + self, test_ns: str, lcompile: CompileFn, resolver: reader.Resolver + ): + code = """ + (def some-val \"some value!\") + + `(some-val)""" + assert llist.l(sym.symbol("some-val", ns=test_ns)) == lcompile( + code, resolver=resolver + ) + + code = """ + (def second-val \"some value!\") + + `(other-val)""" + assert llist.l(sym.symbol("other-val")) == lcompile(code) + + code = """ + (def a-str \"a definite string\") + (def a-number 1583) + + `(a-str ~a-number)""" + assert llist.l(sym.symbol("a-str", ns=test_ns), 1583) == lcompile( + code, resolver=resolver + ) + + code = """ + (def whatever \"yes, whatever\") + (def ssss \"a snake\") + + `(whatever ~@[ssss 45])""" + assert llist.l(sym.symbol("whatever", ns=test_ns), "a snake", 45) == lcompile( + code, resolver=resolver + ) + + assert llist.l(sym.symbol("my-symbol", ns=test_ns)) == lcompile( + "`(my-symbol)", resolver=resolver + ) + + def test_syntax_quoting_anonymous_fns_with_single_arg( + self, test_ns: str, lcompile: CompileFn, resolver: reader.Resolver + ): + single_arg_fn = lcompile("`#(println %)", resolver=resolver) + assert single_arg_fn.first == sym.symbol("fn*") + single_arg_vec = runtime.nth(single_arg_fn, 1) + assert isinstance(single_arg_vec, vec.PersistentVector) + single_arg = single_arg_vec[0] + assert isinstance(single_arg, sym.Symbol) + assert re.match(r"arg-1_\d+", single_arg.name) is not None + println_call = runtime.nth(single_arg_fn, 2) + assert runtime.nth(println_call, 1) == single_arg + + def test_syntax_quoting_anonymous_fns_with_multiple_args( + self, test_ns: str, lcompile: CompileFn, resolver: reader.Resolver + ): + multi_arg_fn = lcompile("`#(vector %1 %2 %3)", resolver=resolver) + assert multi_arg_fn.first == sym.symbol("fn*") + multi_arg_vec = runtime.nth(multi_arg_fn, 1) + assert isinstance(multi_arg_vec, vec.PersistentVector) + + for arg in multi_arg_vec: + assert isinstance(arg, sym.Symbol) + assert re.match(r"arg-\d_\d+", arg.name) is not None + + vector_call = runtime.nth(multi_arg_fn, 2) + assert vec.vector(runtime.nthrest(vector_call, 1)) == multi_arg_vec + + def test_syntax_quoting_anonymous_fns_with_rest_arg( + self, test_ns: str, lcompile: CompileFn, resolver: reader.Resolver + ): + rest_arg_fn = lcompile("`#(vec %&)", resolver=resolver) + assert rest_arg_fn.first == sym.symbol("fn*") + rest_arg_vec = runtime.nth(rest_arg_fn, 1) + assert isinstance(rest_arg_vec, vec.PersistentVector) + assert rest_arg_vec[0] == sym.symbol("&") + rest_arg = rest_arg_vec[1] + assert isinstance(rest_arg, sym.Symbol) + assert re.match(r"arg-rest_\d+", rest_arg.name) is not None + vec_call = runtime.nth(rest_arg_fn, 2) + assert runtime.nth(vec_call, 1) == rest_arg + + @pytest.mark.parametrize("code,v", [("`(s)", llist.l(sym.symbol("s")))]) + def test_unquote(self, lcompile: CompileFn, code: str, v): + assert v == lcompile(code) + + def test_unquote_arbitrary_deftypes(self, lcompile: CompileFn): + v = lcompile( + """ + (deftype Point [a b c]) + + (defmacro make-point + [a b c] + `(identity ~(Point. a b c))) + + (make-point 1 2 3) + """ + ) + + assert isinstance(v, IType) + assert (1, 2, 3) == (v.a, v.b, v.c) + + def test_unquote_arbitrary_defrecords(self, lcompile: CompileFn): + v = lcompile( + """ + (defrecord Point [a b c]) + + (defmacro make-point + [a b c] + `(identity ~(Point. a b c))) + + (make-point 1 2 3) + """ + ) + + assert isinstance(v, IRecord) + assert (1, 2, 3) == (v.a, v.b, v.c) + + @pytest.mark.parametrize("code", ["~s", "`(~s)"]) + def test_invalid_unquote(self, lcompile: CompileFn, code: str): + with pytest.raises(compiler.CompilerException): + lcompile(code) + + @pytest.mark.parametrize( + "code,v", + [ + ("`(~@[1 2 3])", llist.l(1, 2, 3)), + ("'(~@53233)", llist.l(llist.l(reader._UNQUOTE_SPLICING, 53233))), + ], + ) + def test_unquote_splicing(self, lcompile: CompileFn, code: str, v): + assert v == lcompile(code) + + @pytest.mark.parametrize( + "code,v", + [ + ( + "`(print ~@[1 2 3])", + llist.l(sym.symbol("print", ns="basilisp.core"), 1, 2, 3), + ) + ], + ) + def test_unquote_splicing_with_resolver( + self, lcompile: CompileFn, resolver: reader.Resolver, code: str, v + ): + assert v == lcompile(code, resolver=resolver) + + @pytest.mark.parametrize("code", ["~@[1 2 3]"]) + def test_invalid_unquote_splicing(self, lcompile: CompileFn, code: str): + with pytest.raises(TypeError): + lcompile(code) + + class TestQuote: @pytest.mark.parametrize("code", ["(quote)", "(quote form other-form)"]) def test_quote_num_elems(self, lcompile: CompileFn, code: str): @@ -5832,89 +5984,6 @@ def test_set_can_object_attrs(self, lcompile: CompileFn, ns: runtime.Namespace): assert "now a string" == var.value.some_field -def test_syntax_quoting(test_ns: str, lcompile: CompileFn, resolver: reader.Resolver): - code = """ - (def some-val \"some value!\") - - `(some-val)""" - assert llist.l(sym.symbol("some-val", ns=test_ns)) == lcompile( - code, resolver=resolver - ) - - code = """ - (def second-val \"some value!\") - - `(other-val)""" - assert llist.l(sym.symbol("other-val")) == lcompile(code) - - code = """ - (def a-str \"a definite string\") - (def a-number 1583) - - `(a-str ~a-number)""" - assert llist.l(sym.symbol("a-str", ns=test_ns), 1583) == lcompile( - code, resolver=resolver - ) - - code = """ - (def whatever \"yes, whatever\") - (def ssss \"a snake\") - - `(whatever ~@[ssss 45])""" - assert llist.l(sym.symbol("whatever", ns=test_ns), "a snake", 45) == lcompile( - code, resolver=resolver - ) - - assert llist.l(sym.symbol("my-symbol", ns=test_ns)) == lcompile( - "`(my-symbol)", resolver=resolver - ) - - -def test_syntax_quoting_anonymous_fns_with_single_arg( - test_ns: str, lcompile: CompileFn, resolver: reader.Resolver -): - single_arg_fn = lcompile("`#(println %)", resolver=resolver) - assert single_arg_fn.first == sym.symbol("fn*") - single_arg_vec = runtime.nth(single_arg_fn, 1) - assert isinstance(single_arg_vec, vec.PersistentVector) - single_arg = single_arg_vec[0] - assert isinstance(single_arg, sym.Symbol) - assert re.match(r"arg-1_\d+", single_arg.name) is not None - println_call = runtime.nth(single_arg_fn, 2) - assert runtime.nth(println_call, 1) == single_arg - - -def test_syntax_quoting_anonymous_fns_with_multiple_args( - test_ns: str, lcompile: CompileFn, resolver: reader.Resolver -): - multi_arg_fn = lcompile("`#(vector %1 %2 %3)", resolver=resolver) - assert multi_arg_fn.first == sym.symbol("fn*") - multi_arg_vec = runtime.nth(multi_arg_fn, 1) - assert isinstance(multi_arg_vec, vec.PersistentVector) - - for arg in multi_arg_vec: - assert isinstance(arg, sym.Symbol) - assert re.match(r"arg-\d_\d+", arg.name) is not None - - vector_call = runtime.nth(multi_arg_fn, 2) - assert vec.vector(runtime.nthrest(vector_call, 1)) == multi_arg_vec - - -def test_syntax_quoting_anonymous_fns_with_rest_arg( - test_ns: str, lcompile: CompileFn, resolver: reader.Resolver -): - rest_arg_fn = lcompile("`#(vec %&)", resolver=resolver) - assert rest_arg_fn.first == sym.symbol("fn*") - rest_arg_vec = runtime.nth(rest_arg_fn, 1) - assert isinstance(rest_arg_vec, vec.PersistentVector) - assert rest_arg_vec[0] == sym.symbol("&") - rest_arg = rest_arg_vec[1] - assert isinstance(rest_arg, sym.Symbol) - assert re.match(r"arg-rest_\d+", rest_arg.name) is not None - vec_call = runtime.nth(rest_arg_fn, 2) - assert runtime.nth(vec_call, 1) == rest_arg - - class TestThrow: def test_throw_not_enough_args(self, lcompile: CompileFn): with pytest.raises(compiler.CompilerException): @@ -6138,29 +6207,6 @@ def test_try_may_not_have_multiple_finallys(self, lcompile: CompileFn): ) -def test_unquote(lcompile: CompileFn): - with pytest.raises(compiler.CompilerException): - lcompile("~s") - - assert llist.l(sym.symbol("s")) == lcompile("`(s)") - - with pytest.raises(compiler.CompilerException): - lcompile("`(~s)") - - -def test_unquote_splicing(lcompile: CompileFn, resolver: reader.Resolver): - with pytest.raises(TypeError): - lcompile("~@[1 2 3]") - - assert llist.l(1, 2, 3) == lcompile("`(~@[1 2 3])") - - assert llist.l(sym.symbol("print", ns="basilisp.core"), 1, 2, 3) == lcompile( - "`(print ~@[1 2 3])", resolver=resolver - ) - - assert llist.l(llist.l(reader._UNQUOTE_SPLICING, 53233)) == lcompile("'(~@53233)") - - class TestSymbolResolution: def test_bare_sym_resolves_builtins(self, lcompile: CompileFn): assert object is lcompile("object") @@ -6736,6 +6782,52 @@ def test_cross_ns_macro_symbol_resolution_with_refers( runtime.Namespace.remove(other_ns_name) runtime.Namespace.remove(third_ns_name) + def test_cross_ns_macro_deftype_symbol_resolution( + self, + lcompile: CompileFn, + tmp_path: pathlib.Path, + monkeypatch, + ): + """""" + cross_ns_deftype_test_dir = tmp_path / "cross_ns_deftype_test" + cross_ns_deftype_test_dir.mkdir(parents=True, exist_ok=True) + monkeypatch.syspath_prepend(tmp_path) + + other_ns_file = cross_ns_deftype_test_dir / "other.lpy" + other_ns_file.write_text( + """ + (ns cross-ns-deftype-test.other) + + (deftype TypeOther []) + """ + ) + + main_ns_file = cross_ns_deftype_test_dir / "issue.lpy" + main_ns_file.write_text( + """ + (ns cross-ns-deftype-test.issue + (:require cross-ns-deftype-test.other)) + + (defmacro issue [] + `(identity ~(cross-ns-deftype-test.other/TypeOther))) + + (def result [(issue) (cross-ns-deftype-test.other/TypeOther)]) + """ + ) + + result = lcompile( + """ + (require 'cross-ns-deftype-test.issue) + + cross-ns-deftype-test.issue/result + """ + ) + + assert isinstance(result[0], type(result[1])) + + sys.modules.pop("cross_ns_deftype_test.issue") + sys.modules.pop("cross_ns_deftype_test.other") + class TestWarnOnArityMismatch: def test_warning_on_arity_mismatch( diff --git a/tests/basilisp/conftest.py b/tests/basilisp/conftest.py index 66bd1c42..b1bc6eca 100644 --- a/tests/basilisp/conftest.py +++ b/tests/basilisp/conftest.py @@ -36,9 +36,11 @@ def test_ns_sym(test_ns: str) -> sym.Symbol: def ns(test_ns: str, test_ns_sym: sym.Symbol) -> runtime.Namespace: get_or_create_ns(test_ns_sym) with runtime.ns_bindings(test_ns) as ns: + sys.modules[ns.module.__name__] = ns.module try: yield ns finally: + del sys.modules[ns.module.__name__] runtime.Namespace.remove(test_ns_sym)