Skip to content

Commit 50d57e8

Browse files
authored
Improved table migrations logic (#1050)
## Changes - Refactor table migration unit tests to load table mappings from json instead of inline structs. This requires pulling `assessment/init.py` one level up in order to reuse fixtures loading logic - Add `escape_sql_identifier` where missing - This prepares for ACLs migration, which relies on similar logic ### Tests <!-- How is this tested? Please see the checklist below and also describe any other relevant tests --> - [x] manually tested - [x] verified on staging environment (screenshot attached)
1 parent 3a43ccf commit 50d57e8

File tree

18 files changed

+331
-212
lines changed

18 files changed

+331
-212
lines changed

src/databricks/labs/ucx/hive_metastore/grants.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def inner(object_type, object_key):
110110

111111
return inner
112112

113-
def uc_grant_sql(self):
113+
def uc_grant_sql(self, object_type: str | None = None, object_key: str | None = None) -> str | None:
114114
"""Get SQL translated SQL statement for granting similar permissions in UC.
115115
116116
If there's no UC equivalent, returns None. This can also be the case for missing mapping.
@@ -120,13 +120,17 @@ def uc_grant_sql(self):
120120
# See: https://docs.databricks.com/sql/language-manual/sql-ref-privileges-hms.html
121121
# See: https://docs.databricks.com/data-governance/unity-catalog/manage-privileges/ownership.html
122122
# See: https://docs.databricks.com/data-governance/unity-catalog/manage-privileges/privileges.html
123-
object_type, object_key = self.this_type_and_key()
123+
if object_type is None:
124+
object_type, object_key = self.this_type_and_key()
124125
hive_to_uc = {
125126
("FUNCTION", "SELECT"): self._uc_action("EXECUTE"),
126127
("TABLE", "SELECT"): self._uc_action("SELECT"),
127128
("TABLE", "MODIFY"): self._uc_action("MODIFY"),
128129
("TABLE", "READ_METADATA"): self._uc_action("BROWSE"),
129130
("TABLE", "OWN"): self._set_owner_sql,
131+
("VIEW", "SELECT"): self._uc_action("SELECT"),
132+
("VIEW", "READ_METADATA"): self._uc_action("BROWSE"),
133+
("VIEW", "OWN"): self._set_owner_sql,
130134
("DATABASE", "USAGE"): self._uc_action("USE SCHEMA"),
131135
("DATABASE", "CREATE"): self._uc_action("CREATE TABLE"),
132136
("DATABASE", "CREATE_NAMED_FUNCTION"): self._uc_action("CREATE FUNCTION"),

src/databricks/labs/ucx/hive_metastore/mapping.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import re
33
from dataclasses import dataclass
44
from functools import partial
5+
from typing import Any
56

67
from databricks.labs.blueprint.installation import Installation
78
from databricks.labs.blueprint.parallel import Threads
@@ -27,6 +28,18 @@ class Rule:
2728
src_table: str
2829
dst_table: str
2930

31+
@classmethod
32+
def from_dict(cls, data: dict[str, Any]):
33+
"""Deserializes the Rule from a dictionary."""
34+
return cls(
35+
workspace_name=data["workspace_name"],
36+
catalog_name=data["catalog_name"],
37+
src_schema=data["src_schema"],
38+
dst_schema=data["dst_schema"],
39+
src_table=data["src_table"],
40+
dst_table=data["dst_table"],
41+
)
42+
3043
@classmethod
3144
def initial(cls, workspace_name: str, catalog_name: str, table: Table) -> "Rule":
3245
return cls(
@@ -52,6 +65,11 @@ class TableToMigrate:
5265
src: Table
5366
rule: Rule
5467

68+
@classmethod
69+
def from_dict(cls, data: dict[str, Any]):
70+
"""Deserializes the TableToMigrate from a dictionary."""
71+
return cls(Table.from_dict(data["table"]), Rule.from_dict(data["rule"]))
72+
5573

5674
class TableMapping:
5775
UCX_SKIP_PROPERTY = "databricks.labs.ucx.skip"

src/databricks/labs/ucx/hive_metastore/table_migrate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from databricks.labs.ucx.config import WorkspaceConfig
2121
from databricks.labs.ucx.framework.crawlers import CrawlerBase
22+
from databricks.labs.ucx.framework.utils import escape_sql_identifier
2223
from databricks.labs.ucx.hive_metastore import TablesCrawler
2324
from databricks.labs.ucx.hive_metastore.mapping import Rule, TableMapping
2425
from databricks.labs.ucx.hive_metastore.tables import MigrationCount, Table, What
@@ -485,7 +486,7 @@ def get_seen_tables(self) -> dict[str, str]:
485486
return seen_tables
486487

487488
def is_upgraded(self, schema: str, table: str) -> bool:
488-
result = self._backend.fetch(f"SHOW TBLPROPERTIES `{schema}`.`{table}`")
489+
result = self._backend.fetch(f"SHOW TBLPROPERTIES {escape_sql_identifier(schema+'.'+table)}")
489490
for value in result:
490491
if value["key"] == "upgraded_to":
491492
logger.info(f"{schema}.{table} is set as upgraded")

src/databricks/labs/ucx/hive_metastore/tables.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,20 @@ class Table:
5656

5757
UPGRADED_FROM_WS_PARAM: typing.ClassVar[str] = "upgraded_from_workspace_id"
5858

59+
@classmethod
60+
def from_dict(cls, data: dict[str, typing.Any]):
61+
return cls(
62+
catalog=data.get("catalog", "UNKNOWN"),
63+
database=data.get("database", "UNKNOWN"),
64+
name=data.get("name", "UNKNOWN"),
65+
object_type=data.get("object_type", "UNKNOWN"),
66+
table_format=data.get("table_format", "UNKNOWN"),
67+
location=data.get("location", None),
68+
view_text=data.get("view_text", None),
69+
upgraded_to=data.get("upgraded_to", None),
70+
storage_properties=data.get("storage_properties", None),
71+
)
72+
5973
@property
6074
def is_delta(self) -> bool:
6175
if self.table_format is None:
@@ -71,17 +85,17 @@ def kind(self) -> str:
7185
return "VIEW" if self.view_text is not None else "TABLE"
7286

7387
def sql_alter_to(self, target_table_key):
74-
return f"ALTER {self.kind} {self.key} SET TBLPROPERTIES ('upgraded_to' = '{target_table_key}');"
88+
return f"ALTER {self.kind} {escape_sql_identifier(self.key)} SET TBLPROPERTIES ('upgraded_to' = '{target_table_key}');"
7589

7690
def sql_alter_from(self, target_table_key, ws_id):
7791
return (
78-
f"ALTER {self.kind} {target_table_key} SET TBLPROPERTIES "
92+
f"ALTER {self.kind} {escape_sql_identifier(target_table_key)} SET TBLPROPERTIES "
7993
f"('upgraded_from' = '{self.key}'"
8094
f" , '{self.UPGRADED_FROM_WS_PARAM}' = '{ws_id}');"
8195
)
8296

8397
def sql_unset_upgraded_to(self):
84-
return f"ALTER {self.kind} {self.key} UNSET TBLPROPERTIES IF EXISTS('upgraded_to');"
98+
return f"ALTER {self.kind} {escape_sql_identifier(self.key)} UNSET TBLPROPERTIES IF EXISTS('upgraded_to');"
8599

86100
@property
87101
def is_dbfs_root(self) -> bool:
@@ -260,7 +274,9 @@ def _describe(self, catalog: str, database: str, table: str) -> Table | None:
260274
upgraded_to=self._parse_table_props(describe.get("Table Properties", "").lower()).get(
261275
"upgraded_to", None
262276
),
263-
storage_properties=self._parse_table_props(describe.get("Storage Properties", "").lower()), # type: ignore[arg-type]
277+
storage_properties=self._parse_table_props(
278+
describe.get("Storage Properties", "").lower()
279+
), # type: ignore[arg-type]
264280
)
265281
except Exception as e: # pylint: disable=broad-exception-caught
266282
# TODO: https://github.com/databrickslabs/ucx/issues/406

tests/unit/__init__.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,18 @@
1+
import base64
2+
import json
13
import logging
4+
import pathlib
5+
from unittest.mock import create_autospec
6+
7+
from databricks.sdk import WorkspaceClient
8+
from databricks.sdk.errors import NotFound
9+
from databricks.sdk.service.compute import ClusterDetails, Policy
10+
from databricks.sdk.service.jobs import BaseJob, BaseRun
11+
from databricks.sdk.service.pipelines import GetPipelineResponse, PipelineStateInfo
12+
from databricks.sdk.service.sql import EndpointConfPair
13+
from databricks.sdk.service.workspace import ExportResponse, GetSecretResponse
14+
15+
from databricks.labs.ucx.hive_metastore.mapping import TableMapping, TableToMigrate
216

317
logging.getLogger("tests").setLevel("DEBUG")
418

@@ -13,3 +27,99 @@
1327
},
1428
},
1529
}
30+
31+
__dir = pathlib.Path(__file__).parent
32+
33+
34+
def _base64(filename: str):
35+
try:
36+
with (__dir / filename).open("rb") as f:
37+
return base64.b64encode(f.read())
38+
except FileNotFoundError as err:
39+
raise NotFound(filename) from err
40+
41+
42+
def _workspace_export(filename: str):
43+
res = _base64(f'workspace/{filename}')
44+
return ExportResponse(content=res.decode('utf8'))
45+
46+
47+
def _load_fixture(filename: str):
48+
try:
49+
with (__dir / filename).open("r") as f:
50+
return json.load(f)
51+
except FileNotFoundError as err:
52+
raise NotFound(filename) from err
53+
54+
55+
_FOLDERS = {
56+
BaseJob: 'assessment/jobs',
57+
BaseRun: 'assessment/jobruns',
58+
ClusterDetails: 'assessment/clusters',
59+
PipelineStateInfo: 'assessment/pipelines',
60+
Policy: 'assessment/policies',
61+
TableToMigrate: 'hive_metastore/tables',
62+
}
63+
64+
65+
def _load_list(cls: type, filename: str, ids=None):
66+
if not ids: # TODO: remove
67+
return [cls.from_dict(_) for _ in _load_fixture(filename)] # type: ignore[attr-defined]
68+
return _id_list(cls, ids)
69+
70+
71+
def _id_list(cls: type, ids=None):
72+
if not ids:
73+
return []
74+
return [cls.from_dict(_load_fixture(f'{_FOLDERS[cls]}/{_}.json')) for _ in ids] # type: ignore[attr-defined]
75+
76+
77+
def _cluster_policy(policy_id: str):
78+
fixture = _load_fixture(f"assessment/policies/{policy_id}.json")
79+
definition = json.dumps(fixture["definition"])
80+
overrides = json.dumps(fixture["policy_family_definition_overrides"])
81+
return Policy(description=definition, policy_family_definition_overrides=overrides)
82+
83+
84+
def _pipeline(pipeline_id: str):
85+
fixture = _load_fixture(f"assessment/pipelines/{pipeline_id}.json")
86+
return GetPipelineResponse.from_dict(fixture)
87+
88+
89+
def _secret_not_found(secret_scope, _):
90+
msg = f"Secret Scope {secret_scope} does not exist!"
91+
raise NotFound(msg)
92+
93+
94+
def workspace_client_mock(
95+
cluster_ids: list[str] | None = None,
96+
pipeline_ids: list[str] | None = None,
97+
job_ids: list[str] | None = None,
98+
jobruns_ids: list[str] | None = None,
99+
policy_ids: list[str] | None = None,
100+
warehouse_config="single-config.json",
101+
secret_exists=True,
102+
):
103+
ws = create_autospec(WorkspaceClient)
104+
ws.clusters.list.return_value = _id_list(ClusterDetails, cluster_ids)
105+
ws.cluster_policies.list.return_value = _id_list(Policy, policy_ids)
106+
ws.cluster_policies.get = _cluster_policy
107+
ws.pipelines.list_pipelines.return_value = _id_list(PipelineStateInfo, pipeline_ids)
108+
ws.pipelines.get = _pipeline
109+
ws.jobs.list.return_value = _id_list(BaseJob, job_ids)
110+
ws.jobs.list_runs.return_value = _id_list(BaseRun, jobruns_ids)
111+
ws.warehouses.get_workspace_warehouse_config().data_access_config = _load_list(
112+
EndpointConfPair, f"assessment/warehouses/{warehouse_config}"
113+
)
114+
ws.workspace.export = _workspace_export
115+
if secret_exists:
116+
ws.secrets.get_secret.return_value = GetSecretResponse(key="username", value="SGVsbG8sIFdvcmxkIQ==")
117+
else:
118+
ws.secrets.get_secret = _secret_not_found
119+
return ws
120+
121+
122+
def table_mapping_mock(tables: list[str] | None = None):
123+
table_mapping = create_autospec(TableMapping)
124+
table_mapping.get_tables_to_migrate.return_value = _id_list(TableToMigrate, tables)
125+
return table_mapping

tests/unit/assessment/__init__.py

Lines changed: 0 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,101 +0,0 @@
1-
import base64
2-
import json
3-
import pathlib
4-
from unittest.mock import create_autospec
5-
6-
from databricks.sdk import WorkspaceClient
7-
from databricks.sdk.errors import NotFound
8-
from databricks.sdk.service.compute import ClusterDetails, Policy
9-
from databricks.sdk.service.jobs import BaseJob, BaseRun
10-
from databricks.sdk.service.pipelines import GetPipelineResponse, PipelineStateInfo
11-
from databricks.sdk.service.sql import EndpointConfPair
12-
from databricks.sdk.service.workspace import ExportResponse, GetSecretResponse
13-
14-
__dir = pathlib.Path(__file__).parent
15-
16-
17-
def _base64(filename: str):
18-
try:
19-
with (__dir / filename).open("rb") as f:
20-
return base64.b64encode(f.read())
21-
except FileNotFoundError as err:
22-
raise NotFound(filename) from err
23-
24-
25-
def _workspace_export(filename: str):
26-
res = _base64(f'workspace/{filename}')
27-
return ExportResponse(content=res.decode('utf8'))
28-
29-
30-
def _load_fixture(filename: str):
31-
try:
32-
with (__dir / filename).open("r") as f:
33-
return json.load(f)
34-
except FileNotFoundError as err:
35-
raise NotFound(filename) from err
36-
37-
38-
_FOLDERS = {
39-
BaseJob: '../assessment/jobs',
40-
BaseRun: '../assessment/jobruns',
41-
ClusterDetails: '../assessment/clusters',
42-
PipelineStateInfo: '../assessment/pipelines',
43-
Policy: '../assessment/policies',
44-
}
45-
46-
47-
def _load_list(cls: type, filename: str, ids=None):
48-
if not ids: # TODO: remove
49-
return [cls.from_dict(_) for _ in _load_fixture(filename)] # type: ignore[attr-defined]
50-
return _id_list(cls, ids)
51-
52-
53-
def _id_list(cls: type, ids=None):
54-
if not ids:
55-
return []
56-
return [cls.from_dict(_load_fixture(f'{_FOLDERS[cls]}/{_}.json')) for _ in ids] # type: ignore[attr-defined]
57-
58-
59-
def _cluster_policy(policy_id: str):
60-
fixture = _load_fixture(f"policies/{policy_id}.json")
61-
definition = json.dumps(fixture["definition"])
62-
overrides = json.dumps(fixture["policy_family_definition_overrides"])
63-
return Policy(description=definition, policy_family_definition_overrides=overrides)
64-
65-
66-
def _pipeline(pipeline_id: str):
67-
fixture = _load_fixture(f"pipelines/{pipeline_id}.json")
68-
return GetPipelineResponse.from_dict(fixture)
69-
70-
71-
def _secret_not_found(secret_scope, _):
72-
msg = f"Secret Scope {secret_scope} does not exist!"
73-
raise NotFound(msg)
74-
75-
76-
def workspace_client_mock(
77-
cluster_ids: list[str] | None = None,
78-
pipeline_ids: list[str] | None = None,
79-
job_ids: list[str] | None = None,
80-
jobruns_ids: list[str] | None = None,
81-
policy_ids: list[str] | None = None,
82-
warehouse_config="single-config.json",
83-
secret_exists=True,
84-
):
85-
ws = create_autospec(WorkspaceClient)
86-
ws.clusters.list.return_value = _id_list(ClusterDetails, cluster_ids)
87-
ws.cluster_policies.list.return_value = _id_list(Policy, policy_ids)
88-
ws.cluster_policies.get = _cluster_policy
89-
ws.pipelines.list_pipelines.return_value = _id_list(PipelineStateInfo, pipeline_ids)
90-
ws.pipelines.get = _pipeline
91-
ws.jobs.list.return_value = _id_list(BaseJob, job_ids)
92-
ws.jobs.list_runs.return_value = _id_list(BaseRun, jobruns_ids)
93-
ws.warehouses.get_workspace_warehouse_config().data_access_config = _load_list(
94-
EndpointConfPair, f"../assessment/warehouses/{warehouse_config}"
95-
)
96-
ws.workspace.export = _workspace_export
97-
if secret_exists:
98-
ws.secrets.get_secret.return_value = GetSecretResponse(key="username", value="SGVsbG8sIFdvcmxkIQ==")
99-
else:
100-
ws.secrets.get_secret = _secret_not_found
101-
return ws

tests/unit/assessment/test_azure.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from databricks.labs.ucx.assessment.azure import AzureServicePrincipalCrawler
44

5-
from . import workspace_client_mock
5+
from .. import workspace_client_mock
66

77

88
def test_azure_service_principal_info_crawl():

tests/unit/assessment/test_clusters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from databricks.labs.ucx.assessment.clusters import ClustersCrawler, PoliciesCrawler
1212
from databricks.labs.ucx.framework.crawlers import SqlBackend
1313

14-
from . import workspace_client_mock
14+
from .. import workspace_client_mock
1515

1616

1717
def test_cluster_assessment():

tests/unit/assessment/test_jobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from databricks.labs.ucx.assessment.jobs import JobsCrawler, SubmitRunsCrawler
66

7-
from . import workspace_client_mock
7+
from .. import workspace_client_mock
88

99

1010
def test_job_assessment():

tests/unit/assessment/test_pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from databricks.labs.ucx.assessment.azure import AzureServicePrincipalCrawler
55
from databricks.labs.ucx.assessment.pipelines import PipelinesCrawler
66

7-
from . import workspace_client_mock
7+
from .. import workspace_client_mock
88

99

1010
def test_pipeline_assessment_with_config():

0 commit comments

Comments
 (0)