Skip to content

Commit dc6d9c7

Browse files
authored
Fixed AttributeError: 'UsedTable' has no attribute 'table' by adding more type checks (#2895)
Fix #2887
1 parent f8c371b commit dc6d9c7

File tree

4 files changed

+57
-15
lines changed

4 files changed

+57
-15
lines changed

src/databricks/labs/ucx/source_code/base.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,16 @@ class TableInfoNode:
270270
class TablePyCollector(TableCollector, ABC):
271271

272272
def collect_tables(self, source_code: str):
273-
tree = Tree.normalize_and_parse(source_code)
274-
for table_node in self.collect_tables_from_tree(tree):
275-
yield table_node.table
273+
try:
274+
tree = Tree.normalize_and_parse(source_code)
275+
for table_node in self.collect_tables_from_tree(tree):
276+
# see https://github.com/databrickslabs/ucx/issues/2887
277+
if isinstance(table_node, UsedTable):
278+
yield table_node
279+
else:
280+
yield table_node.table
281+
except AstroidSyntaxError as e:
282+
logger.warning('syntax-error', exc_info=e)
276283

277284
@abstractmethod
278285
def collect_tables_from_tree(self, tree: Tree) -> Iterable[TableInfoNode]: ...
@@ -451,7 +458,12 @@ def collect_tables(self, source_code: str) -> Iterable[UsedTable]:
451458
try:
452459
tree = self._parse_and_append(source_code)
453460
for table_node in self.collect_tables_from_tree(tree):
454-
yield table_node.table
461+
# there's a bug in the code that causes this to be necessary
462+
# see https://github.com/databrickslabs/ucx/issues/2887
463+
if isinstance(table_node, UsedTable):
464+
yield table_node
465+
else:
466+
yield table_node.table
455467
except AstroidSyntaxError as e:
456468
logger.warning('syntax-error', exc_info=e)
457469

src/databricks/labs/ucx/source_code/jobs.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ def lint_job(self, job_id: int) -> tuple[list[JobProblem], list[DirectFsAccess],
412412
def _lint_job(self, job: jobs.Job) -> tuple[list[JobProblem], list[DirectFsAccess], list[UsedTable]]:
413413
problems: list[JobProblem] = []
414414
dfsas: list[DirectFsAccess] = []
415-
table_infos: list[UsedTable] = []
415+
used_tables: list[UsedTable] = []
416416

417417
assert job.job_id is not None
418418
assert job.settings is not None
@@ -447,13 +447,14 @@ def _lint_job(self, job: jobs.Job) -> tuple[list[JobProblem], list[DirectFsAcces
447447
assessment_start = datetime.now(timezone.utc)
448448
task_tables = self._collect_task_tables(job, task, graph, session_state)
449449
assessment_end = datetime.now(timezone.utc)
450-
for table_info in task_tables:
451-
table_info = table_info.replace_assessment_infos(
452-
assessment_start=assessment_start, assessment_end=assessment_end
450+
for used_table in task_tables:
451+
used_table = used_table.replace_assessment_infos(
452+
assessment_start=assessment_start,
453+
assessment_end=assessment_end,
453454
)
454-
table_infos.append(table_info)
455+
used_tables.append(used_table)
455456

456-
return problems, dfsas, table_infos
457+
return problems, dfsas, used_tables
457458

458459
def _build_task_dependency_graph(
459460
self, task: jobs.Task, job: jobs.Job
@@ -502,17 +503,21 @@ def _collect_task_dfsas(
502503
yield dataclasses.replace(dfsa, source_lineage=atoms + dfsa.source_lineage)
503504

504505
def _collect_task_tables(
505-
self, job: jobs.Job, task: jobs.Task, graph: DependencyGraph, session_state: CurrentSessionState
506+
self,
507+
job: jobs.Job,
508+
task: jobs.Task,
509+
graph: DependencyGraph,
510+
session_state: CurrentSessionState,
506511
) -> Iterable[UsedTable]:
507512
# need to add lineage for job/task because walker doesn't register them
508513
job_id = str(job.job_id)
509514
job_name = job.settings.name if job.settings and job.settings.name else "<anonymous>"
510-
for dfsa in TablesCollectorWalker(graph, set(), self._path_lookup, session_state, self._migration_index):
515+
for used_table in TablesCollectorWalker(graph, set(), self._path_lookup, session_state, self._migration_index):
511516
atoms = [
512517
LineageAtom(object_type="WORKFLOW", object_id=job_id, other={"name": job_name}),
513518
LineageAtom(object_type="TASK", object_id=f"{job_id}/{task.task_key}"),
514519
]
515-
yield dataclasses.replace(dfsa, source_lineage=atoms + dfsa.source_lineage)
520+
yield dataclasses.replace(used_table, source_lineage=atoms + used_table.source_lineage)
516521

517522

518523
class LintingWalker(DependencyGraphWalker[LocatedAdvice]):

src/databricks/labs/ucx/source_code/linters/pyspark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ def collect_tables_from_tree(self, tree: Tree) -> Iterable[TableInfoNode]:
395395
continue
396396
assert isinstance(node, Call)
397397
for used_table in matcher.collect_tables(self._from_table, self._index, self._session_state, node):
398-
yield TableInfoNode(used_table, node)
398+
yield TableInfoNode(used_table, node) # B
399399

400400

401401
class _SparkSqlAnalyzer:
@@ -475,4 +475,4 @@ def collect_tables_from_tree(self, tree: Tree) -> Iterable[TableInfoNode]:
475475
if not value.is_inferred():
476476
continue # TODO error handling strategy
477477
for table in self._sql_collector.collect_tables(value.as_string()):
478-
yield TableInfoNode(table, call_node)
478+
yield TableInfoNode(table, call_node) # A
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import pytest
2+
from databricks.sdk.service.workspace import Language
3+
4+
from databricks.labs.ucx.source_code.linters.context import LinterContext
5+
6+
7+
@pytest.mark.parametrize(
8+
'code,expected',
9+
[
10+
('spark.table("a.b").count()', {'r:a.b'}),
11+
('spark.getTable("a.b")', {'r:a.b'}),
12+
('spark.cacheTable("a.b")', {'r:a.b'}),
13+
('spark.range(10).saveAsTable("a.b")', {'r:a.b'}), # TODO: bug: has to be w:a.b
14+
('spark.sql("SELECT * FROM b.c LEFT JOIN c.d USING (e)")', {'r:b.c', 'r:c.d'}),
15+
('spark.sql("SELECT * FROM delta.`/foo/bar`")', set()),
16+
],
17+
)
18+
def test_collector_walker_from_python(code, expected, migration_index) -> None:
19+
used = set()
20+
ctx = LinterContext(migration_index)
21+
collector = ctx.tables_collector(Language.PYTHON)
22+
for used_table in collector.collect_tables(code):
23+
prefix = 'r' if used_table.is_read else 'w'
24+
used.add(f'{prefix}:{used_table.schema_name}.{used_table.table_name}')
25+
assert used == expected

0 commit comments

Comments
 (0)