Skip to content

Commit 811bfc5

Browse files
authored
Pass macro environment as the first argument to macro functions (#339)
* Pass macro environment as the first argument to macro functions * Formatting
1 parent eef6fb0 commit 811bfc5

File tree

6 files changed

+302
-69
lines changed

6 files changed

+302
-69
lines changed

src/basilisp/core/__init__.lpy

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,15 @@
6161
(.with-meta o meta)))
6262

6363
(def ^:macro ^:redef let
64-
(fn* let [&form & decl]
64+
(fn* let [&env &form & decl]
6565
(cons 'let* decl)))
6666

6767
(def ^:macro ^:redef loop
68-
(fn* loop [&form & decl]
68+
(fn* loop [&env &form & decl]
6969
(cons 'loop* decl)))
7070

7171
(def ^:macro ^:redef fn
72-
(fn* fn [&form & decl]
72+
(fn* fn [&env &form & decl]
7373
(with-meta
7474
(cons 'fn* decl)
7575
(meta &form))))
@@ -264,7 +264,7 @@
264264
:doc "Define a new function with an optional docstring."
265265
:arglists '([name & body] [name doc & body])}
266266
defn
267-
(fn defn [&form name & body]
267+
(fn defn [&env &form name & body]
268268
(if (symbol? name)
269269
nil ;; Do nothing!
270270
(throw (ex-info "First argument to defn must be a symbol"
@@ -382,7 +382,7 @@
382382
:doc "Define a new macro like defn. Macro functions are available to the
383383
compiler during macroexpansion."}
384384
defmacro
385-
(fn defmacro [&form name & body]
385+
(fn defmacro [&env &form name & body]
386386
(let [body (concat body)
387387
doc (if (string? (first body))
388388
(first body)
@@ -408,7 +408,10 @@
408408
add-implicit-args (fn [body]
409409
(cons
410410
(if (vector? (first body))
411-
(apply vector (cons '&form (first body)))
411+
(apply vector
412+
(cons '&env
413+
(cons '&form
414+
(first body))))
412415
(throw
413416
(ex-info "Expected an argument vector"
414417
{:found (first body)})))
@@ -2245,3 +2248,29 @@
22452248
(let [[bindings body] (loop-with-destructuring bindings body)]
22462249
`(loop* ~bindings
22472250
~@body)))
2251+
2252+
;;;;;;;;;;;;;;;;;;;;;;;
2253+
;; Interop Functions ;;
2254+
;;;;;;;;;;;;;;;;;;;;;;;
2255+
2256+
(defn lisp->py
2257+
"Recursively convert Basilisp data structures into Python data structures.
2258+
2259+
Callers can specify a keyword argument :keyword-fn, which names a function
2260+
which is called for each keyword value in the input structure to return a
2261+
new value. By default :keyword-fn is the function `name`."
2262+
([o]
2263+
(basilisp.lang.runtime/to-py o))
2264+
([o & {:keys [keyword-fn] :or {keyword-fn name}}]
2265+
(basilisp.lang.runtime/to-py o keyword-fn)))
2266+
2267+
(defn py->lisp
2268+
"Recursively convert Python data structures into Basilisp data structures.
2269+
2270+
Callers can specify a keyword argument :keywordize-keys, which defaults to
2271+
true. If :keywordize-keys is true, then all string keys in Python dicts will
2272+
be converted into keywords in the final return value."
2273+
([o]
2274+
(basilisp.lang.runtime/to-lisp o))
2275+
([o & {:keys [keywordize-keys] :or {keywordize-keys true}}]
2276+
(basilisp.lang.runtime/to-lisp o keywordize-keys)))

src/basilisp/lang/compiler/nodes.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import basilisp.lang.set as lset
2020
import basilisp.lang.symbol as sym
2121
import basilisp.lang.vector as vec
22-
from basilisp.lang.runtime import Namespace, Var
22+
from basilisp.lang.runtime import Namespace, Var, to_lisp
2323
from basilisp.lang.typing import LispForm, ReaderForm as ReaderLispForm, SpecialForm
2424
from basilisp.lang.util import munge
2525

@@ -120,6 +120,9 @@ def top_level(self) -> bool:
120120
def env(self) -> "NodeEnv":
121121
raise NotImplementedError()
122122

123+
def to_map(self) -> lmap.Map:
124+
return to_lisp(attr.asdict(self))
125+
123126
def assoc(self, **kwargs):
124127
return attr.evolve(self, **kwargs)
125128

src/basilisp/lang/compiler/parser.py

Lines changed: 74 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,18 @@ class RecurPoint:
134134

135135
@attr.s(auto_attribs=True, frozen=True, slots=True)
136136
class SymbolTableEntry:
137-
context: LocalType
138-
symbol: sym.Symbol
137+
binding: Binding
139138
used: bool = False
140139
warn_if_unused: bool = True
141140

141+
@property
142+
def symbol(self) -> sym.Symbol:
143+
return self.binding.form
144+
145+
@property
146+
def context(self) -> LocalType:
147+
return self.binding.local
148+
142149

143150
# pylint: disable=unsupported-membership-test,unsupported-delete-operation,unsupported-assignment-operation
144151
@attr.s(auto_attribs=True, slots=True)
@@ -149,14 +156,16 @@ class SymbolTable:
149156
_children: Dict[str, "SymbolTable"] = attr.ib(factory=dict)
150157

151158
def new_symbol(
152-
self, s: sym.Symbol, ctx: LocalType, warn_if_unused: bool = True
159+
self, s: sym.Symbol, binding: Binding, warn_if_unused: bool = True
153160
) -> "SymbolTable":
161+
assert s == binding.form, "Binding symbol must match passed symbol"
162+
154163
if s in self._table:
155164
self._table[s] = attr.evolve(
156-
self._table[s], context=ctx, symbol=s, warn_if_unused=warn_if_unused
165+
self._table[s], binding=binding, warn_if_unused=warn_if_unused
157166
)
158167
else:
159-
self._table[s] = SymbolTableEntry(ctx, s, warn_if_unused=warn_if_unused)
168+
self._table[s] = SymbolTableEntry(binding, warn_if_unused=warn_if_unused)
160169
return self
161170

162171
def find_symbol(self, s: sym.Symbol) -> Optional[SymbolTableEntry]:
@@ -228,6 +237,16 @@ def new_frame(self, name, warn_on_unused_names):
228237
new_frame._warn_unused_names()
229238
self.pop_frame(name)
230239

240+
def _as_env_map(self) -> Dict[sym.Symbol, lmap.Map]:
241+
locals_ = {} if self._parent is None else self._parent._as_env_map()
242+
locals_.update({k: v.binding.to_map() for k, v in self._table.items()})
243+
return locals_
244+
245+
def as_env_map(self) -> lmap.Map:
246+
"""Return a map of symbols to the local binding objects in the
247+
local symbol table as of this call."""
248+
return lmap.map(self._as_env_map())
249+
231250

232251
class ParserContext:
233252
__slots__ = ("_filename", "_is_quoted", "_opts", "_recur_points", "_st")
@@ -304,7 +323,7 @@ def symbol_table(self) -> SymbolTable:
304323
def put_new_symbol( # pylint: disable=too-many-arguments
305324
self,
306325
s: sym.Symbol,
307-
sym_ctx: LocalType,
326+
binding: Binding,
308327
warn_on_shadowed_name: bool = True,
309328
warn_on_shadowed_var: bool = True,
310329
warn_if_unused: bool = True,
@@ -336,7 +355,7 @@ def put_new_symbol( # pylint: disable=too-many-arguments
336355
logger.warning(f"name '{s}' shadows def'ed Var from outer scope")
337356
if s.meta is not None and s.meta.entry(SYM_NO_WARN_WHEN_UNUSED_META_KEY, None):
338357
warn_if_unused = False
339-
st.new_symbol(s, sym_ctx, warn_if_unused=warn_if_unused)
358+
st.new_symbol(s, binding, warn_if_unused=warn_if_unused)
340359

341360
@contextlib.contextmanager
342361
def new_symbol_table(self, name):
@@ -548,7 +567,7 @@ def _do_ast(ctx: ParserContext, form: lseq.Seq) -> Do:
548567
)
549568

550569

551-
def __fn_method_ast( # pylint: disable=too-many-branches
570+
def __fn_method_ast( # pylint: disable=too-many-branches,too-many-locals
552571
ctx: ParserContext, form: lseq.Seq, fnname: Optional[sym.Symbol] = None
553572
) -> FnMethod:
554573
with ctx.new_symbol_table("fn-method"):
@@ -571,18 +590,16 @@ def __fn_method_ast( # pylint: disable=too-many-branches
571590
vargs_idx = i
572591
break
573592

574-
param_nodes.append(
575-
Binding(
576-
form=s,
577-
name=s.name,
578-
local=LocalType.ARG,
579-
arg_id=i,
580-
is_variadic=False,
581-
env=ctx.get_node_env(),
582-
)
593+
binding = Binding(
594+
form=s,
595+
name=s.name,
596+
local=LocalType.ARG,
597+
arg_id=i,
598+
is_variadic=False,
599+
env=ctx.get_node_env(),
583600
)
584-
585-
ctx.put_new_symbol(s, LocalType.ARG)
601+
param_nodes.append(binding)
602+
ctx.put_new_symbol(s, binding)
586603

587604
if has_vargs:
588605
try:
@@ -593,18 +610,16 @@ def __fn_method_ast( # pylint: disable=too-many-branches
593610
"function rest parameter name must be a symbol", form=vargs_sym
594611
)
595612

596-
param_nodes.append(
597-
Binding(
598-
form=vargs_sym,
599-
name=vargs_sym.name,
600-
local=LocalType.ARG,
601-
arg_id=vargs_idx + 1,
602-
is_variadic=True,
603-
env=ctx.get_node_env(),
604-
)
613+
binding = Binding(
614+
form=vargs_sym,
615+
name=vargs_sym.name,
616+
local=LocalType.ARG,
617+
arg_id=vargs_idx + 1,
618+
is_variadic=True,
619+
env=ctx.get_node_env(),
605620
)
606-
607-
ctx.put_new_symbol(vargs_sym, LocalType.ARG)
621+
param_nodes.append(binding)
622+
ctx.put_new_symbol(vargs_sym, binding)
608623
except IndexError:
609624
raise ParserException(
610625
"Expected variadic argument name after '&'", form=params
@@ -658,10 +673,11 @@ def _fn_ast( # pylint: disable=too-many-branches # noqa: MC0001
658673
)
659674

660675
if isinstance(name, sym.Symbol):
661-
ctx.put_new_symbol(name, LocalType.FN, warn_if_unused=False)
662676
name_node: Optional[Binding] = Binding(
663677
form=name, name=name.name, local=LocalType.FN, env=ctx.get_node_env()
664678
)
679+
assert name_node is not None
680+
ctx.put_new_symbol(name, name_node, warn_if_unused=False)
665681
idx += 1
666682
elif isinstance(name, (llist.List, vec.Vector)):
667683
name = None
@@ -944,7 +960,8 @@ def _invoke_ast(ctx: ParserContext, form: Union[llist.List, lseq.Seq]) -> Node:
944960
if fn.op == NodeOp.VAR and isinstance(fn, VarRef):
945961
if _is_macro(fn.var):
946962
try:
947-
expanded = fn.var.value(form, *form.rest)
963+
macro_env = ctx.symbol_table.as_env_map()
964+
expanded = fn.var.value(macro_env, form, *form.rest)
948965
expanded_ast = _parse_ast(ctx, expanded)
949966

950967
# Verify that macroexpanded code also does not have any
@@ -997,18 +1014,16 @@ def _let_ast(ctx: ParserContext, form: lseq.Seq) -> Let:
9971014
if not isinstance(name, sym.Symbol):
9981015
raise ParserException("let binding name must be a symbol", form=name)
9991016

1000-
binding_nodes.append(
1001-
Binding(
1002-
form=name,
1003-
name=name.name,
1004-
local=LocalType.LET,
1005-
init=_parse_ast(ctx, value),
1006-
children=vec.v(INIT),
1007-
env=ctx.get_node_env(),
1008-
)
1017+
binding = Binding(
1018+
form=name,
1019+
name=name.name,
1020+
local=LocalType.LET,
1021+
init=_parse_ast(ctx, value),
1022+
children=vec.v(INIT),
1023+
env=ctx.get_node_env(),
10091024
)
1010-
1011-
ctx.put_new_symbol(name, LocalType.LET)
1025+
binding_nodes.append(binding)
1026+
ctx.put_new_symbol(name, binding)
10121027

10131028
let_body = runtime.nthrest(form, 2)
10141029
*statements, ret = map(partial(_parse_ast, ctx), let_body)
@@ -1050,17 +1065,15 @@ def _loop_ast(ctx: ParserContext, form: lseq.Seq) -> Loop:
10501065
if not isinstance(name, sym.Symbol):
10511066
raise ParserException("loop binding name must be a symbol", form=name)
10521067

1053-
binding_nodes.append(
1054-
Binding(
1055-
form=name,
1056-
name=name.name,
1057-
local=LocalType.LOOP,
1058-
init=_parse_ast(ctx, value),
1059-
env=ctx.get_node_env(),
1060-
)
1068+
binding = Binding(
1069+
form=name,
1070+
name=name.name,
1071+
local=LocalType.LOOP,
1072+
init=_parse_ast(ctx, value),
1073+
env=ctx.get_node_env(),
10611074
)
1062-
1063-
ctx.put_new_symbol(name, LocalType.LOOP)
1075+
binding_nodes.append(binding)
1076+
ctx.put_new_symbol(name, binding)
10641077

10651078
with ctx.new_recur_point(loop_id, binding_nodes):
10661079
loop_body = runtime.nthrest(form, 2)
@@ -1222,19 +1235,20 @@ def _catch_ast(ctx: ParserContext, form: lseq.Seq) -> Catch:
12221235
raise ParserException("catch local must be a symbol", form=local_name)
12231236

12241237
with ctx.new_symbol_table("catch"):
1225-
ctx.put_new_symbol(local_name, LocalType.CATCH)
1238+
catch_binding = Binding(
1239+
form=local_name,
1240+
name=local_name.name,
1241+
local=LocalType.CATCH,
1242+
env=ctx.get_node_env(),
1243+
)
1244+
ctx.put_new_symbol(local_name, catch_binding)
12261245

12271246
catch_body = runtime.nthrest(form, 3)
12281247
*catch_statements, catch_ret = map(partial(_parse_ast, ctx), catch_body)
12291248
return Catch(
12301249
form=form,
12311250
class_=catch_cls,
1232-
local=Binding(
1233-
form=local_name,
1234-
name=local_name.name,
1235-
local=LocalType.CATCH,
1236-
env=ctx.get_node_env(),
1237-
),
1251+
local=catch_binding,
12381252
body=Do(
12391253
form=catch_body,
12401254
statements=vec.vector(catch_statements),

0 commit comments

Comments
 (0)