@@ -84,6 +84,11 @@ def _convert_magic_lines_to_magic_commands(cls, python_code: str):
8484 in_multi_line_comment = not in_multi_line_comment
8585 return "\n " .join (lines )
8686
87+ @classmethod
88+ def new_module (cls ):
89+ node = Module ("root" )
90+ return Tree (node )
91+
8792 def __init__ (self , node : NodeNG ):
8893 self ._node : NodeNG = node
8994
@@ -118,6 +123,132 @@ def first_statement(self):
118123 return self ._node .body [0 ]
119124 return None
120125
126+ def __repr__ (self ):
127+ truncate_after = 32
128+ code = repr (self ._node )
129+ if len (code ) > truncate_after :
130+ code = code [0 :truncate_after ] + "..."
131+ return f"<Tree: { code } >"
132+
133+ def append_tree (self , tree : Tree ) -> Tree :
134+ """returns the appended tree, not the consolidated one!"""
135+ if not isinstance (tree .node , Module ):
136+ raise NotImplementedError (f"Can't append tree from { type (tree .node ).__name__ } " )
137+ tree_module : Module = cast (Module , tree .node )
138+ self .append_nodes (tree_module .body )
139+ self .append_globals (tree_module .globals )
140+ # the following may seem strange but it's actually ok to use the original module as tree root
141+ # because each node points to the correct parent (practically, the tree is now only a list of statements)
142+ return tree
143+
144+ def append_globals (self , globs : dict [str , list [NodeNG ]]) -> None :
145+ if not isinstance (self .node , Module ):
146+ raise NotImplementedError (f"Can't append globals to { type (self .node ).__name__ } " )
147+ self_module : Module = cast (Module , self .node )
148+ for name , values in globs .items ():
149+ statements : list [Expr ] = self_module .globals .get (name , None )
150+ if statements is None :
151+ self_module .globals [name ] = list (values ) # clone the source list to avoid side-effects
152+ continue
153+ statements .extend (values )
154+
155+ def append_nodes (self , nodes : list [NodeNG ]) -> None :
156+ if not isinstance (self .node , Module ):
157+ raise NotImplementedError (f"Can't append statements to { type (self .node ).__name__ } " )
158+ self_module : Module = cast (Module , self .node )
159+ for node in nodes :
160+ node .parent = self_module
161+ self_module .body .append (node )
162+
163+ def is_from_module (self , module_name : str ) -> bool :
164+ # if this is the call's root node, check it against the required module
165+ if isinstance (self ._node , Name ):
166+ if self ._node .name == module_name :
167+ return True
168+ root = self .root
169+ if not isinstance (root , Module ):
170+ return False
171+ for value in root .globals .get (self ._node .name , []):
172+ if not isinstance (value , AssignName ) or not isinstance (value .parent , Assign ):
173+ continue
174+ if Tree (value .parent .value ).is_from_module (module_name ):
175+ return True
176+ return False
177+ # walk up intermediate calls such as spark.range(...)
178+ if isinstance (self ._node , Call ):
179+ return isinstance (self ._node .func , Attribute ) and Tree (self ._node .func .expr ).is_from_module (module_name )
180+ if isinstance (self ._node , Attribute ):
181+ return Tree (self ._node .expr ).is_from_module (module_name )
182+ return False
183+
184+ def has_global (self , name : str ) -> bool :
185+ if not isinstance (self .node , Module ):
186+ return False
187+ self_module : Module = cast (Module , self .node )
188+ return self_module .globals .get (name , None ) is not None
189+
190+ def nodes_between (self , first_line : int , last_line : int ) -> list [NodeNG ]:
191+ if not isinstance (self .node , Module ):
192+ raise NotImplementedError (f"Can't extract nodes from { type (self .node ).__name__ } " )
193+ self_module : Module = cast (Module , self .node )
194+ nodes : list [NodeNG ] = []
195+ for node in self_module .body :
196+ if node .lineno < first_line :
197+ continue
198+ if node .lineno > last_line :
199+ break
200+ nodes .append (node )
201+ return nodes
202+
203+ def globals_between (self , first_line : int , last_line : int ) -> dict [str , list [NodeNG ]]:
204+ if not isinstance (self .node , Module ):
205+ raise NotImplementedError (f"Can't extract globals from { type (self .node ).__name__ } " )
206+ self_module : Module = cast (Module , self .node )
207+ globs : dict [str , list [NodeNG ]] = {}
208+ for key , nodes in self_module .globals .items ():
209+ nodes_in_scope : list [NodeNG ] = []
210+ for node in nodes :
211+ if node .lineno < first_line or node .lineno > last_line :
212+ continue
213+ nodes_in_scope .append (node )
214+ if len (nodes_in_scope ) > 0 :
215+ globs [key ] = nodes_in_scope
216+ return globs
217+
218+ def line_count (self ):
219+ if not isinstance (self .node , Module ):
220+ raise NotImplementedError (f"Can't count lines from { type (self .node ).__name__ } " )
221+ self_module : Module = cast (Module , self .node )
222+ nodes_count = len (self_module .body )
223+ if nodes_count == 0 :
224+ return 0
225+ return 1 + self_module .body [nodes_count - 1 ].lineno - self_module .body [0 ].lineno
226+
227+ def renumber (self , start : int ) -> Tree :
228+ assert start != 0
229+ if not isinstance (self .node , Module ):
230+ raise NotImplementedError (f"Can't renumber { type (self .node ).__name__ } " )
231+ root : Module = self .node
232+ # for now renumber in place to avoid the complexity of rebuilding the tree with clones
233+
234+ def renumber_node (node : NodeNG , offset : int ) -> None :
235+ for child in node .get_children ():
236+ renumber_node (child , offset + child .lineno - node .lineno )
237+ if node .end_lineno :
238+ node .end_lineno = node .end_lineno + offset
239+ node .lineno = node .lineno + offset
240+
241+ nodes = root .body if start > 0 else reversed (root .body )
242+ for node in nodes :
243+ offset = start - node .lineno
244+ renumber_node (node , offset )
245+ num_lines = 1 + (node .end_lineno - node .lineno if node .end_lineno else 0 )
246+ start = start + num_lines if start > 0 else start - num_lines
247+ return self
248+
249+
250+ class TreeHelper (ABC ):
251+
121252 @classmethod
122253 def extract_call_by_name (cls , call : Call , name : str ) -> Call | None :
123254 """Given a call-chain, extract its sub-call by method name (if it has one)"""
@@ -163,13 +294,6 @@ def is_none(cls, node: NodeNG) -> bool:
163294 return False
164295 return node .value is None
165296
166- def __repr__ (self ):
167- truncate_after = 32
168- code = repr (self ._node )
169- if len (code ) > truncate_after :
170- code = code [0 :truncate_after ] + "..."
171- return f"<Tree: { code } >"
172-
173297 @classmethod
174298 def get_full_attribute_name (cls , node : Attribute ) -> str :
175299 return cls ._get_attribute_value (node )
@@ -210,55 +334,6 @@ def _get_attribute_value(cls, node: Attribute):
210334 logger .debug (f"Missing handler for { name } " )
211335 return None
212336
213- def append_tree (self , tree : Tree ) -> Tree :
214- if not isinstance (tree .node , Module ):
215- raise NotImplementedError (f"Can't append tree from { type (tree .node ).__name__ } " )
216- tree_module : Module = cast (Module , tree .node )
217- self .append_nodes (tree_module .body )
218- self .append_globals (tree_module .globals )
219- # the following may seem strange but it's actually ok to use the original module as tree root
220- return tree
221-
222- def append_globals (self , globs : dict ):
223- if not isinstance (self .node , Module ):
224- raise NotImplementedError (f"Can't append globals to { type (self .node ).__name__ } " )
225- self_module : Module = cast (Module , self .node )
226- for name , value in globs .items ():
227- statements : list [Expr ] = self_module .globals .get (name , None )
228- if statements is None :
229- self_module .globals [name ] = list (value ) # clone the source list to avoid side-effects
230- continue
231- statements .extend (value )
232-
233- def append_nodes (self , nodes : list [NodeNG ]):
234- if not isinstance (self .node , Module ):
235- raise NotImplementedError (f"Can't append statements to { type (self .node ).__name__ } " )
236- self_module : Module = cast (Module , self .node )
237- for node in nodes :
238- node .parent = self_module
239- self_module .body .append (node )
240-
241- def is_from_module (self , module_name : str ):
242- # if this is the call's root node, check it against the required module
243- if isinstance (self ._node , Name ):
244- if self ._node .name == module_name :
245- return True
246- root = self .root
247- if not isinstance (root , Module ):
248- return False
249- for value in root .globals .get (self ._node .name , []):
250- if not isinstance (value , AssignName ) or not isinstance (value .parent , Assign ):
251- continue
252- if Tree (value .parent .value ).is_from_module (module_name ):
253- return True
254- return False
255- # walk up intermediate calls such as spark.range(...)
256- if isinstance (self ._node , Call ):
257- return isinstance (self ._node .func , Attribute ) and Tree (self ._node .func .expr ).is_from_module (module_name )
258- if isinstance (self ._node , Attribute ):
259- return Tree (self ._node .expr ).is_from_module (module_name )
260- return False
261-
262337
263338class TreeVisitor :
264339
0 commit comments