Skip to content

Commit 4a25004

Browse files
authored
Macroexpand convenience function (#394)
* Imports in the time of __init__ * Tests * Allow Python submodules to Basilisp modules * Macroexpansion except way overcomplicated * Actually good, simple macroexpand function * Rebuild * Make MyPy happy * plz
1 parent 0aec10a commit 4a25004

File tree

4 files changed

+122
-20
lines changed

4 files changed

+122
-20
lines changed

src/basilisp/core.lpy

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,19 @@
467467
~fmeta
468468
~@body)))
469469

470+
(defn macroexpand-1
471+
"Macroexpand form one time. Returns the macroexpanded form. The return
472+
value may still represent a macro. Does not macroexpand child forms."
473+
[form]
474+
(basilisp.lang.compiler/macroexpand-1 form))
475+
476+
(defn macroexpand
477+
"Repeatedly macroexpand form as by macroexpand-1 until form no longer
478+
represents a macro. Returns the expanded form. Does not macroexpand child
479+
forms."
480+
[form]
481+
(basilisp.lang.compiler/macroexpand form))
482+
470483
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
471484
;; Logical Comparisons & Macros ;;
472485
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

src/basilisp/lang/compiler/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
WARN_ON_SHADOWED_VAR,
1414
WARN_ON_UNUSED_NAMES,
1515
analyze_form,
16+
macroexpand,
17+
macroexpand_1,
1618
)
1719
from basilisp.lang.compiler.exception import CompilerException, CompilerPhase # noqa
1820
from basilisp.lang.compiler.generator import ( # noqa

src/basilisp/lang/compiler/analyzer.py

Lines changed: 54 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -281,11 +281,15 @@ class AnalyzerContext:
281281
"_is_quoted",
282282
"_opts",
283283
"_recur_points",
284+
"_should_macroexpand",
284285
"_st",
285286
)
286287

287288
def __init__(
288-
self, filename: Optional[str] = None, opts: Optional[Mapping[str, bool]] = None
289+
self,
290+
filename: Optional[str] = None,
291+
opts: Optional[Mapping[str, bool]] = None,
292+
should_macroexpand: bool = True,
289293
) -> None:
290294
self._filename = Maybe(filename).or_else_get(DEFAULT_COMPILER_FILE_PATH)
291295
self._func_ctx: Deque[bool] = collections.deque([])
@@ -294,6 +298,7 @@ def __init__(
294298
Maybe(opts).map(lmap.map).or_else_get(lmap.Map.empty()) # type: ignore
295299
)
296300
self._recur_points: Deque[RecurPoint] = collections.deque([])
301+
self._should_macroexpand = should_macroexpand
297302
self._st = collections.deque([SymbolTable("<Top>")])
298303

299304
@property
@@ -339,6 +344,10 @@ def quoted(self):
339344
yield
340345
self._is_quoted.pop()
341346

347+
@property
348+
def should_macroexpand(self) -> bool:
349+
return self._should_macroexpand
350+
342351
@property
343352
def is_async_ctx(self) -> bool:
344353
try:
@@ -1614,25 +1623,26 @@ def _invoke_ast(ctx: AnalyzerContext, form: Union[llist.List, ISeq]) -> Node:
16141623

16151624
if fn.op == NodeOp.VAR and isinstance(fn, VarRef):
16161625
if _is_macro(fn.var):
1617-
try:
1618-
macro_env = ctx.symbol_table.as_env_map()
1619-
expanded = fn.var.value(macro_env, form, *form.rest)
1620-
expanded_ast = _analyze_form(ctx, expanded)
1621-
1622-
# Verify that macroexpanded code also does not have any
1623-
# non-tail recur forms
1624-
if ctx.recur_point is not None:
1625-
_assert_recur_is_tail(expanded_ast)
1626-
1627-
return expanded_ast.assoc(
1628-
raw_forms=cast(vec.Vector, expanded_ast.raw_forms).cons(form)
1629-
)
1630-
except Exception as e:
1631-
raise CompilerException(
1632-
"error occurred during macroexpansion",
1633-
form=form,
1634-
phase=CompilerPhase.MACROEXPANSION,
1635-
) from e
1626+
if ctx.should_macroexpand:
1627+
try:
1628+
macro_env = ctx.symbol_table.as_env_map()
1629+
expanded = fn.var.value(macro_env, form, *form.rest)
1630+
expanded_ast = _analyze_form(ctx, expanded)
1631+
1632+
# Verify that macroexpanded code also does not have any
1633+
# non-tail recur forms
1634+
if ctx.recur_point is not None:
1635+
_assert_recur_is_tail(expanded_ast)
1636+
1637+
return expanded_ast.assoc(
1638+
raw_forms=cast(vec.Vector, expanded_ast.raw_forms).cons(form)
1639+
)
1640+
except Exception as e:
1641+
raise CompilerException(
1642+
"error occurred during macroexpansion",
1643+
form=form,
1644+
phase=CompilerPhase.MACROEXPANSION,
1645+
) from e
16361646

