Skip to content

Commit a9948d7

Browse files
authored
Infer values from child notebook in magic line (#2091)
## Changes When linting notebooks, use values from child notebooks loaded via %run magic line to improve value inference ### Linked issues Resolves #1201 Progresses #1901 ### Functionality None ### Tests - [x] added unit tests --------- Co-authored-by: Eric Vergnaud <[email protected]>
1 parent 3487dcd commit a9948d7

22 files changed

+375
-143
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,11 @@ path = ".venv"
9797
test = "pytest -n 4 --cov src --cov-report=xml --timeout 30 tests/unit --durations 20"
9898
coverage = "pytest -n auto --cov src tests/unit --timeout 30 --cov-report=html --durations 20"
9999
integration = "pytest -n 10 --cov src tests/integration --durations 20"
100-
fmt = ["black . --extend-exclude 'tests/unit/source_code/samples/*' --extend-exclude dist",
100+
fmt = ["black . --extend-exclude 'tests/unit/source_code/samples/'",
101101
"ruff check . --fix",
102102
"mypy --disable-error-code 'annotation-unchecked' --exclude 'tests/unit/source_code/samples/*' --exclude dist .",
103103
"pylint --output-format=colorized -j 0 src tests"]
104-
verify = ["black --check . --extend-exclude 'tests/unit/source_code/samples/*' --extend-exclude dist",
104+
verify = ["black --check . --extend-exclude 'tests/unit/source_code/samples/'",
105105
"ruff check .",
106106
"mypy --exclude 'tests/unit/source_code/samples/*' --exclude dist .",
107107
"pylint --output-format=colorized -j 0 src tests"]

src/databricks/labs/ucx/source_code/base.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dataclasses import dataclass
77
from pathlib import Path
88

9-
from astroid import AstroidSyntaxError, NodeNG # type: ignore
9+
from astroid import AstroidSyntaxError, Module, NodeNG # type: ignore
1010

1111
from databricks.sdk.service import compute
1212

@@ -205,23 +205,42 @@ def __init__(self, linters: list[PythonLinter]):
205205

206206
def lint(self, code: str) -> Iterable[Advice]:
207207
try:
208-
tree = Tree.normalize_and_parse(code)
209-
if self._tree is None:
210-
self._tree = tree
211-
else:
212-
tree = self._tree.append_statements(tree)
213-
for linter in self._linters:
214-
yield from linter.lint_tree(tree)
208+
tree = self._parse_and_append(code)
209+
yield from self.lint_tree(tree)
215210
except AstroidSyntaxError as e:
216211
yield Failure('syntax-error', str(e), 0, 0, 0, 0)
217212

213+
def lint_tree(self, tree: Tree) -> Iterable[Advice]:
214+
for linter in self._linters:
215+
yield from linter.lint_tree(tree)
216+
217+
def _parse_and_append(self, code: str) -> Tree:
218+
tree = Tree.normalize_and_parse(code)
219+
self.append_tree(tree)
220+
return tree
221+
222+
def append_tree(self, tree: Tree):
223+
if self._tree is None:
224+
self._tree = Tree(Module("root"))
225+
self._tree.append_tree(tree)
226+
227+
def append_nodes(self, nodes: list[NodeNG]):
228+
if self._tree is None:
229+
self._tree = Tree(Module("root"))
230+
self._tree.append_nodes(nodes)
231+
232+
def append_globals(self, globs: dict):
233+
if self._tree is None:
234+
self._tree = Tree(Module("root"))
235+
self._tree.append_globals(globs)
236+
218237
def process_child_cell(self, code: str):
219238
try:
220239
tree = Tree.normalize_and_parse(code)
221240
if self._tree is None:
222241
self._tree = tree
223242
else:
224-
self._tree.append_statements(tree)
243+
self._tree.append_tree(tree)
225244
except AstroidSyntaxError as e:
226245
# error already reported when linting enclosing notebook
227246
logger.warning(f"Failed to parse Python cell: {code}", exc_info=e)

src/databricks/labs/ucx/source_code/jobs.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -381,16 +381,16 @@ def _lint_task(self, task: jobs.Task, job: jobs.Job):
381381
if not container:
382382
continue
383383
if isinstance(container, Notebook):
384-
yield from self._lint_notebook(container, ctx)
384+
yield from self._lint_notebook(container, ctx, session_state)
385385
if isinstance(container, LocalFile):
386-
yield from self._lint_file(container, ctx)
386+
yield from self._lint_file(container, ctx, session_state)
387387

388-
def _lint_file(self, file: LocalFile, ctx: LinterContext):
389-
linter = FileLinter(ctx, self._path_lookup, file.path)
388+
def _lint_file(self, file: LocalFile, ctx: LinterContext, session_state: CurrentSessionState):
389+
linter = FileLinter(ctx, self._path_lookup, session_state, file.path)
390390
for advice in linter.lint():
391391
yield file.path, advice
392392

393-
def _lint_notebook(self, notebook: Notebook, ctx: LinterContext):
394-
linter = NotebookLinter(ctx, self._path_lookup, notebook)
393+
def _lint_notebook(self, notebook: Notebook, ctx: LinterContext, session_state: CurrentSessionState):
394+
linter = NotebookLinter(ctx, self._path_lookup, session_state, notebook)
395395
for advice in linter.lint():
396396
yield notebook.path, advice

src/databricks/labs/ucx/source_code/known.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from databricks.labs.blueprint.entrypoint import get_logger
1515

1616
from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex
17+
from databricks.labs.ucx.source_code.base import CurrentSessionState
1718
from databricks.labs.ucx.source_code.graph import DependencyProblem
1819
from databricks.labs.ucx.source_code.linters.context import LinterContext
1920
from databricks.labs.ucx.source_code.notebooks.sources import FileLinter
@@ -172,8 +173,9 @@ def _analyze_file(cls, known_distributions, library_root, dist_info, module_path
172173
if module_ref.endswith(suffix):
173174
module_ref = module_ref[: -len(suffix)]
174175
logger.info(f"Processing module: {module_ref}")
175-
ctx = LinterContext(empty_index)
176-
linter = FileLinter(ctx, PathLookup.from_sys_path(module_path.parent), module_path)
176+
session_state = CurrentSessionState()
177+
ctx = LinterContext(empty_index, session_state)
178+
linter = FileLinter(ctx, PathLookup.from_sys_path(module_path.parent), session_state, module_path)
177179
known_problems = set()
178180
for problem in linter.lint():
179181
known_problems.add(KnownProblem(problem.code, problem.message))

src/databricks/labs/ucx/source_code/linters/files.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def _lint_one(self, path: Path) -> Iterable[LocatedAdvice]:
139139
if path.is_dir():
140140
return []
141141
ctx = self._new_linter_context()
142-
linter = FileLinter(ctx, self._path_lookup, path)
142+
linter = FileLinter(ctx, self._path_lookup, self._session_state, path)
143143
return [advice.for_path(path) for advice in linter.lint()]
144144

145145

src/databricks/labs/ucx/source_code/linters/imports.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import abc
44
import logging
55
from collections.abc import Callable, Iterable
6+
from pathlib import Path
67
from typing import TypeVar, cast
78

89
from astroid import ( # type: ignore
@@ -19,6 +20,7 @@
1920
from databricks.labs.ucx.source_code.base import Advice, Advisory, CurrentSessionState, PythonLinter
2021
from databricks.labs.ucx.source_code.linters.python_ast import Tree, NodeBase, TreeVisitor
2122
from databricks.labs.ucx.source_code.linters.python_infer import InferredValue
23+
from databricks.labs.ucx.source_code.path_lookup import PathLookup
2224

2325
logger = logging.getLogger(__name__)
2426

@@ -161,17 +163,18 @@ def __init__(self, node: NodeNG, path: str, is_append: bool):
161163
self._path = path
162164
self._is_append = is_append
163165

164-
@property
165-
def node(self):
166-
return self._node
167-
168166
@property
169167
def path(self):
170168
return self._path
171169

172-
@property
173-
def is_append(self):
174-
return self._is_append
170+
def apply_to(self, path_lookup: PathLookup):
171+
path = Path(self._path)
172+
if not path.is_absolute():
173+
path = path_lookup.cwd / path
174+
if self._is_append:
175+
path_lookup.append_path(path)
176+
return
177+
path_lookup.prepend_path(path)
175178

176179

177180
class AbsolutePath(SysPathChange):

src/databricks/labs/ucx/source_code/linters/python_ast.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -210,24 +210,33 @@ def _get_attribute_value(cls, node: Attribute):
210210
logger.debug(f"Missing handler for {name}")
211211
return None
212212

213-
def append_statements(self, tree: Tree) -> Tree:
213+
def append_tree(self, tree: Tree) -> Tree:
214214
if not isinstance(tree.node, Module):
215-
raise NotImplementedError(f"Can't append statements from {type(tree.node).__name__}")
215+
raise NotImplementedError(f"Can't append tree from {type(tree.node).__name__}")
216216
tree_module: Module = cast(Module, tree.node)
217+
self.append_nodes(tree_module.body)
218+
self.append_globals(tree_module.globals)
219+
# the following may seem strange but it's actually ok to use the original module as tree root
220+
return tree
221+
222+
def append_globals(self, globs: dict):
217223
if not isinstance(self.node, Module):
218-
raise NotImplementedError(f"Can't append statements to {type(self.node).__name__}")
224+
raise NotImplementedError(f"Can't append globals to {type(self.node).__name__}")
219225
self_module: Module = cast(Module, self.node)
220-
for stmt in tree_module.body:
221-
stmt.parent = self_module
222-
self_module.body.append(stmt)
223-
for name, value in tree_module.globals.items():
226+
for name, value in globs.items():
224227
statements: list[Expr] = self_module.globals.get(name, None)
225228
if statements is None:
226229
self_module.globals[name] = list(value) # clone the source list to avoid side-effects
227230
continue
228231
statements.extend(value)
229-
# the following may seem strange but it's actually ok to use the original module as tree root
230-
return tree
232+
233+
def append_nodes(self, nodes: list[NodeNG]):
234+
if not isinstance(self.node, Module):
235+
raise NotImplementedError(f"Can't append statements to {type(self.node).__name__}")
236+
self_module: Module = cast(Module, self.node)
237+
for node in nodes:
238+
node.parent = self_module
239+
self_module.body.append(node)
231240

232241
def is_from_module(self, module_name: str):
233242
# if this is the call's root node, check it against the required module

src/databricks/labs/ucx/source_code/notebooks/cells.py

Lines changed: 71 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from collections.abc import Callable
99
from enum import Enum
1010
from pathlib import Path
11+
from typing import TypeVar
1112

1213
from astroid import Call, Const, ImportFrom, Name, NodeNG # type: ignore
1314
from astroid.exceptions import AstroidSyntaxError # type: ignore
@@ -229,7 +230,8 @@ def is_runnable(self) -> bool:
229230
return True # TODO
230231

231232
def build_dependency_graph(self, graph: DependencyGraph) -> list[DependencyProblem]:
232-
return PipMagic(self.original_code).build_dependency_graph(graph)
233+
node = MagicNode(0, 1, None, end_lineno=0, end_col_offset=len(self.original_code))
234+
return PipCommand(node, self.original_code).build_dependency_graph(graph)
233235

234236

235237
class CellLanguage(Enum):
@@ -414,10 +416,10 @@ def build_graph_from_python_source(self, python_code: str) -> list[DependencyPro
414416
import_problems: list[DependencyProblem]
415417
import_sources, import_problems = ImportSource.extract_from_tree(tree, DependencyProblem.from_node)
416418
problems.extend(import_problems)
417-
magic_commands, command_problems = MagicCommand.extract_from_tree(tree, DependencyProblem.from_node)
419+
magic_commands, command_problems = MagicLine.extract_from_tree(tree, DependencyProblem.from_node)
418420
problems.extend(command_problems)
419421
nodes = syspath_changes + run_calls + import_sources + magic_commands
420-
# need to execute things in intertwined sequence so concat and sort
422+
# need to execute things in intertwined sequence so concat and sort them
421423
for base_node in sorted(nodes, key=lambda node: (node.node.lineno, node.node.col_offset)):
422424
for problem in self._process_node(base_node):
423425
# Astroid line numbers are 1-based.
@@ -437,7 +439,7 @@ def _process_node(self, base_node: NodeBase):
437439
yield from self._register_notebook(base_node)
438440
elif isinstance(base_node, ImportSource):
439441
yield from self._register_import(base_node)
440-
elif isinstance(base_node, MagicCommand):
442+
elif isinstance(base_node, MagicLine):
441443
yield from base_node.build_dependency_graph(self._context.parent)
442444
else:
443445
logger.warning(f"Can't process {NodeBase.__name__} of type {type(base_node).__name__}")
@@ -466,23 +468,20 @@ def _mutate_path_lookup(self, change: SysPathChange):
466468
f"Can't update sys.path from {change.node.as_string()} because the expression cannot be computed",
467469
)
468470
return
469-
path = Path(change.path)
470-
if not path.is_absolute():
471-
path = self._context.path_lookup.cwd / path
472-
if change.is_append:
473-
self._context.path_lookup.append_path(path)
474-
return
475-
self._context.path_lookup.prepend_path(path)
471+
change.apply_to(self._context.path_lookup)
472+
476473

474+
T = TypeVar("T")
477475

478-
class MagicCommand(NodeBase):
476+
477+
class MagicLine(NodeBase):
479478

480479
@classmethod
481480
def extract_from_tree(
482-
cls, tree: Tree, problem_factory: Callable[[str, str, NodeNG], DependencyProblem]
483-
) -> tuple[list[MagicCommand], list[DependencyProblem]]:
484-
problems: list[DependencyProblem] = []
485-
commands: list[MagicCommand] = []
481+
cls, tree: Tree, problem_factory: Callable[[str, str, NodeNG], T]
482+
) -> tuple[list[MagicLine], list[T]]:
483+
problems: list[T] = []
484+
commands: list[MagicLine] = []
486485
try:
487486
nodes = tree.locate(Call, [("magic_command", Name)])
488487
for command in cls._make_commands_for_magic_command_call_nodes(nodes):
@@ -498,36 +497,82 @@ def _make_commands_for_magic_command_call_nodes(cls, nodes: list[Call]):
498497
for node in nodes:
499498
arg = node.args[0]
500499
if isinstance(arg, Const):
501-
yield MagicCommand(node, arg.value)
500+
yield MagicLine(node, arg.value)
502501

503502
def __init__(self, node: NodeNG, command: bytes):
504503
super().__init__(node)
505504
self._command = command.decode()
506505

507-
def build_dependency_graph(self, graph: DependencyGraph) -> list[DependencyProblem]:
506+
def as_magic(self) -> MagicCommand | None:
508507
if self._command.startswith("%pip") or self._command.startswith("!pip"):
509-
cmd = PipMagic(self._command)
510-
return cmd.build_dependency_graph(graph)
508+
return PipCommand(self.node, self._command)
509+
if self._command.startswith("%run"):
510+
return RunCommand(self.node, self._command)
511+
return None
512+
513+
def build_dependency_graph(self, graph: DependencyGraph) -> list[DependencyProblem]:
514+
magic = self.as_magic()
515+
if magic is not None:
516+
return magic.build_dependency_graph(graph)
511517
problem = DependencyProblem.from_node(
512518
code='unsupported-magic-line', message=f"magic line '{self._command}' is not supported yet", node=self.node
513519
)
514520
return [problem]
515521

516522

517-
class PipMagic:
523+
class MagicNode(NodeNG):
524+
pass
518525

519-
def __init__(self, code: str):
526+
527+
class MagicCommand(ABC):
528+
529+
def __init__(self, node: NodeNG, code: str):
530+
self._node = node
520531
self._code = code
521532

533+
@abstractmethod
534+
def build_dependency_graph(self, graph: DependencyGraph) -> list[DependencyProblem]: ...
535+
536+
537+
class RunCommand(MagicCommand):
538+
539+
def build_dependency_graph(self, graph: DependencyGraph) -> list[DependencyProblem]:
540+
path = self.notebook_path
541+
if path is not None:
542+
problems = graph.register_notebook(path)
543+
return [problem.from_node(problem.code, problem.message, self._node) for problem in problems]
544+
problem = DependencyProblem.from_node('invalid-run-cell', "Missing notebook path in %run command", self._node)
545+
return [problem]
546+
547+
@property
548+
def notebook_path(self) -> Path | None:
549+
start = self._code.find(' ')
550+
if start < 0:
551+
return None
552+
path = self._code[start + 1 :].strip().strip('"').strip("'")
553+
return Path(path)
554+
555+
556+
class PipCommand(MagicCommand):
557+
522558
def build_dependency_graph(self, graph: DependencyGraph) -> list[DependencyProblem]:
523559
argv = self._split(self._code)
524560
if len(argv) == 1:
525-
return [DependencyProblem("library-install-failed", "Missing command after 'pip'")]
561+
return [DependencyProblem.from_node("library-install-failed", "Missing command after 'pip'", self._node)]
526562
if argv[1] != "install":
527-
return [DependencyProblem("library-install-failed", f"Unsupported 'pip' command: {argv[1]}")]
563+
return [
564+
DependencyProblem.from_node(
565+
"library-install-failed", f"Unsupported 'pip' command: {argv[1]}", self._node
566+
)
567+
]
528568
if len(argv) == 2:
529-
return [DependencyProblem("library-install-failed", "Missing arguments after 'pip install'")]
530-
return graph.register_library(*argv[2:]) # Skipping %pip install
569+
return [
570+
DependencyProblem.from_node(
571+
"library-install-failed", "Missing arguments after 'pip install'", self._node
572+
)
573+
]
574+
problems = graph.register_library(*argv[2:]) # Skipping %pip install
575+
return [problem.from_node(problem.code, problem.message, self._node) for problem in problems]
531576

532577
# Cache re-used regex (and ensure issues are raised during class init instead of upon first use).
533578
_splitter = re.compile(r"(?<!\\)\n")

0 commit comments

Comments
 (0)