@@ -134,11 +134,18 @@ class RecurPoint:
134
134
135
135
@attr .s (auto_attribs = True , frozen = True , slots = True )
136
136
class SymbolTableEntry :
137
- context : LocalType
138
- symbol : sym .Symbol
137
+ binding : Binding
139
138
used : bool = False
140
139
warn_if_unused : bool = True
141
140
141
+ @property
142
+ def symbol (self ) -> sym .Symbol :
143
+ return self .binding .form
144
+
145
+ @property
146
+ def context (self ) -> LocalType :
147
+ return self .binding .local
148
+
142
149
143
150
# pylint: disable=unsupported-membership-test,unsupported-delete-operation,unsupported-assignment-operation
144
151
@attr .s (auto_attribs = True , slots = True )
@@ -149,14 +156,16 @@ class SymbolTable:
149
156
_children : Dict [str , "SymbolTable" ] = attr .ib (factory = dict )
150
157
151
158
def new_symbol (
152
- self , s : sym .Symbol , ctx : LocalType , warn_if_unused : bool = True
159
+ self , s : sym .Symbol , binding : Binding , warn_if_unused : bool = True
153
160
) -> "SymbolTable" :
161
+ assert s == binding .form , "Binding symbol must match passed symbol"
162
+
154
163
if s in self ._table :
155
164
self ._table [s ] = attr .evolve (
156
- self ._table [s ], context = ctx , symbol = s , warn_if_unused = warn_if_unused
165
+ self ._table [s ], binding = binding , warn_if_unused = warn_if_unused
157
166
)
158
167
else :
159
- self ._table [s ] = SymbolTableEntry (ctx , s , warn_if_unused = warn_if_unused )
168
+ self ._table [s ] = SymbolTableEntry (binding , warn_if_unused = warn_if_unused )
160
169
return self
161
170
162
171
def find_symbol (self , s : sym .Symbol ) -> Optional [SymbolTableEntry ]:
@@ -228,6 +237,16 @@ def new_frame(self, name, warn_on_unused_names):
228
237
new_frame ._warn_unused_names ()
229
238
self .pop_frame (name )
230
239
240
+ def _as_env_map (self ) -> Dict [sym .Symbol , lmap .Map ]:
241
+ locals_ = {} if self ._parent is None else self ._parent ._as_env_map ()
242
+ locals_ .update ({k : v .binding .to_map () for k , v in self ._table .items ()})
243
+ return locals_
244
+
245
+ def as_env_map (self ) -> lmap .Map :
246
+ """Return a map of symbols to the local binding objects in the
247
+ local symbol table as of this call."""
248
+ return lmap .map (self ._as_env_map ())
249
+
231
250
232
251
class ParserContext :
233
252
__slots__ = ("_filename" , "_is_quoted" , "_opts" , "_recur_points" , "_st" )
@@ -304,7 +323,7 @@ def symbol_table(self) -> SymbolTable:
304
323
def put_new_symbol ( # pylint: disable=too-many-arguments
305
324
self ,
306
325
s : sym .Symbol ,
307
- sym_ctx : LocalType ,
326
+ binding : Binding ,
308
327
warn_on_shadowed_name : bool = True ,
309
328
warn_on_shadowed_var : bool = True ,
310
329
warn_if_unused : bool = True ,
@@ -336,7 +355,7 @@ def put_new_symbol( # pylint: disable=too-many-arguments
336
355
logger .warning (f"name '{ s } ' shadows def'ed Var from outer scope" )
337
356
if s .meta is not None and s .meta .entry (SYM_NO_WARN_WHEN_UNUSED_META_KEY , None ):
338
357
warn_if_unused = False
339
- st .new_symbol (s , sym_ctx , warn_if_unused = warn_if_unused )
358
+ st .new_symbol (s , binding , warn_if_unused = warn_if_unused )
340
359
341
360
@contextlib .contextmanager
342
361
def new_symbol_table (self , name ):
@@ -548,7 +567,7 @@ def _do_ast(ctx: ParserContext, form: lseq.Seq) -> Do:
548
567
)
549
568
550
569
551
- def __fn_method_ast ( # pylint: disable=too-many-branches
570
+ def __fn_method_ast ( # pylint: disable=too-many-branches,too-many-locals
552
571
ctx : ParserContext , form : lseq .Seq , fnname : Optional [sym .Symbol ] = None
553
572
) -> FnMethod :
554
573
with ctx .new_symbol_table ("fn-method" ):
@@ -571,18 +590,16 @@ def __fn_method_ast( # pylint: disable=too-many-branches
571
590
vargs_idx = i
572
591
break
573
592
574
- param_nodes .append (
575
- Binding (
576
- form = s ,
577
- name = s .name ,
578
- local = LocalType .ARG ,
579
- arg_id = i ,
580
- is_variadic = False ,
581
- env = ctx .get_node_env (),
582
- )
593
+ binding = Binding (
594
+ form = s ,
595
+ name = s .name ,
596
+ local = LocalType .ARG ,
597
+ arg_id = i ,
598
+ is_variadic = False ,
599
+ env = ctx .get_node_env (),
583
600
)
584
-
585
- ctx .put_new_symbol (s , LocalType . ARG )
601
+ param_nodes . append ( binding )
602
+ ctx .put_new_symbol (s , binding )
586
603
587
604
if has_vargs :
588
605
try :
@@ -593,18 +610,16 @@ def __fn_method_ast( # pylint: disable=too-many-branches
593
610
"function rest parameter name must be a symbol" , form = vargs_sym
594
611
)
595
612
596
- param_nodes .append (
597
- Binding (
598
- form = vargs_sym ,
599
- name = vargs_sym .name ,
600
- local = LocalType .ARG ,
601
- arg_id = vargs_idx + 1 ,
602
- is_variadic = True ,
603
- env = ctx .get_node_env (),
604
- )
613
+ binding = Binding (
614
+ form = vargs_sym ,
615
+ name = vargs_sym .name ,
616
+ local = LocalType .ARG ,
617
+ arg_id = vargs_idx + 1 ,
618
+ is_variadic = True ,
619
+ env = ctx .get_node_env (),
605
620
)
606
-
607
- ctx .put_new_symbol (vargs_sym , LocalType . ARG )
621
+ param_nodes . append ( binding )
622
+ ctx .put_new_symbol (vargs_sym , binding )
608
623
except IndexError :
609
624
raise ParserException (
610
625
"Expected variadic argument name after '&'" , form = params
@@ -658,10 +673,11 @@ def _fn_ast( # pylint: disable=too-many-branches # noqa: MC0001
658
673
)
659
674
660
675
if isinstance (name , sym .Symbol ):
661
- ctx .put_new_symbol (name , LocalType .FN , warn_if_unused = False )
662
676
name_node : Optional [Binding ] = Binding (
663
677
form = name , name = name .name , local = LocalType .FN , env = ctx .get_node_env ()
664
678
)
679
+ assert name_node is not None
680
+ ctx .put_new_symbol (name , name_node , warn_if_unused = False )
665
681
idx += 1
666
682
elif isinstance (name , (llist .List , vec .Vector )):
667
683
name = None
@@ -944,7 +960,8 @@ def _invoke_ast(ctx: ParserContext, form: Union[llist.List, lseq.Seq]) -> Node:
944
960
if fn .op == NodeOp .VAR and isinstance (fn , VarRef ):
945
961
if _is_macro (fn .var ):
946
962
try :
947
- expanded = fn .var .value (form , * form .rest )
963
+ macro_env = ctx .symbol_table .as_env_map ()
964
+ expanded = fn .var .value (macro_env , form , * form .rest )
948
965
expanded_ast = _parse_ast (ctx , expanded )
949
966
950
967
# Verify that macroexpanded code also does not have any
@@ -997,18 +1014,16 @@ def _let_ast(ctx: ParserContext, form: lseq.Seq) -> Let:
997
1014
if not isinstance (name , sym .Symbol ):
998
1015
raise ParserException ("let binding name must be a symbol" , form = name )
999
1016
1000
- binding_nodes .append (
1001
- Binding (
1002
- form = name ,
1003
- name = name .name ,
1004
- local = LocalType .LET ,
1005
- init = _parse_ast (ctx , value ),
1006
- children = vec .v (INIT ),
1007
- env = ctx .get_node_env (),
1008
- )
1017
+ binding = Binding (
1018
+ form = name ,
1019
+ name = name .name ,
1020
+ local = LocalType .LET ,
1021
+ init = _parse_ast (ctx , value ),
1022
+ children = vec .v (INIT ),
1023
+ env = ctx .get_node_env (),
1009
1024
)
1010
-
1011
- ctx .put_new_symbol (name , LocalType . LET )
1025
+ binding_nodes . append ( binding )
1026
+ ctx .put_new_symbol (name , binding )
1012
1027
1013
1028
let_body = runtime .nthrest (form , 2 )
1014
1029
* statements , ret = map (partial (_parse_ast , ctx ), let_body )
@@ -1050,17 +1065,15 @@ def _loop_ast(ctx: ParserContext, form: lseq.Seq) -> Loop:
1050
1065
if not isinstance (name , sym .Symbol ):
1051
1066
raise ParserException ("loop binding name must be a symbol" , form = name )
1052
1067
1053
- binding_nodes .append (
1054
- Binding (
1055
- form = name ,
1056
- name = name .name ,
1057
- local = LocalType .LOOP ,
1058
- init = _parse_ast (ctx , value ),
1059
- env = ctx .get_node_env (),
1060
- )
1068
+ binding = Binding (
1069
+ form = name ,
1070
+ name = name .name ,
1071
+ local = LocalType .LOOP ,
1072
+ init = _parse_ast (ctx , value ),
1073
+ env = ctx .get_node_env (),
1061
1074
)
1062
-
1063
- ctx .put_new_symbol (name , LocalType . LOOP )
1075
+ binding_nodes . append ( binding )
1076
+ ctx .put_new_symbol (name , binding )
1064
1077
1065
1078
with ctx .new_recur_point (loop_id , binding_nodes ):
1066
1079
loop_body = runtime .nthrest (form , 2 )
@@ -1222,19 +1235,20 @@ def _catch_ast(ctx: ParserContext, form: lseq.Seq) -> Catch:
1222
1235
raise ParserException ("catch local must be a symbol" , form = local_name )
1223
1236
1224
1237
with ctx .new_symbol_table ("catch" ):
1225
- ctx .put_new_symbol (local_name , LocalType .CATCH )
1238
+ catch_binding = Binding (
1239
+ form = local_name ,
1240
+ name = local_name .name ,
1241
+ local = LocalType .CATCH ,
1242
+ env = ctx .get_node_env (),
1243
+ )
1244
+ ctx .put_new_symbol (local_name , catch_binding )
1226
1245
1227
1246
catch_body = runtime .nthrest (form , 3 )
1228
1247
* catch_statements , catch_ret = map (partial (_parse_ast , ctx ), catch_body )
1229
1248
return Catch (
1230
1249
form = form ,
1231
1250
class_ = catch_cls ,
1232
- local = Binding (
1233
- form = local_name ,
1234
- name = local_name .name ,
1235
- local = LocalType .CATCH ,
1236
- env = ctx .get_node_env (),
1237
- ),
1251
+ local = catch_binding ,
1238
1252
body = Do (
1239
1253
form = catch_body ,
1240
1254
statements = vec .vector (catch_statements ),
0 commit comments