16371647
return Invoke(
16381648
form=form,
@@ -2458,3 +2468,27 @@ def analyze_form(ctx: AnalyzerContext, form: ReaderForm) -> Node:
24582468
"""Take a Lisp form as an argument and produce a Basilisp syntax
24592469
tree matching the clojure.tools.analyzer AST spec."""
24602470
return _analyze_form(ctx, form).assoc(top_level=True)
2471+
2472+
2473+
def macroexpand_1(form: ReaderForm) -> ReaderForm:
2474+
"""Macroexpand form one time. Returns the macroexpanded form. The return
2475+
value may still represent a macro. Does not macroexpand child forms."""
2476+
ctx = AnalyzerContext("<Macroexpand>", should_macroexpand=False)
2477+
maybe_macro = analyze_form(ctx, form)
2478+
if maybe_macro.op == NodeOp.INVOKE:
2479+
assert isinstance(maybe_macro, Invoke)
2480+
2481+
fn = maybe_macro.fn
2482+
if fn.op == NodeOp.VAR and isinstance(fn, VarRef):
2483+
if _is_macro(fn.var):
2484+
assert isinstance(form, ISeq)
2485+
macro_env = ctx.symbol_table.as_env_map()
2486+
return fn.var.value(macro_env, form, *form.rest)
2487+
return maybe_macro.form
2488+
2489+
2490+
def macroexpand(form: ReaderForm) -> ReaderForm:
2491+
"""Repeatedly macroexpand form as by macroexpand-1 until form no longer
2492+
represents a macro. Returns the expanded form. Does not macroexpand child
2493+
forms."""
2494+
return analyze_form(AnalyzerContext("<Macroexpand>"), form).form

tests/basilisp/compiler_test.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1740,6 +1740,59 @@ def test_macro_expansion(ns: runtime.Namespace):
17401740
assert llist.l(1, 2, 3) == lcompile("((fn [] '(1 2 3)))")
17411741

17421742

1743+
class TestMacroexpandFunctions:
1744+
@pytest.fixture
1745+
def example_macro(self):
1746+
lcompile(
1747+
"(defmacro parent [] `(defmacro ~'child [] (fn [])))",
1748+
resolver=runtime.resolve_alias,
1749+
)
1750+
1751+
def test_macroexpand_1(self, example_macro):
1752+
assert llist.l(
1753+
sym.symbol("defmacro", ns="basilisp.core"),
1754+
sym.symbol("child"),
1755+
vec.Vector.empty(),
1756+
llist.l(sym.symbol("fn", ns="basilisp.core"), vec.Vector.empty()),
1757+
) == compiler.macroexpand_1(llist.l(sym.symbol("parent")))
1758+
1759+
assert llist.l(
1760+
sym.symbol("add", ns="operator"), 1, 2
1761+
) == compiler.macroexpand_1(llist.l(sym.symbol("add", ns="operator"), 1, 2))
1762+
assert sym.symbol("map") == compiler.macroexpand_1(sym.symbol("map"))
1763+
assert llist.l(sym.symbol("map")) == compiler.macroexpand_1(
1764+
llist.l(sym.symbol("map"))
1765+
)
1766+
assert vec.Vector.empty() == compiler.macroexpand_1(vec.Vector.empty())
1767+
1768+
with pytest.raises(compiler.CompilerException):
1769+
compiler.macroexpand_1(sym.symbol("non-existent-symbol"))
1770+
1771+
def test_macroexpand(self, example_macro):
1772+
assert llist.l(
1773+
sym.symbol("def"),
1774+
sym.symbol("child"),
1775+
llist.l(
1776+
sym.symbol("fn", ns="basilisp.core"),
1777+
sym.symbol("child"),
1778+
vec.v(sym.symbol("&env"), sym.symbol("&form")),
1779+
llist.l(sym.symbol("fn", ns="basilisp.core"), vec.Vector.empty()),
1780+
),
1781+
) == compiler.macroexpand(llist.l(sym.symbol("parent")))
1782+
1783+
assert llist.l(sym.symbol("add", ns="operator"), 1, 2) == compiler.macroexpand(
1784+
llist.l(sym.symbol("add", ns="operator"), 1, 2)
1785+
)
1786+
assert sym.symbol("map") == compiler.macroexpand(sym.symbol("map"))
1787+
assert llist.l(sym.symbol("map")) == compiler.macroexpand(
1788+
llist.l(sym.symbol("map"))
1789+
)
1790+
assert vec.Vector.empty() == compiler.macroexpand(vec.Vector.empty())
1791+
1792+
with pytest.raises(compiler.CompilerException):
1793+
compiler.macroexpand(sym.symbol("non-existent-symbol"))
1794+
1795+
17431796
class TestIf:
17441797
def test_if_number_of_elems(self):
17451798
with pytest.raises(compiler.CompilerException):

0 commit comments

Comments
 (0)