Skip to content

Commit df7f1d7

Browse files
authored
Escape column names in target tables of the table migration (#2563)
## Changes Escape column names in target tables of the table migration ### Linked issues Resolves #2544 ### Functionality - [x] modified existing workflow: `-migration-`-ones ### Tests - [x] added unit tests - [x] changed integration tests
1 parent 591761c commit df7f1d7

File tree

9 files changed

+146
-22
lines changed

9 files changed

+146
-22
lines changed

src/databricks/labs/ucx/framework/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,20 @@
44
logger = logging.getLogger(__name__)
55

66

7-
def escape_sql_identifier(path: str) -> str:
7+
def escape_sql_identifier(path: str, *, maxsplit: int = 2) -> str:
88
"""
99
Escapes the path components to make them SQL safe.
1010
1111
Args:
1212
path (str): The dot-separated path of a catalog object.
13+
maxsplit (int): The maximum number of splits to perform.
1314
1415
Returns:
1516
str: The path with all parts escaped in backticks.
1617
"""
1718
if not path:
1819
return path
19-
parts = path.split(".", maxsplit=2)
20+
parts = path.split(".", maxsplit=maxsplit)
2021
escaped = [f"`{part.strip('`').replace('`', '``')}`" for part in parts]
2122
return ".".join(escaped)
2223

src/databricks/labs/ucx/hive_metastore/tables.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -305,17 +305,17 @@ def sql_migrate_view(self, target_table_key):
305305
def sql_migrate_table_in_mount(self, target_table_key: str, table_schema: Iterator[typing.Any]):
306306
fields = []
307307
partitioned_fields = []
308-
next_fileds_are_partitioned = False
308+
next_fields_are_partitioned = False
309309
for key, value, _ in table_schema:
310310
if key == "# Partition Information":
311311
continue
312312
if key == "# col_name":
313-
next_fileds_are_partitioned = True
313+
next_fields_are_partitioned = True
314314
continue
315-
if next_fileds_are_partitioned:
316-
partitioned_fields.append(f"{key}")
315+
if next_fields_are_partitioned:
316+
partitioned_fields.append(escape_sql_identifier(key, maxsplit=0))
317317
else:
318-
fields.append(f"{key} {value}")
318+
fields.append(f"{escape_sql_identifier(key, maxsplit=0)} {value}")
319319

320320
partitioned_str = ""
321321
if partitioned_fields:

src/databricks/labs/ucx/mixins/fixtures.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
AwsIamRoleRequest,
2626
AzureServicePrincipal,
2727
CatalogInfo,
28+
ColumnInfo,
2829
DataSourceFormat,
2930
FunctionInfo,
3031
SchemaInfo,
@@ -55,6 +56,7 @@
5556
from databricks.sdk.service.workspace import ImportFormat, Language
5657

5758
from databricks.labs.ucx.workspace_access.groups import MigratedGroup
59+
from databricks.labs.ucx.framework.utils import escape_sql_identifier
5860

5961
# this file will get to databricks-labs-pytester project and be maintained/refactored there
6062
# pylint: disable=redefined-outer-name,too-many-try-statements,import-outside-toplevel,unnecessary-lambda,too-complex,invalid-name
@@ -1014,6 +1016,37 @@ def remove(schema_info: SchemaInfo):
10141016
@pytest.fixture
10151017
# pylint: disable-next=too-many-statements
10161018
def make_table(ws, sql_backend, make_schema, make_random) -> Generator[Callable[..., TableInfo], None, None]:
1019+
def generate_sql_schema(columns: list[ColumnInfo]) -> str:
1020+
"""Generate a SQL schema from columns."""
1021+
schema = "("
1022+
for index, column in enumerate(columns):
1023+
schema += escape_sql_identifier(column.name or str(index), maxsplit=0)
1024+
if column.type_name is None:
1025+
type_name = "STRING"
1026+
else:
1027+
type_name = column.type_name.value
1028+
schema += f" {type_name}, "
1029+
schema = schema[:-2] + ")" # Remove the last ', '
1030+
return schema
1031+
1032+
def generate_sql_column_casting(existing_columns: list[ColumnInfo], new_columns: list[ColumnInfo]) -> str:
1033+
"""Generate the SQL to cast columns"""
1034+
if any(column.name is None for column in existing_columns):
1035+
raise ValueError(f"Columns should have a name: {existing_columns}")
1036+
if len(new_columns) > len(existing_columns):
1037+
raise ValueError(f"Too many columns: {new_columns}")
1038+
select_expressions = []
1039+
for index, (existing_column, new_column) in enumerate(zip(existing_columns, new_columns)):
1040+
column_name_new = escape_sql_identifier(new_column.name or str(index), maxsplit=0)
1041+
if new_column.type_name is None:
1042+
type_name = "STRING"
1043+
else:
1044+
type_name = new_column.type_name.value
1045+
select_expression = f"CAST({existing_column.name} AS {type_name}) AS {column_name_new}"
1046+
select_expressions.append(select_expression)
1047+
select = ", ".join(select_expressions)
1048+
return select
1049+
10171050
def create( # pylint: disable=too-many-locals,too-many-arguments,too-many-statements
10181051
*,
10191052
catalog_name="hive_metastore",
@@ -1028,6 +1061,7 @@ def create( # pylint: disable=too-many-locals,too-many-arguments,too-many-state
10281061
tbl_properties: dict[str, str] | None = None,
10291062
hiveserde_ddl: str | None = None,
10301063
storage_override: str | None = None,
1064+
columns: list[ColumnInfo] | None = None,
10311065
) -> TableInfo:
10321066
if schema_name is None:
10331067
schema = make_schema(catalog_name=catalog_name)
@@ -1041,6 +1075,10 @@ def create( # pylint: disable=too-many-locals,too-many-arguments,too-many-state
10411075
view_text = None
10421076
full_name = f"{catalog_name}.{schema_name}.{name}".lower()
10431077
ddl = f'CREATE {"VIEW" if view else "TABLE"} {full_name}'
1078+
if columns is None:
1079+
schema = "(id INT, value STRING)"
1080+
else:
1081+
schema = generate_sql_schema(columns)
10441082
if view:
10451083
table_type = TableType.VIEW
10461084
view_text = ctas
@@ -1052,21 +1090,36 @@ def create( # pylint: disable=too-many-locals,too-many-arguments,too-many-state
10521090
data_source_format = DataSourceFormat.JSON
10531091
# DBFS locations are not purged; no suffix necessary.
10541092
storage_location = f"dbfs:/tmp/ucx_test_{make_random(4)}"
1093+
if columns is None:
1094+
select = "*"
1095+
else:
1096+
# These are the columns from the JSON dataset below
1097+
dataset_columns = [
1098+
ColumnInfo(name="calories_burnt"),
1099+
ColumnInfo(name="device_id"),
1100+
ColumnInfo(name="id"),
1101+
ColumnInfo(name="miles_walked"),
1102+
ColumnInfo(name="num_steps"),
1103+
ColumnInfo(name="timestamp"),
1104+
ColumnInfo(name="user_id"),
1105+
ColumnInfo(name="value"),
1106+
]
1107+
select = generate_sql_column_casting(dataset_columns, columns)
10551108
# Modified, otherwise it will identify the table as a DB Dataset
10561109
ddl = (
1057-
f"{ddl} USING json location '{storage_location}' as SELECT * FROM "
1110+
f"{ddl} USING json location '{storage_location}' as SELECT {select} FROM "
10581111
f"JSON.`dbfs:/databricks-datasets/iot-stream/data-device`"
10591112
)
10601113
elif external_csv is not None:
10611114
table_type = TableType.EXTERNAL
10621115
data_source_format = DataSourceFormat.CSV
10631116
storage_location = external_csv
1064-
ddl = f"{ddl} USING CSV OPTIONS (header=true) LOCATION '{storage_location}'"
1117+
ddl = f"{ddl} {schema} USING CSV OPTIONS (header=true) LOCATION '{storage_location}'"
10651118
elif external_delta is not None:
10661119
table_type = TableType.EXTERNAL
10671120
data_source_format = DataSourceFormat.DELTA
10681121
storage_location = external_delta
1069-
ddl = f"{ddl} (id string) LOCATION '{storage_location}'"
1122+
ddl = f"{ddl} {schema} LOCATION '{storage_location}'"
10701123
elif external:
10711124
# external table
10721125
table_type = TableType.EXTERNAL
@@ -1079,7 +1132,7 @@ def create( # pylint: disable=too-many-locals,too-many-arguments,too-many-state
10791132
table_type = TableType.MANAGED
10801133
data_source_format = DataSourceFormat.DELTA
10811134
storage_location = f"dbfs:/user/hive/warehouse/{schema_name}/{name}"
1082-
ddl = f"{ddl} (id INT, value STRING)"
1135+
ddl = f"{ddl} {schema}"
10831136
if tbl_properties:
10841137
tbl_properties.update({"RemoveAfter": get_test_purge_time()})
10851138
else:

tests/integration/hive_metastore/test_migrate.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from databricks.sdk.errors import NotFound
66
from databricks.sdk.retries import retried
77
from databricks.sdk.service.compute import DataSecurityMode, AwsAttributes
8-
from databricks.sdk.service.catalog import Privilege, SecurableType, TableInfo, TableType
8+
from databricks.sdk.service.catalog import ColumnInfo, ColumnTypeName, Privilege, SecurableType, TableInfo, TableType
99
from databricks.sdk.service.iam import PermissionLevel
1010
from databricks.labs.ucx.config import WorkspaceConfig
1111
from databricks.labs.ucx.hive_metastore.mapping import Rule, TableMapping
@@ -20,7 +20,11 @@
2020
@retried(on=[NotFound], timeout=timedelta(minutes=2))
2121
def test_migrate_managed_tables(ws, sql_backend, runtime_ctx, make_catalog):
2222
src_schema = runtime_ctx.make_schema(catalog_name="hive_metastore")
23-
src_managed_table = runtime_ctx.make_table(catalog_name=src_schema.catalog_name, schema_name=src_schema.name)
23+
src_managed_table = runtime_ctx.make_table(
24+
catalog_name=src_schema.catalog_name,
25+
schema_name=src_schema.name,
26+
columns=[ColumnInfo(name="-das-hes-", type_name=ColumnTypeName.STRING)], # Test with column that needs escaping
27+
)
2428

2529
dst_catalog = make_catalog()
2630
dst_schema = runtime_ctx.make_schema(catalog_name=dst_catalog.name, name=src_schema.name)
@@ -48,7 +52,11 @@ def test_migrate_dbfs_non_delta_tables(ws, sql_backend, runtime_ctx, make_catalo
4852
pytest.skip("temporary: only works in azure test env")
4953
src_schema = runtime_ctx.make_schema(catalog_name="hive_metastore")
5054
src_managed_table = runtime_ctx.make_table(
51-
catalog_name=src_schema.catalog_name, non_delta=True, schema_name=src_schema.name
55+
catalog_name=src_schema.catalog_name,
56+
non_delta=True,
57+
schema_name=src_schema.name,
58+
# Test with column that needs escaping
59+
columns=[ColumnInfo(name="1-0`.0-ugly-column", type_name=ColumnTypeName.STRING)],
5260
)
5361

5462
dst_catalog = make_catalog()
@@ -134,7 +142,12 @@ def test_migrate_external_table(
134142
make_mounted_location,
135143
):
136144
src_schema = runtime_ctx.make_schema(catalog_name="hive_metastore")
137-
src_external_table = runtime_ctx.make_table(schema_name=src_schema.name, external_csv=make_mounted_location)
145+
src_external_table = runtime_ctx.make_table(
146+
schema_name=src_schema.name,
147+
external_csv=make_mounted_location,
148+
# Test with column that needs escaping
149+
columns=[ColumnInfo(name="`back`ticks`", type_name=ColumnTypeName.STRING)],
150+
)
138151
dst_catalog = make_catalog()
139152
dst_schema = runtime_ctx.make_schema(catalog_name=dst_catalog.name, name=src_schema.name)
140153
logger.info(f"dst_catalog={dst_catalog.name}, external_table={src_external_table.full_name}")
@@ -667,6 +680,8 @@ def test_migrate_table_in_mount(
667680
src_external_table = runtime_ctx.make_table(
668681
schema_name=src_schema.name,
669682
external_delta=f"dbfs:/mnt/{env_or_skip('TEST_MOUNT_NAME')}/a/b/{table_path}",
683+
# Test with column that needs escaping
684+
columns=[ColumnInfo(name="1-0`.0-ugly-column", type_name=ColumnTypeName.STRING)],
670685
)
671686
table_in_mount_location = f"abfss://[email protected]/a/b/{table_path}"
672687
# TODO: Remove this hack below

tests/unit/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,8 @@ def mock_notebook_resolver():
108108
resolver = create_autospec(BaseNotebookResolver)
109109
resolver.resolve_notebook.return_value = None
110110
return resolver
111+
112+
113+
@pytest.fixture
114+
def mock_backend():
115+
return MockBackend()

tests/unit/framework/test_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,9 @@
2828
)
2929
def test_escaped_path(path: str, expected: str) -> None:
3030
assert escape_sql_identifier(path) == expected
31+
32+
33+
def test_escaped_when_column_contains_period() -> None:
34+
expected = "`column.with.periods`"
35+
path = "column.with.periods"
36+
assert escape_sql_identifier(path, maxsplit=0) == expected

tests/unit/hive_metastore/test_table_migrate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,7 +1068,7 @@ def test_table_in_mount_mapping_with_table_owner():
10681068
)
10691069
table_migrate.migrate_tables(what=What.TABLE_IN_MOUNT)
10701070
assert (
1071-
"CREATE TABLE IF NOT EXISTS `tgt_catalog`.`tgt_db`.`test` (col1 string, col2 decimal) LOCATION 'abfss://bucket@msft/path/test';"
1071+
"CREATE TABLE IF NOT EXISTS `tgt_catalog`.`tgt_db`.`test` (`col1` string, `col2` decimal) LOCATION 'abfss://bucket@msft/path/test';"
10721072
in backend.queries
10731073
)
10741074
migrate_grants.apply.assert_called()
@@ -1111,7 +1111,7 @@ def test_table_in_mount_mapping_with_partition_information():
11111111
)
11121112
table_migrate.migrate_tables(what=What.TABLE_IN_MOUNT)
11131113
assert (
1114-
"CREATE TABLE IF NOT EXISTS `tgt_catalog`.`tgt_db`.`test` (col1 string, col2 decimal) PARTITIONED BY (col1) LOCATION 'abfss://bucket@msft/path/test';"
1114+
"CREATE TABLE IF NOT EXISTS `tgt_catalog`.`tgt_db`.`test` (`col1` string, `col2` decimal) PARTITIONED BY (`col1`) LOCATION 'abfss://bucket@msft/path/test';"
11151115
in backend.queries
11161116
)
11171117
migrate_grants.apply.assert_called()

tests/unit/hive_metastore/test_tables.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,55 @@ def test_uc_sql(table, target, query):
127127
assert table.sql_migrate_external(target) == query
128128

129129

130+
@pytest.mark.parametrize(
131+
"schema,partitions,table_schema",
132+
[
133+
(
134+
"(`id` INT, `value` STRING)",
135+
"",
136+
[
137+
("id", "INT", ""),
138+
("value", "STRING", ""),
139+
],
140+
),
141+
(
142+
"(`column.with.periods` STRING)",
143+
"",
144+
[
145+
("column.with.periods", "STRING", ""),
146+
],
147+
),
148+
(
149+
"(`id` STRING, `country` STRING)",
150+
"PARTITIONED BY (`country`)",
151+
[
152+
("id", "STRING", ""),
153+
("country", "STRING", ""),
154+
("# Partition Information", "", ""),
155+
("# col_name", "", ""),
156+
("country", "", ""),
157+
],
158+
),
159+
],
160+
)
161+
def test_uc_sql_when_table_is_in_mount(schema, partitions, table_schema):
162+
expected = (
163+
f"CREATE TABLE IF NOT EXISTS `new_catalog`.`db`.`external_table` "
164+
f"{schema} {partitions} LOCATION 's3a://foo/bar';"
165+
)
166+
table = Table(
167+
catalog="catalog",
168+
database="db",
169+
name="external_table",
170+
object_type="EXTERNAL",
171+
table_format="DELTA",
172+
location="s3a://foo/bar",
173+
)
174+
target = "new_catalog.db.external_table"
175+
176+
assert table.sql_migrate_table_in_mount(target, table_schema) == expected
177+
178+
130179
def test_tables_returning_error_when_describing():
131180
errors = {"DESCRIBE TABLE EXTENDED `hive_metastore`.`database`.`table1`": "error"}
132181
rows = {

tests/unit/workspace_access/test_manager.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,6 @@
1414
from databricks.labs.ucx.workspace_access.manager import PermissionManager, Permissions
1515

1616

17-
@pytest.fixture
18-
def mock_backend():
19-
return MockBackend()
20-
21-
2217
def test_inventory_permission_manager_init(mock_backend):
2318
permission_manager = PermissionManager(mock_backend, "test_database", [])
2419

0 commit comments

Comments
 (0)