22
33from dataclasses import dataclass
44from enum import Enum
5- from typing import Callable , TypeAlias , Iterable
5+ from typing import Callable , TypeAlias , Iterable , TypeVar
66
7+ from util import flatten_force
78from ..common import HasRegion , StrRegion
89
910__all__ = [
1011 "AstNode" , "AstProgramNode" , "VarDeclScope" , "VarDeclType" , "AstDeclNode" ,
1112 "AstRepeat" , "AstIf" , "AstWhile" , "AstAssign" , "AstAugAssign" , "AstDefine" ,
1213 "AstNumber" , "AstString" , "AstAnyName" , "AstIdent" , "AstAttrName" ,
1314 "AstListLiteral" , "AstAttribute" , "AstItem" , "AstCall" , "AstOp" , "AstBinOp" ,
14- "AstUnaryOp" , 'walk_ast' , 'WalkableT' , 'WalkerFnT' , 'WalkerCallType'
15+ "AstUnaryOp" , 'walk_ast' , 'WalkableT' , 'WalkerFnT' , 'WalkerCallType' ,
16+ "FilteredWalker"
1517]
1618
1719
@@ -48,7 +50,6 @@ def _walk_obj_members(cls, o: WalkableT, fn: WalkerFnT):
4850 if o is None :
4951 return
5052 if isinstance (o , AstNode ):
51- # noinspection PyProtectedMember
5253 return o ._walk_members (fn )
5354 try :
5455 it = iter (o )
@@ -75,6 +76,122 @@ def walk_multiple_objects(cls, fn: WalkerFnT, objs: Iterable[WalkableT]):
7576walk_ast = AstNode .walk_obj
7677
7778
79+ # region <FilteredWalker>
80+ WT = TypeVar ('WT' , bound = WalkableT )
81+ VT = TypeVar ('VT' )
82+ SpecificCbT = Callable [[WT ], bool | None ]
83+ SpecificCbsDict = dict [type [WT ] | type , list [Callable [[WT ], bool | None ]]]
84+ BothCbT = Callable [[WT , WalkerCallType ], bool | None ]
85+ BothCbsDict = dict [type [WT ] | type , list [Callable [[WT , WalkerCallType ], bool | None ]]]
86+
87+
88+ class WalkerFilterRegistry :
89+ def __init__ (self , enter_cbs : SpecificCbsDict = (),
90+ exit_cbs : SpecificCbsDict = (),
91+ both_sbc : BothCbsDict = ()):
92+ self .enter_cbs : SpecificCbsDict = dict (enter_cbs ) # Copy them,
93+ self .exit_cbs : SpecificCbsDict = dict (exit_cbs ) # also converts default () -> {}
94+ self .both_cbs : BothCbsDict = dict (both_sbc )
95+
96+ def copy (self ):
97+ return WalkerFilterRegistry (self .enter_cbs , self .exit_cbs , self .both_cbs )
98+
99+ def register_both (self , t : type [WT ], fn : Callable [[WT , WalkerCallType ], bool | None ]):
100+ self .both_cbs .setdefault (t , []).append (fn )
101+ return self
102+
103+ def register_enter (self , t : type [WT ], fn : Callable [[WT ], bool | None ]):
104+ self .enter_cbs .setdefault (t , []).append (fn )
105+ return self
106+
107+ def register_exit (self , t : type [WT ], fn : Callable [[WT ], bool | None ]):
108+ self .exit_cbs .setdefault (t , []).append (fn )
109+ return self
110+
111+ def __call__ (self , * args , ** kwargs ):
112+ return self
113+
114+ def on_enter (self , * tps : type [WT ] | type ):
115+ """Decorator version of register_enter."""
116+ def decor (fn : SpecificCbT ):
117+ for t in tps :
118+ self .register_enter (t , fn )
119+ return fn
120+ return decor
121+
122+ def on_exit (self , * tps : type [WT ] | type ):
123+ """Decorator version of register_exit."""
124+ def decor (fn : SpecificCbT ):
125+ for t in tps :
126+ self .register_exit (t , fn )
127+ return fn
128+ return decor
129+
130+ def on_both (self , * tps : type [WT ] | type ):
131+ """Decorator version of register_both."""
132+ def decor (fn : BothCbT ):
133+ for t in tps :
134+ self .register_both (t , fn )
135+ return fn
136+ return decor
137+
138+
139+ class FilteredWalker (WalkerFilterRegistry ):
140+ def __init__ (self ):
141+ cls_reg = self .class_registry ()
142+ super ().__init__ (cls_reg .enter_cbs , cls_reg .exit_cbs , cls_reg .both_cbs )
143+
144+ @classmethod
145+ def class_registry (cls ) -> WalkerFilterRegistry :
146+ return WalkerFilterRegistry ()
147+
148+ @classmethod
149+ def create_cls_registry (cls , fn = None ):
150+ """Create a class-level registry that can be added to using decorators.
151+
152+ This can be used in two ways (at the top of your class)::
153+
154+ # MUST be this name
155+ class_registry = FilteredWalker.create_cls_registry()
156+
157+ or::
158+
159+ @classmethod
160+ @FilteredWalker.create_cls_registry
161+ def class_registry(cls): # MUST be this name
162+ pass
163+
164+ and when registering methods::
165+
166+ @class_registry.on_enter(AstDefine)
167+ def enter_define(self, ...):
168+ ...
169+
170+ The restrictions on name are because we have no other way of detecting
171+ it (without metaclass dark magic) as we can't refer to the class while
172+ its namespace is being evaluated
173+ """
174+ if fn is not None and (parent := fn (cls )) is not None :
175+ return WalkerFilterRegistry .copy (parent )
176+ return WalkerFilterRegistry ()
177+
178+ def __call__ (self , o : WalkableT , call_type : WalkerCallType ):
179+ result = None
180+ # Call more specific ones first
181+ specific_cbs = self .enter_cbs if call_type == WalkerCallType .PRE else self .exit_cbs
182+ for fn in self ._get_funcs (specific_cbs , type (o )):
183+ result = fn (o ) or result
184+ for fn in self ._get_funcs (self .both_cbs , type (o )):
185+ result = fn (o , call_type ) or result
186+ return result
187+
188+ @classmethod
189+ def _get_funcs (cls , mapping : dict [type [WT ] | type , list [VT ]], tp : type [WT ]) -> list [VT ]:
190+ """Also looks at superclasses/MRO"""
191+ return flatten_force (mapping .get (sub , []) for sub in tp .mro ())
192+ # endregion
193+
194+
78195@dataclass
79196class AstProgramNode (AstNode ):
80197 name = 'program'
0 commit comments