From 265065579dc51307f404b1bd33b4fe5384ede283 Mon Sep 17 00:00:00 2001 From: Chris Rink Date: Wed, 4 Dec 2024 11:25:49 -0500 Subject: [PATCH 1/2] Fix a bug with syntax quoted anonymous functions --- CHANGELOG.md | 1 + src/basilisp/lang/reader.py | 17 ++-- tests/basilisp/compiler_test.py | 45 ++++++++++ tests/basilisp/reader_test.py | 147 ++++++++++++++++++++++---------- 4 files changed, 158 insertions(+), 52 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 96e44469..bf7c39ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Fix a bug where `#` characters were not legal in keywords and symbols (#1149) * Fix a bug where seqs were not considered valid input for matching clauses of the `case` macro (#1148) * Fix a bug where `py->lisp` did not keywordize string keys potentially containing namespaces (#1156) + * Fix a bug where anonymous functions using the `#(...)` reader syntax were not properly expanded in a syntax quote (#1160) ## [v0.3.3] ### Added diff --git a/src/basilisp/lang/reader.py b/src/basilisp/lang/reader.py index c2961198..ea9624d4 100644 --- a/src/basilisp/lang/reader.py +++ b/src/basilisp/lang/reader.py @@ -1183,11 +1183,13 @@ def _read_function(ctx: ReaderContext) -> llist.PersistentList: if ctx.is_in_anon_fn: raise ctx.syntax_error("Nested #() definitions not allowed") + current_ns = get_current_ns() + with ctx.in_anon_fn(): form = _read_list(ctx) arg_set = set() - def arg_suffix(arg_num): + def arg_suffix(arg_num: Optional[str]) -> str: if arg_num is None: return "1" elif arg_num == "&": @@ -1195,13 +1197,15 @@ def arg_suffix(arg_num): else: return arg_num - def sym_replacement(arg_num): + def sym_replacement(arg_num: Optional[str]) -> sym.Symbol: suffix = arg_suffix(arg_num) + if ctx.is_syntax_quoted: + suffix = f"{suffix}#" return sym.symbol(f"arg-{suffix}") def identify_and_replace(f): if isinstance(f, sym.Symbol): - if f.ns is None: + if f.ns is None or f.ns == current_ns.name: match = fn_macro_args.match(f.name) if match is not None: arg_num = match.group(2) @@ -1217,9 +1221,10 @@ def identify_and_replace(f): if len(numbered_args) > 0: max_arg = max(numbered_args) arg_list = [sym_replacement(str(i)) for i in range(1, max_arg + 1)] - if "rest" in arg_set: - arg_list.append(_AMPERSAND) - arg_list.append(sym_replacement("rest")) + + if "rest" in arg_set: + arg_list.append(_AMPERSAND) + arg_list.append(sym_replacement("rest")) return llist.l(_FN, vec.vector(arg_list), body) diff --git a/tests/basilisp/compiler_test.py b/tests/basilisp/compiler_test.py index 751cf8fc..dc3018af 100644 --- a/tests/basilisp/compiler_test.py +++ b/tests/basilisp/compiler_test.py @@ -5870,6 +5870,51 @@ def test_syntax_quoting(test_ns: str, lcompile: CompileFn, resolver: reader.Reso ) +def test_syntax_quoting_anonymous_fns_with_single_arg( + test_ns: str, lcompile: CompileFn, resolver: reader.Resolver +): + single_arg_fn = lcompile("`#(println %)", resolver=resolver) + assert single_arg_fn.first == sym.symbol("fn*") + single_arg_vec = runtime.nth(single_arg_fn, 1) + assert isinstance(single_arg_vec, vec.PersistentVector) + single_arg = single_arg_vec[0] + assert isinstance(single_arg, sym.Symbol) + assert re.match(r"arg-1_\d+", single_arg.name) is not None + println_call = runtime.nth(single_arg_fn, 2) + assert runtime.nth(println_call, 1) == single_arg + + +def test_syntax_quoting_anonymous_fns_with_multiple_args( + test_ns: str, lcompile: CompileFn, resolver: reader.Resolver +): + multi_arg_fn = lcompile("`#(vector %1 %2 %3)", resolver=resolver) + assert multi_arg_fn.first == sym.symbol("fn*") + multi_arg_vec = runtime.nth(multi_arg_fn, 1) + assert isinstance(multi_arg_vec, vec.PersistentVector) + + for arg in multi_arg_vec: + assert isinstance(arg, sym.Symbol) + assert re.match(r"arg-\d_\d+", arg.name) is not None + + vector_call = runtime.nth(multi_arg_fn, 2) + assert vec.vector(runtime.nthrest(vector_call, 1)) == multi_arg_vec + + +def test_syntax_quoting_anonymous_fns_with_rest_arg( + test_ns: str, lcompile: CompileFn, resolver: reader.Resolver +): + rest_arg_fn = lcompile("`#(vec %&)", resolver=resolver) + assert rest_arg_fn.first == sym.symbol("fn*") + rest_arg_vec = runtime.nth(rest_arg_fn, 1) + assert isinstance(rest_arg_vec, vec.PersistentVector) + assert rest_arg_vec[0] == sym.symbol("&") + rest_arg = rest_arg_vec[1] + assert isinstance(rest_arg, sym.Symbol) + assert re.match(r"arg-rest_\d+", rest_arg.name) is not None + vec_call = runtime.nth(rest_arg_fn, 2) + assert runtime.nth(vec_call, 1) == rest_arg + + class TestThrow: def test_throw_not_enough_args(self, lcompile: CompileFn): with pytest.raises(compiler.CompilerException): diff --git a/tests/basilisp/reader_test.py b/tests/basilisp/reader_test.py index abe1c80d..a089d2a8 100644 --- a/tests/basilisp/reader_test.py +++ b/tests/basilisp/reader_test.py @@ -1793,54 +1793,109 @@ def test_splicing_form_in_maps(self): ) -def test_function_reader_macro(): - assert read_str_first("#()") == llist.l(sym.symbol("fn*"), vec.v(), None) - assert read_str_first("#(identity %)") == llist.l( - sym.symbol("fn*"), - vec.v(sym.symbol("arg-1")), - llist.l(sym.symbol("identity"), sym.symbol("arg-1")), - ) - assert read_str_first("#(identity %1)") == llist.l( - sym.symbol("fn*"), - vec.v(sym.symbol("arg-1")), - llist.l(sym.symbol("identity"), sym.symbol("arg-1")), - ) - assert read_str_first("#(identity %& %1)") == llist.l( - sym.symbol("fn*"), - vec.v(sym.symbol("arg-1"), sym.symbol("&"), sym.symbol("arg-rest")), - llist.l(sym.symbol("identity"), sym.symbol("arg-rest"), sym.symbol("arg-1")), - ) - assert read_str_first("#(identity %3)") == llist.l( - sym.symbol("fn*"), - vec.v(sym.symbol("arg-1"), sym.symbol("arg-2"), sym.symbol("arg-3")), - llist.l(sym.symbol("identity"), sym.symbol("arg-3")), - ) - assert read_str_first("#(identity %3 %&)") == llist.l( - sym.symbol("fn*"), - vec.v( - sym.symbol("arg-1"), - sym.symbol("arg-2"), - sym.symbol("arg-3"), - sym.symbol("&"), - sym.symbol("arg-rest"), - ), - llist.l(sym.symbol("identity"), sym.symbol("arg-3"), sym.symbol("arg-rest")), - ) - assert read_str_first("#(identity {:arg %})") == llist.l( - sym.symbol("fn*"), - vec.v( - sym.symbol("arg-1"), - ), - llist.l( - sym.symbol("identity"), lmap.map({kw.keyword("arg"): sym.symbol("arg-1")}) - ), +class TestFunctionReaderMacro: + @pytest.mark.parametrize( + "code,v", + [ + ("#()", llist.l(sym.symbol("fn*"), vec.v(), None)), + ( + "#(identity %)", + llist.l( + sym.symbol("fn*"), + vec.v(sym.symbol("arg-1")), + llist.l(sym.symbol("identity"), sym.symbol("arg-1")), + ), + ), + ( + "#(identity %1)", + llist.l( + sym.symbol("fn*"), + vec.v(sym.symbol("arg-1")), + llist.l(sym.symbol("identity"), sym.symbol("arg-1")), + ), + ), + ( + "#(identity %& %1)", + llist.l( + sym.symbol("fn*"), + vec.v(sym.symbol("arg-1"), sym.symbol("&"), sym.symbol("arg-rest")), + llist.l( + sym.symbol("identity"), + sym.symbol("arg-rest"), + sym.symbol("arg-1"), + ), + ), + ), + ( + "#(identity %3)", + llist.l( + sym.symbol("fn*"), + vec.v( + sym.symbol("arg-1"), sym.symbol("arg-2"), sym.symbol("arg-3") + ), + llist.l(sym.symbol("identity"), sym.symbol("arg-3")), + ), + ), + ( + "#(identity %3 %&)", + llist.l( + sym.symbol("fn*"), + vec.v( + sym.symbol("arg-1"), + sym.symbol("arg-2"), + sym.symbol("arg-3"), + sym.symbol("&"), + sym.symbol("arg-rest"), + ), + llist.l( + sym.symbol("identity"), + sym.symbol("arg-3"), + sym.symbol("arg-rest"), + ), + ), + ), + ( + "#(identity {:arg %})", + llist.l( + sym.symbol("fn*"), + vec.v( + sym.symbol("arg-1"), + ), + llist.l( + sym.symbol("identity"), + lmap.map({kw.keyword("arg"): sym.symbol("arg-1")}), + ), + ), + ), + ( + "#(vec %&)", + llist.l( + sym.symbol("fn*"), + vec.v(sym.symbol("&"), sym.symbol("arg-rest")), + llist.l(sym.symbol("vec"), sym.symbol("arg-rest")), + ), + ), + ( + "#(vector %1 %&)", + llist.l( + sym.symbol("fn*"), + vec.v(sym.symbol("arg-1"), sym.symbol("&"), sym.symbol("arg-rest")), + llist.l( + sym.symbol("vector"), + sym.symbol("arg-1"), + sym.symbol("arg-rest"), + ), + ), + ), + ], ) + def test_function_reader_macro(self, code: str, v): + assert v == read_str_first(code) - with pytest.raises(reader.SyntaxError): - read_str_first("#(identity #(%1 %2))") - - with pytest.raises(reader.SyntaxError): - read_str_first("#app/ermagrd [1 2 3]") + @pytest.mark.parametrize("code", ["#(identity #(%1 %2))", "#app/ermagrd [1 2 3]"]) + def test_invalid_function_reader_macro(self, code: str): + with pytest.raises(reader.SyntaxError): + read_str_first(code) def test_deref(): From 963550ba1f9674c8036eb6cf38eab8ac9cb202ff Mon Sep 17 00:00:00 2001 From: Chris Rink Date: Wed, 4 Dec 2024 11:29:05 -0500 Subject: [PATCH 2/2] Comment --- src/basilisp/lang/reader.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/basilisp/lang/reader.py b/src/basilisp/lang/reader.py index ea9624d4..1f50fb0f 100644 --- a/src/basilisp/lang/reader.py +++ b/src/basilisp/lang/reader.py @@ -1205,6 +1205,10 @@ def sym_replacement(arg_num: Optional[str]) -> sym.Symbol: def identify_and_replace(f): if isinstance(f, sym.Symbol): + # Checking against the current namespace is generally only used for + # when anonymous function definitions are syntax quoted. Arguments + # are resolved in terms of the current namespace, so we simply check + # if the symbol namespace matches the current runtime namespace. if f.ns is None or f.ns == current_ns.name: match = fn_macro_args.match(f.name) if match is not None: