Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Fix a bug where `#` characters were not legal in keywords and symbols (#1149)
* Fix a bug where seqs were not considered valid input for matching clauses of the `case` macro (#1148)
* Fix a bug where `py->lisp` did not keywordize string keys potentially containing namespaces (#1156)
* Fix a bug where anonymous functions using the `#(...)` reader syntax were not properly expanded in a syntax quote (#1160)

## [v0.3.3]
### Added
Expand Down
17 changes: 11 additions & 6 deletions src/basilisp/lang/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,25 +1183,29 @@ def _read_function(ctx: ReaderContext) -> llist.PersistentList:
if ctx.is_in_anon_fn:
raise ctx.syntax_error("Nested #() definitions not allowed")

current_ns = get_current_ns()

with ctx.in_anon_fn():
form = _read_list(ctx)
arg_set = set()

def arg_suffix(arg_num):
def arg_suffix(arg_num: Optional[str]) -> str:
if arg_num is None:
return "1"
elif arg_num == "&":
return "rest"
else:
return arg_num

def sym_replacement(arg_num):
def sym_replacement(arg_num: Optional[str]) -> sym.Symbol:
suffix = arg_suffix(arg_num)
if ctx.is_syntax_quoted:
suffix = f"{suffix}#"
return sym.symbol(f"arg-{suffix}")

def identify_and_replace(f):
if isinstance(f, sym.Symbol):
if f.ns is None:
if f.ns is None or f.ns == current_ns.name:
match = fn_macro_args.match(f.name)
if match is not None:
arg_num = match.group(2)
Expand All @@ -1217,9 +1221,10 @@ def identify_and_replace(f):
if len(numbered_args) > 0:
max_arg = max(numbered_args)
arg_list = [sym_replacement(str(i)) for i in range(1, max_arg + 1)]
if "rest" in arg_set:
arg_list.append(_AMPERSAND)
arg_list.append(sym_replacement("rest"))

if "rest" in arg_set:
arg_list.append(_AMPERSAND)
arg_list.append(sym_replacement("rest"))
Comment on lines +1229 to +1231
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a separate bug where #(vec %&) forms didn't expand correctly with or without syntax quotes to capture the rest argument.


return llist.l(_FN, vec.vector(arg_list), body)

Expand Down
45 changes: 45 additions & 0 deletions tests/basilisp/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5870,6 +5870,51 @@ def test_syntax_quoting(test_ns: str, lcompile: CompileFn, resolver: reader.Reso
)


def test_syntax_quoting_anonymous_fns_with_single_arg(
test_ns: str, lcompile: CompileFn, resolver: reader.Resolver
):
single_arg_fn = lcompile("`#(println %)", resolver=resolver)
assert single_arg_fn.first == sym.symbol("fn*")
single_arg_vec = runtime.nth(single_arg_fn, 1)
assert isinstance(single_arg_vec, vec.PersistentVector)
single_arg = single_arg_vec[0]
assert isinstance(single_arg, sym.Symbol)
assert re.match(r"arg-1_\d+", single_arg.name) is not None
println_call = runtime.nth(single_arg_fn, 2)
assert runtime.nth(println_call, 1) == single_arg


def test_syntax_quoting_anonymous_fns_with_multiple_args(
test_ns: str, lcompile: CompileFn, resolver: reader.Resolver
):
multi_arg_fn = lcompile("`#(vector %1 %2 %3)", resolver=resolver)
assert multi_arg_fn.first == sym.symbol("fn*")
multi_arg_vec = runtime.nth(multi_arg_fn, 1)
assert isinstance(multi_arg_vec, vec.PersistentVector)

for arg in multi_arg_vec:
assert isinstance(arg, sym.Symbol)
assert re.match(r"arg-\d_\d+", arg.name) is not None

vector_call = runtime.nth(multi_arg_fn, 2)
assert vec.vector(runtime.nthrest(vector_call, 1)) == multi_arg_vec


def test_syntax_quoting_anonymous_fns_with_rest_arg(
test_ns: str, lcompile: CompileFn, resolver: reader.Resolver
):
rest_arg_fn = lcompile("`#(vec %&)", resolver=resolver)
assert rest_arg_fn.first == sym.symbol("fn*")
rest_arg_vec = runtime.nth(rest_arg_fn, 1)
assert isinstance(rest_arg_vec, vec.PersistentVector)
assert rest_arg_vec[0] == sym.symbol("&")
rest_arg = rest_arg_vec[1]
assert isinstance(rest_arg, sym.Symbol)
assert re.match(r"arg-rest_\d+", rest_arg.name) is not None
vec_call = runtime.nth(rest_arg_fn, 2)
assert runtime.nth(vec_call, 1) == rest_arg


class TestThrow:
def test_throw_not_enough_args(self, lcompile: CompileFn):
with pytest.raises(compiler.CompilerException):
Expand Down
147 changes: 101 additions & 46 deletions tests/basilisp/reader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1793,54 +1793,109 @@ def test_splicing_form_in_maps(self):
)


def test_function_reader_macro():
assert read_str_first("#()") == llist.l(sym.symbol("fn*"), vec.v(), None)
assert read_str_first("#(identity %)") == llist.l(
sym.symbol("fn*"),
vec.v(sym.symbol("arg-1")),
llist.l(sym.symbol("identity"), sym.symbol("arg-1")),
)
assert read_str_first("#(identity %1)") == llist.l(
sym.symbol("fn*"),
vec.v(sym.symbol("arg-1")),
llist.l(sym.symbol("identity"), sym.symbol("arg-1")),
)
assert read_str_first("#(identity %& %1)") == llist.l(
sym.symbol("fn*"),
vec.v(sym.symbol("arg-1"), sym.symbol("&"), sym.symbol("arg-rest")),
llist.l(sym.symbol("identity"), sym.symbol("arg-rest"), sym.symbol("arg-1")),
)
assert read_str_first("#(identity %3)") == llist.l(
sym.symbol("fn*"),
vec.v(sym.symbol("arg-1"), sym.symbol("arg-2"), sym.symbol("arg-3")),
llist.l(sym.symbol("identity"), sym.symbol("arg-3")),
)
assert read_str_first("#(identity %3 %&)") == llist.l(
sym.symbol("fn*"),
vec.v(
sym.symbol("arg-1"),
sym.symbol("arg-2"),
sym.symbol("arg-3"),
sym.symbol("&"),
sym.symbol("arg-rest"),
),
llist.l(sym.symbol("identity"), sym.symbol("arg-3"), sym.symbol("arg-rest")),
)
assert read_str_first("#(identity {:arg %})") == llist.l(
sym.symbol("fn*"),
vec.v(
sym.symbol("arg-1"),
),
llist.l(
sym.symbol("identity"), lmap.map({kw.keyword("arg"): sym.symbol("arg-1")})
),
class TestFunctionReaderMacro:
@pytest.mark.parametrize(
"code,v",
[
("#()", llist.l(sym.symbol("fn*"), vec.v(), None)),
(
"#(identity %)",
llist.l(
sym.symbol("fn*"),
vec.v(sym.symbol("arg-1")),
llist.l(sym.symbol("identity"), sym.symbol("arg-1")),
),
),
(
"#(identity %1)",
llist.l(
sym.symbol("fn*"),
vec.v(sym.symbol("arg-1")),
llist.l(sym.symbol("identity"), sym.symbol("arg-1")),
),
),
(
"#(identity %& %1)",
llist.l(
sym.symbol("fn*"),
vec.v(sym.symbol("arg-1"), sym.symbol("&"), sym.symbol("arg-rest")),
llist.l(
sym.symbol("identity"),
sym.symbol("arg-rest"),
sym.symbol("arg-1"),
),
),
),
(
"#(identity %3)",
llist.l(
sym.symbol("fn*"),
vec.v(
sym.symbol("arg-1"), sym.symbol("arg-2"), sym.symbol("arg-3")
),
llist.l(sym.symbol("identity"), sym.symbol("arg-3")),
),
),
(
"#(identity %3 %&)",
llist.l(
sym.symbol("fn*"),
vec.v(
sym.symbol("arg-1"),
sym.symbol("arg-2"),
sym.symbol("arg-3"),
sym.symbol("&"),
sym.symbol("arg-rest"),
),
llist.l(
sym.symbol("identity"),
sym.symbol("arg-3"),
sym.symbol("arg-rest"),
),
),
),
(
"#(identity {:arg %})",
llist.l(
sym.symbol("fn*"),
vec.v(
sym.symbol("arg-1"),
),
llist.l(
sym.symbol("identity"),
lmap.map({kw.keyword("arg"): sym.symbol("arg-1")}),
),
),
),
(
"#(vec %&)",
llist.l(
sym.symbol("fn*"),
vec.v(sym.symbol("&"), sym.symbol("arg-rest")),
llist.l(sym.symbol("vec"), sym.symbol("arg-rest")),
),
),
(
"#(vector %1 %&)",
llist.l(
sym.symbol("fn*"),
vec.v(sym.symbol("arg-1"), sym.symbol("&"), sym.symbol("arg-rest")),
llist.l(
sym.symbol("vector"),
sym.symbol("arg-1"),
sym.symbol("arg-rest"),
),
),
),
],
)
def test_function_reader_macro(self, code: str, v):
assert v == read_str_first(code)

with pytest.raises(reader.SyntaxError):
read_str_first("#(identity #(%1 %2))")

with pytest.raises(reader.SyntaxError):
read_str_first("#app/ermagrd [1 2 3]")
@pytest.mark.parametrize("code", ["#(identity #(%1 %2))", "#app/ermagrd [1 2 3]"])
def test_invalid_function_reader_macro(self, code: str):
with pytest.raises(reader.SyntaxError):
read_str_first(code)


def test_deref():
Expand Down
Loading