Skip to content

Commit e60486f

Browse files
authored
Simplify SqlBackend and table creation logic (#203)
Fixes #202
1 parent 34c7a90 commit e60486f

File tree

9 files changed

+253
-178
lines changed

9 files changed

+253
-178
lines changed

src/databricks/labs/ucx/framework/crawlers.py

Lines changed: 84 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
from abc import ABC, abstractmethod
55
from collections.abc import Iterator
6+
from typing import ClassVar
67

78
from databricks.sdk import WorkspaceClient
89

@@ -20,6 +21,26 @@ def execute(self, sql):
2021
def fetch(self, sql) -> Iterator[any]:
2122
raise NotImplementedError
2223

24+
@abstractmethod
25+
def save_table(self, full_name: str, rows: list[any], mode: str = "append"):
26+
raise NotImplementedError
27+
28+
_builtin_type_mapping: ClassVar[dict[type, str]] = {str: "STRING", int: "INT", bool: "BOOLEAN", float: "FLOAT"}
29+
30+
@classmethod
31+
def _schema_for(cls, klass):
32+
fields = []
33+
for f in dataclasses.fields(klass):
34+
if f.type not in cls._builtin_type_mapping:
35+
msg = f"Cannot auto-convert {f.type}"
36+
raise SyntaxError(msg)
37+
not_null = " NOT NULL"
38+
if f.default is None:
39+
not_null = ""
40+
spark_type = cls._builtin_type_mapping[f.type]
41+
fields.append(f"{f.name} {spark_type}{not_null}")
42+
return ", ".join(fields)
43+
2344

2445
class StatementExecutionBackend(SqlBackend):
2546
def __init__(self, ws: WorkspaceClient, warehouse_id):
@@ -34,6 +55,40 @@ def fetch(self, sql) -> Iterator[any]:
3455
logger.debug(f"[api][fetch] {sql}")
3556
return self._sql.execute_fetch_all(self._warehouse_id, sql)
3657

58+
def save_table(self, full_name: str, rows: list[any], mode="append"):
59+
if mode == "overwrite":
60+
msg = "Overwrite mode is not yet supported"
61+
raise NotImplementedError(msg)
62+
63+
if len(rows) == 0:
64+
return
65+
66+
klass = rows[0].__class__
67+
ddl = f"CREATE TABLE IF NOT EXISTS {full_name} ({self._schema_for(klass)}) USING DELTA"
68+
self.execute(ddl)
69+
70+
fields = dataclasses.fields(klass)
71+
field_names = [f.name for f in fields]
72+
vals = "), (".join(self._row_to_sql(r, fields) for r in rows)
73+
sql = f'INSERT INTO {full_name} ({", ".join(field_names)}) VALUES ({vals})'
74+
self.execute(sql)
75+
76+
@staticmethod
77+
def _row_to_sql(row, fields):
78+
data = []
79+
for f in fields:
80+
value = getattr(row, f.name)
81+
if value is None:
82+
data.append("NULL")
83+
elif f.type == bool:
84+
data.append("TRUE" if value else "FALSE")
85+
elif f.type == str:
86+
data.append(f"'{value}'")
87+
else:
88+
msg = f"unknown type: {f.type}"
89+
raise ValueError(msg)
90+
return ", ".join(data)
91+
3792

3893
class RuntimeBackend(SqlBackend):
3994
def __init__(self):
@@ -42,6 +97,7 @@ def __init__(self):
4297
if "DATABRICKS_RUNTIME_VERSION" not in os.environ:
4398
msg = "Not in the Databricks Runtime"
4499
raise RuntimeError(msg)
100+
45101
self._spark = SparkSession.builder.getOrCreate()
46102

47103
def execute(self, sql):
@@ -52,6 +108,13 @@ def fetch(self, sql) -> Iterator[any]:
52108
logger.debug(f"[spark][fetch] {sql}")
53109
return self._spark.sql(sql).collect()
54110

111+
def save_table(self, full_name: str, rows: list[any], mode: str = "append"):
112+
if len(rows) == 0:
113+
return
114+
# pyspark deals well with lists of dataclass instances, as long as schema is provided
115+
df = self._spark.createDataFrame(rows, self._schema_for(rows[0]))
116+
df.write.saveAsTable(full_name, mode=mode)
117+
55118

56119
class CrawlerBase:
57120
def __init__(self, backend: SqlBackend, catalog: str, schema: str, table: str):
@@ -116,129 +179,38 @@ def _try_valid(cls, name: str):
116179
return None
117180
return cls._valid(name)
118181

119-
def _snapshot(self, klass, fetcher, loader) -> list[any]:
182+
def _snapshot(self, fetcher, loader) -> list[any]:
120183
"""
121-
Tries to load dataset of records with the type `klass` with `fetcher` function,
122-
otherwise automatically creates a table with the schema defined in `klass` and
123-
executes `loader` function to populate the dataset.
184+
Tries to load dataset of records with `fetcher` function, otherwise automatically creates
185+
a table with the schema defined in the class of the first row and executes `loader` function
186+
to populate the dataset.
124187
125188
Args:
126-
klass: The class representing the data structure.
127189
fetcher: A function to fetch existing data.
128190
loader: A function to load new data.
129191
130-
Behavior:
131-
- Initiates an infinite loop to attempt fetching existing data using the provided fetcher function.
132-
- If the fetcher function encounters a runtime error with the message "TABLE_OR_VIEW_NOT_FOUND",
133-
it indicates that the data does not exist in the table.
134-
- In this case, the method logs that the data is not found and triggers the loader function to load new data.
135-
- The new data loaded by the loader function is then appended to the existing table using the `_append_records`
136-
method.
137-
138-
Note:
139-
- The method assumes that the provided fetcher and loader functions operate on the same data structure.
140-
- The fetcher function should return an iterator of data records.
141-
- The loader function should return an iterator of new data records to be added to the table.
142-
143192
Exceptions:
144193
- If a runtime error occurs during fetching (other than "TABLE_OR_VIEW_NOT_FOUND"), the original error is
145194
re-raised.
146195
147196
Returns:
148197
list[any]: A list of data records, either fetched or loaded.
149198
"""
150-
loaded = False
151-
trigger_load = ValueError("trigger records load")
152-
while True:
153-
try:
154-
logger.debug(f"[{self._full_name}] fetching {self._table} inventory")
155-
cached_results = list(fetcher())
156-
if len(cached_results) == 0 and loaded:
157-
return cached_results
158-
if len(cached_results) == 0 and not loaded:
159-
raise trigger_load
199+
logger.debug(f"[{self._full_name}] fetching {self._table} inventory")
200+
try:
201+
cached_results = list(fetcher())
202+
if len(cached_results) > 0:
160203
return cached_results
161-
except Exception as e:
162-
if not (e == trigger_load or "TABLE_OR_VIEW_NOT_FOUND" in str(e)):
163-
raise e
164-
logger.debug(f"[{self._full_name}] crawling new batch for {self._table}")
165-
loaded_records = list(loader())
166-
if len(loaded_records) > 0:
167-
logger.debug(f"[{self._full_name}] found {len(loaded_records)} new records for {self._table}")
168-
self._append_records(klass, loaded_records)
169-
loaded = True
170-
171-
@staticmethod
172-
def _row_to_sql(row, fields):
173-
data = []
174-
for f in fields:
175-
value = getattr(row, f.name)
176-
if value is None:
177-
data.append("NULL")
178-
elif f.type == bool:
179-
data.append("TRUE" if value else "FALSE")
180-
elif f.type == str:
181-
data.append(f"'{value}'")
182-
else:
183-
msg = f"unknown type: {f.type}"
184-
raise ValueError(msg)
185-
return ", ".join(data)
186-
187-
@staticmethod
188-
def _field_type(f):
189-
if f.type == bool:
190-
return "BOOLEAN"
191-
elif f.type == str:
192-
return "STRING"
193-
else:
194-
msg = f"unknown type: {f.type}"
195-
raise ValueError(msg)
196-
197-
def _append_records(self, klass, records: Iterator[any]):
198-
"""
199-
Appends records to the table or creates the table if it does not exist.
200-
201-
Args:
202-
klass: The class representing the data structure.
203-
records (Iterator[any]): An iterator of records to be appended.
204-
205-
Behavior:
206-
- Retrieves the fields of the provided class representing the data.
207-
- Generates a comma-separated list of field names from the fields.
208-
- Converts each record into a formatted SQL representation using the `_row_to_sql` method.
209-
- Constructs an SQL INSERT statement with the formatted field names and values.
210-
- Attempts to execute the INSERT statement using the `_exec` function.
211-
- If the table does not exist (TABLE_OR_VIEW_NOT_FOUND), it creates the table using a CREATE TABLE statement.
212-
213-
Note:
214-
- The method assumes that the target table exists in the database.
215-
- If the table does not exist, it will be created with the schema inferred from the class fields.
216-
- If the table already exists, the provided records will be appended to it.
217-
218-
Exceptions:
219-
- If a runtime error occurs during execution, it checks if the error message contains "TABLE_OR_VIEW_NOT_FOUND".
220-
- If the table does not exist, a new table will be created using the schema inferred from the class fields.
221-
- If the error is different, the original error is re-raised.
222-
"""
223-
fields = dataclasses.fields(klass)
224-
field_names = [f.name for f in fields]
225-
vals = "), (".join(self._row_to_sql(r, fields) for r in records)
226-
sql = f'INSERT INTO {self._full_name} ({", ".join(field_names)}) VALUES ({vals})'
227-
while True:
228-
try:
229-
logger.debug(f"[{self._full_name}] appending records")
230-
self._exec(sql)
231-
return
232-
except Exception as e:
233-
if "TABLE_OR_VIEW_NOT_FOUND" not in str(e):
234-
raise e
235-
logger.debug(f"[{self._full_name}] not found. creating")
236-
schema = ", ".join(f"{f.name} {self._field_type(f)}" for f in fields)
237-
try:
238-
self._exec(f"CREATE TABLE {self._full_name} ({schema}) USING DELTA")
239-
except Exception as e:
240-
schema_not_found = "SCHEMA_NOT_FOUND" in str(e)
241-
if not schema_not_found:
242-
raise e
243-
logger.debug(f"[{self._catalog}.{self._schema}] not found. creating")
244-
self._exec(f"CREATE SCHEMA {self._catalog}.{self._schema}")
204+
except Exception as err:
205+
if "TABLE_OR_VIEW_NOT_FOUND" not in str(err):
206+
raise err
207+
logger.debug(f"[{self._full_name}] crawling new batch for {self._table}")
208+
loaded_records = list(loader())
209+
self._append_records(loaded_records)
210+
return loaded_records
211+
212+
def _append_records(self, items):
213+
if len(items) == 0:
214+
return
215+
logger.debug(f"[{self._full_name}] found {len(items)} new records for {self._table}")
216+
self._backend.save_table(self._full_name, items, mode="append")

src/databricks/labs/ucx/framework/tasks.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ def wrapper(*args, **kwargs):
4747
continue
4848
deps.append(fn.__name__)
4949

50+
if not func.__doc__:
51+
msg = f"Task {func.__name__} must have documentation"
52+
raise SyntaxError(msg)
53+
5054
_TASKS[func.__name__] = Task(
5155
workflow=workflow, name=func.__name__, doc=func.__doc__, fn=func, depends_on=deps, job_cluster=job_cluster
5256
)

src/databricks/labs/ucx/hive_metastore/grants.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,7 @@ def __init__(self, tc: TablesCrawler):
120120
self._tc = tc
121121

122122
def snapshot(self, catalog: str, database: str) -> list[Grant]:
123-
return self._snapshot(
124-
Grant, partial(self._try_load, catalog, database), partial(self._crawl, catalog, database)
125-
)
123+
return self._snapshot(partial(self._try_load, catalog, database), partial(self._crawl, catalog, database))
126124

127125
def _try_load(self, catalog: str, database: str):
128126
for row in self._fetch(

src/databricks/labs/ucx/hive_metastore/tables.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,7 @@ def snapshot(self, catalog: str, database: str) -> list[Table]:
9595
Returns:
9696
list[Table]: A list of Table objects representing the snapshot of tables.
9797
"""
98-
return self._snapshot(
99-
Table, partial(self._try_load, catalog, database), partial(self._crawl, catalog, database)
100-
)
98+
return self._snapshot(partial(self._try_load, catalog, database), partial(self._crawl, catalog, database))
10199

102100
def _try_load(self, catalog: str, database: str):
103101
"""Tries to load table information from the database or throws TABLE_OR_VIEW_NOT_FOUND error"""

src/databricks/labs/ucx/runtime.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from databricks.sdk import WorkspaceClient
66

77
from databricks.labs.ucx.config import MigrationConfig
8+
from databricks.labs.ucx.framework.crawlers import RuntimeBackend
89
from databricks.labs.ucx.framework.tasks import task, trigger
910
from databricks.labs.ucx.hive_metastore import TaclToolkit
1011
from databricks.labs.ucx.workspace_access import GroupMigrationToolkit
@@ -13,6 +14,13 @@
1314

1415

1516
@task("assessment")
17+
def setup_schema(cfg: MigrationConfig):
18+
"""Creates a database for UCX migration intermediate state"""
19+
backend = RuntimeBackend()
20+
backend.execute(f"CREATE SCHEMA IF NOT EXISTS hive_metastore.{cfg.inventory_database}")
21+
22+
23+
@task("assessment", depends_on=[setup_schema])
1624
def crawl_tables(cfg: MigrationConfig):
1725
"""During this operation, a systematic scan is conducted, encompassing every table within the Hive Metastore.
1826
This scan extracts essential details associated with each table, including its unique identifier or name, table
@@ -48,7 +56,7 @@ def crawl_grants(cfg: MigrationConfig):
4856
tacls.grants_snapshot()
4957

5058

51-
@task("assessment")
59+
@task("assessment", depends_on=[setup_schema])
5260
def inventorize_permissions(cfg: MigrationConfig):
5361
"""As we embark on the complex migration journey from Hive Metastore to the Databricks Unity Catalog, a pivotal
5462
aspect of this transition is the comprehensive examination and preservation of permissions associated with a myriad

src/databricks/labs/ucx/workspace_access/manager.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@ def cleanup(self):
7373
logger.info("Inventory table cleanup complete")
7474

7575
def _save(self, items: list[Permissions]):
76-
# TODO: update instead of append
77-
logger.info(f"Saving {len(items)} items to {self._full_name}")
78-
self._append_records(Permissions, items)
76+
self._append_records(items) # TODO: update instead of append
7977
logger.info("Successfully saved the items to inventory table")
8078

8179
def _load_all(self) -> list[Permissions]:

tests/unit/framework/mocks.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def __init__(self, *, fails_on_first: dict | None = None, rows: dict | None = No
1313
if not rows:
1414
rows = {}
1515
self._rows = rows
16+
self._save_table = []
1617
self.queries = []
1718

1819
def _sql(self, sql):
@@ -39,3 +40,14 @@ def fetch(self, sql) -> Iterator[any]:
3940
rows.extend(self._rows[pattern])
4041
logger.debug(f"Returning rows: {rows}")
4142
return iter(rows)
43+
44+
def save_table(self, full_name: str, rows: list[any], mode: str = "append"):
45+
self._save_table.append((full_name, rows, mode))
46+
47+
def rows_written_for(self, full_name: str, mode: str) -> list[any]:
48+
rows = []
49+
for stub_full_name, stub_rows, stub_mode in self._save_table:
50+
if not (stub_full_name == full_name and stub_mode == mode):
51+
continue
52+
rows += stub_rows
53+
return rows

0 commit comments

Comments
 (0)