Skip to content

Commit c242fb4

Browse files
authored
enhance python abstract syntax tree API (#2241)
## Changes Enhance python abstract syntax tree API in preparation of linting with inherited context Move Tree static methods to TreeHelper class to avoid 'too many public methods' linting error ### Linked issues Progresses #2155 Progresses #2156 ### Functionality None ### Tests - [x] added unit tests --------- Co-authored-by: Eric Vergnaud <[email protected]>
1 parent 3f25ca5 commit c242fb4

File tree

11 files changed

+262
-137
lines changed

11 files changed

+262
-137
lines changed

src/databricks/labs/ucx/source_code/base.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dataclasses import dataclass
77
from pathlib import Path
88

9-
from astroid import AstroidSyntaxError, Module, NodeNG # type: ignore
9+
from astroid import AstroidSyntaxError, NodeNG # type: ignore
1010

1111
from databricks.sdk.service import compute
1212

@@ -221,27 +221,24 @@ def _parse_and_append(self, code: str) -> Tree:
221221
return tree
222222

223223
def append_tree(self, tree: Tree):
224-
if self._tree is None:
225-
self._tree = Tree(Module("root"))
226-
self._tree.append_tree(tree)
224+
self._make_tree().append_tree(tree)
227225

228226
def append_nodes(self, nodes: list[NodeNG]):
229-
if self._tree is None:
230-
self._tree = Tree(Module("root"))
231-
self._tree.append_nodes(nodes)
227+
self._make_tree().append_nodes(nodes)
232228

233229
def append_globals(self, globs: dict):
234-
if self._tree is None:
235-
self._tree = Tree(Module("root"))
236-
self._tree.append_globals(globs)
230+
self._make_tree().append_globals(globs)
237231

238232
def process_child_cell(self, code: str):
239233
try:
234+
this_tree = self._make_tree()
240235
tree = Tree.normalize_and_parse(code)
241-
if self._tree is None:
242-
self._tree = tree
243-
else:
244-
self._tree.append_tree(tree)
236+
this_tree.append_tree(tree)
245237
except AstroidSyntaxError as e:
246238
# error already reported when linting enclosing notebook
247239
logger.warning(f"Failed to parse Python cell: {code}", exc_info=e)
240+
241+
def _make_tree(self) -> Tree:
242+
if self._tree is None:
243+
self._tree = Tree.new_module()
244+
return self._tree

src/databricks/labs/ucx/source_code/graph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,15 +196,15 @@ def visit(self, visit_node: Callable[[DependencyGraph], bool | None], visited: s
196196
return True
197197
return False
198198

199-
def new_graph_builder_context(self):
200-
return GraphBuilderContext(parent=self, path_lookup=self._path_lookup, session_state=self._session_state)
199+
def new_dependency_graph_context(self):
200+
return DependencyGraphContext(parent=self, path_lookup=self._path_lookup, session_state=self._session_state)
201201

202202
def __repr__(self):
203203
return f"<DependencyGraph {self.path}>"
204204

205205

206206
@dataclass
207-
class GraphBuilderContext:
207+
class DependencyGraphContext:
208208
parent: DependencyGraph
209209
path_lookup: PathLookup
210210
session_state: CurrentSessionState

src/databricks/labs/ucx/source_code/linters/files.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from databricks.labs.blueprint.tui import Prompts
1515

1616
from databricks.labs.ucx.source_code.linters.context import LinterContext
17-
from databricks.labs.ucx.source_code.notebooks.cells import CellLanguage, GraphBuilder
17+
from databricks.labs.ucx.source_code.notebooks.cells import CellLanguage, PythonCodeAnalyzer
1818
from databricks.labs.ucx.source_code.graph import (
1919
BaseImportResolver,
2020
BaseFileResolver,
@@ -40,9 +40,9 @@ def __init__(self, path: Path, source: str, language: Language):
4040

4141
def build_dependency_graph(self, parent: DependencyGraph) -> list[DependencyProblem]:
4242
if self._language is CellLanguage.PYTHON:
43-
context = parent.new_graph_builder_context()
44-
builder = GraphBuilder(context)
45-
return builder.build_graph_from_python_source(self._original_code)
43+
context = parent.new_dependency_graph_context()
44+
analyser = PythonCodeAnalyzer(context, self._original_code)
45+
return analyser.build_graph()
4646
# supported language that does not generate dependencies
4747
if self._language is CellLanguage.SQL:
4848
return []

src/databricks/labs/ucx/source_code/linters/pyspark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
)
1515
from databricks.labs.ucx.source_code.linters.python_infer import InferredValue
1616
from databricks.labs.ucx.source_code.queries import FromTable
17-
from databricks.labs.ucx.source_code.linters.python_ast import Tree
17+
from databricks.labs.ucx.source_code.linters.python_ast import Tree, TreeHelper
1818

1919

2020
@dataclass
@@ -61,7 +61,7 @@ def _get_table_arg(self, node: Call):
6161
def _check_call_context(self, node: Call) -> bool:
6262
assert isinstance(node.func, Attribute) # Avoid linter warning
6363
func_name = node.func.attrname
64-
qualified_name = Tree.get_full_function_name(node)
64+
qualified_name = TreeHelper.get_full_function_name(node)
6565

6666
# Check if the call_context is None as that means all calls are checked
6767
if self.call_context is None:

src/databricks/labs/ucx/source_code/linters/python_ast.py

Lines changed: 131 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ def _convert_magic_lines_to_magic_commands(cls, python_code: str):
8484
in_multi_line_comment = not in_multi_line_comment
8585
return "\n".join(lines)
8686

87+
@classmethod
88+
def new_module(cls):
89+
node = Module("root")
90+
return Tree(node)
91+
8792
def __init__(self, node: NodeNG):
8893
self._node: NodeNG = node
8994

@@ -118,6 +123,132 @@ def first_statement(self):
118123
return self._node.body[0]
119124
return None
120125

126+
def __repr__(self):
127+
truncate_after = 32
128+
code = repr(self._node)
129+
if len(code) > truncate_after:
130+
code = code[0:truncate_after] + "..."
131+
return f"<Tree: {code}>"
132+
133+
def append_tree(self, tree: Tree) -> Tree:
134+
"""returns the appended tree, not the consolidated one!"""
135+
if not isinstance(tree.node, Module):
136+
raise NotImplementedError(f"Can't append tree from {type(tree.node).__name__}")
137+
tree_module: Module = cast(Module, tree.node)
138+
self.append_nodes(tree_module.body)
139+
self.append_globals(tree_module.globals)
140+
# the following may seem strange but it's actually ok to use the original module as tree root
141+
# because each node points to the correct parent (practically, the tree is now only a list of statements)
142+
return tree
143+
144+
def append_globals(self, globs: dict[str, list[NodeNG]]) -> None:
145+
if not isinstance(self.node, Module):
146+
raise NotImplementedError(f"Can't append globals to {type(self.node).__name__}")
147+
self_module: Module = cast(Module, self.node)
148+
for name, values in globs.items():
149+
statements: list[Expr] = self_module.globals.get(name, None)
150+
if statements is None:
151+
self_module.globals[name] = list(values) # clone the source list to avoid side-effects
152+
continue
153+
statements.extend(values)
154+
155+
def append_nodes(self, nodes: list[NodeNG]) -> None:
156+
if not isinstance(self.node, Module):
157+
raise NotImplementedError(f"Can't append statements to {type(self.node).__name__}")
158+
self_module: Module = cast(Module, self.node)
159+
for node in nodes:
160+
node.parent = self_module
161+
self_module.body.append(node)
162+
163+
def is_from_module(self, module_name: str) -> bool:
164+
# if this is the call's root node, check it against the required module
165+
if isinstance(self._node, Name):
166+
if self._node.name == module_name:
167+
return True
168+
root = self.root
169+
if not isinstance(root, Module):
170+
return False
171+
for value in root.globals.get(self._node.name, []):
172+
if not isinstance(value, AssignName) or not isinstance(value.parent, Assign):
173+
continue
174+
if Tree(value.parent.value).is_from_module(module_name):
175+
return True
176+
return False
177+
# walk up intermediate calls such as spark.range(...)
178+
if isinstance(self._node, Call):
179+
return isinstance(self._node.func, Attribute) and Tree(self._node.func.expr).is_from_module(module_name)
180+
if isinstance(self._node, Attribute):
181+
return Tree(self._node.expr).is_from_module(module_name)
182+
return False
183+
184+
def has_global(self, name: str) -> bool:
185+
if not isinstance(self.node, Module):
186+
return False
187+
self_module: Module = cast(Module, self.node)
188+
return self_module.globals.get(name, None) is not None
189+
190+
def nodes_between(self, first_line: int, last_line: int) -> list[NodeNG]:
191+
if not isinstance(self.node, Module):
192+
raise NotImplementedError(f"Can't extract nodes from {type(self.node).__name__}")
193+
self_module: Module = cast(Module, self.node)
194+
nodes: list[NodeNG] = []
195+
for node in self_module.body:
196+
if node.lineno < first_line:
197+
continue
198+
if node.lineno > last_line:
199+
break
200+
nodes.append(node)
201+
return nodes
202+
203+
def globals_between(self, first_line: int, last_line: int) -> dict[str, list[NodeNG]]:
204+
if not isinstance(self.node, Module):
205+
raise NotImplementedError(f"Can't extract globals from {type(self.node).__name__}")
206+
self_module: Module = cast(Module, self.node)
207+
globs: dict[str, list[NodeNG]] = {}
208+
for key, nodes in self_module.globals.items():
209+
nodes_in_scope: list[NodeNG] = []
210+
for node in nodes:
211+
if node.lineno < first_line or node.lineno > last_line:
212+
continue
213+
nodes_in_scope.append(node)
214+
if len(nodes_in_scope) > 0:
215+
globs[key] = nodes_in_scope
216+
return globs
217+
218+
def line_count(self):
219+
if not isinstance(self.node, Module):
220+
raise NotImplementedError(f"Can't count lines from {type(self.node).__name__}")
221+
self_module: Module = cast(Module, self.node)
222+
nodes_count = len(self_module.body)
223+
if nodes_count == 0:
224+
return 0
225+
return 1 + self_module.body[nodes_count - 1].lineno - self_module.body[0].lineno
226+
227+
def renumber(self, start: int) -> Tree:
228+
assert start != 0
229+
if not isinstance(self.node, Module):
230+
raise NotImplementedError(f"Can't renumber {type(self.node).__name__}")
231+
root: Module = self.node
232+
# for now renumber in place to avoid the complexity of rebuilding the tree with clones
233+
234+
def renumber_node(node: NodeNG, offset: int) -> None:
235+
for child in node.get_children():
236+
renumber_node(child, offset + child.lineno - node.lineno)
237+
if node.end_lineno:
238+
node.end_lineno = node.end_lineno + offset
239+
node.lineno = node.lineno + offset
240+
241+
nodes = root.body if start > 0 else reversed(root.body)
242+
for node in nodes:
243+
offset = start - node.lineno
244+
renumber_node(node, offset)
245+
num_lines = 1 + (node.end_lineno - node.lineno if node.end_lineno else 0)
246+
start = start + num_lines if start > 0 else start - num_lines
247+
return self
248+
249+
250+
class TreeHelper(ABC):
251+
121252
@classmethod
122253
def extract_call_by_name(cls, call: Call, name: str) -> Call | None:
123254
"""Given a call-chain, extract its sub-call by method name (if it has one)"""
@@ -163,13 +294,6 @@ def is_none(cls, node: NodeNG) -> bool:
163294
return False
164295
return node.value is None
165296

166-
def __repr__(self):
167-
truncate_after = 32
168-
code = repr(self._node)
169-
if len(code) > truncate_after:
170-
code = code[0:truncate_after] + "..."
171-
return f"<Tree: {code}>"
172-
173297
@classmethod
174298
def get_full_attribute_name(cls, node: Attribute) -> str:
175299
return cls._get_attribute_value(node)
@@ -210,55 +334,6 @@ def _get_attribute_value(cls, node: Attribute):
210334
logger.debug(f"Missing handler for {name}")
211335
return None
212336

213-
def append_tree(self, tree: Tree) -> Tree:
214-
if not isinstance(tree.node, Module):
215-
raise NotImplementedError(f"Can't append tree from {type(tree.node).__name__}")
216-
tree_module: Module = cast(Module, tree.node)
217-
self.append_nodes(tree_module.body)
218-
self.append_globals(tree_module.globals)
219-
# the following may seem strange but it's actually ok to use the original module as tree root
220-
return tree
221-
222-
def append_globals(self, globs: dict):
223-
if not isinstance(self.node, Module):
224-
raise NotImplementedError(f"Can't append globals to {type(self.node).__name__}")
225-
self_module: Module = cast(Module, self.node)
226-
for name, value in globs.items():
227-
statements: list[Expr] = self_module.globals.get(name, None)
228-
if statements is None:
229-
self_module.globals[name] = list(value) # clone the source list to avoid side-effects
230-
continue
231-
statements.extend(value)
232-
233-
def append_nodes(self, nodes: list[NodeNG]):
234-
if not isinstance(self.node, Module):
235-
raise NotImplementedError(f"Can't append statements to {type(self.node).__name__}")
236-
self_module: Module = cast(Module, self.node)
237-
for node in nodes:
238-
node.parent = self_module
239-
self_module.body.append(node)
240-
241-
def is_from_module(self, module_name: str):
242-
# if this is the call's root node, check it against the required module
243-
if isinstance(self._node, Name):
244-
if self._node.name == module_name:
245-
return True
246-
root = self.root
247-
if not isinstance(root, Module):
248-
return False
249-
for value in root.globals.get(self._node.name, []):
250-
if not isinstance(value, AssignName) or not isinstance(value.parent, Assign):
251-
continue
252-
if Tree(value.parent.value).is_from_module(module_name):
253-
return True
254-
return False
255-
# walk up intermediate calls such as spark.range(...)
256-
if isinstance(self._node, Call):
257-
return isinstance(self._node.func, Attribute) and Tree(self._node.func.expr).is_from_module(module_name)
258-
if isinstance(self._node, Attribute):
259-
return Tree(self._node.expr).is_from_module(module_name)
260-
return False
261-
262337

263338
class TreeVisitor:
264339

src/databricks/labs/ucx/source_code/linters/spark_connect.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
PythonLinter,
1010
CurrentSessionState,
1111
)
12-
from databricks.labs.ucx.source_code.linters.python_ast import Tree
12+
from databricks.labs.ucx.source_code.linters.python_ast import Tree, TreeHelper
1313

1414

1515
@dataclass
@@ -96,7 +96,7 @@ def _lint_sc(self, node: NodeNG) -> Iterator[Advice]:
9696
return
9797
if node.func.attrname not in self._SC_METHODS:
9898
return
99-
function_name = Tree.get_full_function_name(node)
99+
function_name = TreeHelper.get_full_function_name(node)
100100
if not function_name or not function_name.endswith(f"sc.{node.func.attrname}"):
101101
return
102102
yield self._rdd_failure(node)
@@ -149,7 +149,7 @@ def _match_sc_set_log_level(self, node: NodeNG) -> Iterator[Advice]:
149149
return
150150
if node.func.attrname != 'setLogLevel':
151151
return
152-
function_name = Tree.get_full_function_name(node)
152+
function_name = TreeHelper.get_full_function_name(node)
153153
if not function_name or not function_name.endswith('sc.setLogLevel'):
154154
return
155155

@@ -163,7 +163,7 @@ def _match_sc_set_log_level(self, node: NodeNG) -> Iterator[Advice]:
163163
def _match_jvm_log(self, node: NodeNG) -> Iterator[Advice]:
164164
if not isinstance(node, Attribute):
165165
return
166-
attribute_name = Tree.get_full_attribute_name(node)
166+
attribute_name = TreeHelper.get_full_attribute_name(node)
167167
if attribute_name and attribute_name.endswith('org.apache.log4j'):
168168
yield Failure.from_node(
169169
code='spark-logging-in-shared-clusters',
@@ -179,7 +179,7 @@ class UDFMatcher(SharedClusterMatcher):
179179
def lint(self, node: NodeNG) -> Iterator[Advice]:
180180
if not isinstance(node, Call):
181181
return
182-
function_name = Tree.get_function_name(node)
182+
function_name = TreeHelper.get_function_name(node)
183183

184184
if function_name == 'registerJavaFunction':
185185
yield Failure.from_node(
@@ -214,7 +214,7 @@ class CatalogApiMatcher(SharedClusterMatcher):
214214
def lint(self, node: NodeNG) -> Iterator[Advice]:
215215
if not isinstance(node, Attribute):
216216
return
217-
if node.attrname == 'catalog' and Tree.get_full_attribute_name(node).endswith('spark.catalog'):
217+
if node.attrname == 'catalog' and TreeHelper.get_full_attribute_name(node).endswith('spark.catalog'):
218218
yield Failure.from_node(
219219
code='catalog-api-in-shared-clusters',
220220
message=f'spark.catalog functions require DBR 14.3 LTS or above on {self._cluster_type_str()}',
@@ -226,7 +226,7 @@ class CommandContextMatcher(SharedClusterMatcher):
226226
def lint(self, node: NodeNG) -> Iterator[Advice]:
227227
if not isinstance(node, Call):
228228
return
229-
function_name = Tree.get_full_function_name(node)
229+
function_name = TreeHelper.get_full_function_name(node)
230230
if function_name and function_name.endswith('getContext.toJson'):
231231
yield Failure.from_node(
232232
code='to-json-in-shared-clusters',

0 commit comments

Comments
 (0)