2525 AwsIamRoleRequest ,
2626 AzureServicePrincipal ,
2727 CatalogInfo ,
28+ ColumnInfo ,
2829 DataSourceFormat ,
2930 FunctionInfo ,
3031 SchemaInfo ,
5556from databricks .sdk .service .workspace import ImportFormat , Language
5657
5758from databricks .labs .ucx .workspace_access .groups import MigratedGroup
59+ from databricks .labs .ucx .framework .utils import escape_sql_identifier
5860
5961# this file will get to databricks-labs-pytester project and be maintained/refactored there
6062# pylint: disable=redefined-outer-name,too-many-try-statements,import-outside-toplevel,unnecessary-lambda,too-complex,invalid-name
@@ -1014,6 +1016,37 @@ def remove(schema_info: SchemaInfo):
10141016@pytest .fixture
10151017# pylint: disable-next=too-many-statements
10161018def make_table (ws , sql_backend , make_schema , make_random ) -> Generator [Callable [..., TableInfo ], None , None ]:
1019+ def generate_sql_schema (columns : list [ColumnInfo ]) -> str :
1020+ """Generate a SQL schema from columns."""
1021+ schema = "("
1022+ for index , column in enumerate (columns ):
1023+ schema += escape_sql_identifier (column .name or str (index ), maxsplit = 0 )
1024+ if column .type_name is None :
1025+ type_name = "STRING"
1026+ else :
1027+ type_name = column .type_name .value
1028+ schema += f" { type_name } , "
1029+ schema = schema [:- 2 ] + ")" # Remove the last ', '
1030+ return schema
1031+
1032+ def generate_sql_column_casting (existing_columns : list [ColumnInfo ], new_columns : list [ColumnInfo ]) -> str :
1033+ """Generate the SQL to cast columns"""
1034+ if any (column .name is None for column in existing_columns ):
1035+ raise ValueError (f"Columns should have a name: { existing_columns } " )
1036+ if len (new_columns ) > len (existing_columns ):
1037+ raise ValueError (f"Too many columns: { new_columns } " )
1038+ select_expressions = []
1039+ for index , (existing_column , new_column ) in enumerate (zip (existing_columns , new_columns )):
1040+ column_name_new = escape_sql_identifier (new_column .name or str (index ), maxsplit = 0 )
1041+ if new_column .type_name is None :
1042+ type_name = "STRING"
1043+ else :
1044+ type_name = new_column .type_name .value
1045+ select_expression = f"CAST({ existing_column .name } AS { type_name } ) AS { column_name_new } "
1046+ select_expressions .append (select_expression )
1047+ select = ", " .join (select_expressions )
1048+ return select
1049+
10171050 def create ( # pylint: disable=too-many-locals,too-many-arguments,too-many-statements
10181051 * ,
10191052 catalog_name = "hive_metastore" ,
@@ -1028,6 +1061,7 @@ def create( # pylint: disable=too-many-locals,too-many-arguments,too-many-state
10281061 tbl_properties : dict [str , str ] | None = None ,
10291062 hiveserde_ddl : str | None = None ,
10301063 storage_override : str | None = None ,
1064+ columns : list [ColumnInfo ] | None = None ,
10311065 ) -> TableInfo :
10321066 if schema_name is None :
10331067 schema = make_schema (catalog_name = catalog_name )
@@ -1041,6 +1075,10 @@ def create( # pylint: disable=too-many-locals,too-many-arguments,too-many-state
10411075 view_text = None
10421076 full_name = f"{ catalog_name } .{ schema_name } .{ name } " .lower ()
10431077 ddl = f'CREATE { "VIEW" if view else "TABLE" } { full_name } '
1078+ if columns is None :
1079+ schema = "(id INT, value STRING)"
1080+ else :
1081+ schema = generate_sql_schema (columns )
10441082 if view :
10451083 table_type = TableType .VIEW
10461084 view_text = ctas
@@ -1052,21 +1090,36 @@ def create( # pylint: disable=too-many-locals,too-many-arguments,too-many-state
10521090 data_source_format = DataSourceFormat .JSON
10531091 # DBFS locations are not purged; no suffix necessary.
10541092 storage_location = f"dbfs:/tmp/ucx_test_{ make_random (4 )} "
1093+ if columns is None :
1094+ select = "*"
1095+ else :
1096+ # These are the columns from the JSON dataset below
1097+ dataset_columns = [
1098+ ColumnInfo (name = "calories_burnt" ),
1099+ ColumnInfo (name = "device_id" ),
1100+ ColumnInfo (name = "id" ),
1101+ ColumnInfo (name = "miles_walked" ),
1102+ ColumnInfo (name = "num_steps" ),
1103+ ColumnInfo (name = "timestamp" ),
1104+ ColumnInfo (name = "user_id" ),
1105+ ColumnInfo (name = "value" ),
1106+ ]
1107+ select = generate_sql_column_casting (dataset_columns , columns )
10551108 # Modified, otherwise it will identify the table as a DB Dataset
10561109 ddl = (
1057- f"{ ddl } USING json location '{ storage_location } ' as SELECT * FROM "
1110+ f"{ ddl } USING json location '{ storage_location } ' as SELECT { select } FROM "
10581111 f"JSON.`dbfs:/databricks-datasets/iot-stream/data-device`"
10591112 )
10601113 elif external_csv is not None :
10611114 table_type = TableType .EXTERNAL
10621115 data_source_format = DataSourceFormat .CSV
10631116 storage_location = external_csv
1064- ddl = f"{ ddl } USING CSV OPTIONS (header=true) LOCATION '{ storage_location } '"
1117+ ddl = f"{ ddl } { schema } USING CSV OPTIONS (header=true) LOCATION '{ storage_location } '"
10651118 elif external_delta is not None :
10661119 table_type = TableType .EXTERNAL
10671120 data_source_format = DataSourceFormat .DELTA
10681121 storage_location = external_delta
1069- ddl = f"{ ddl } (id string) LOCATION '{ storage_location } '"
1122+ ddl = f"{ ddl } { schema } LOCATION '{ storage_location } '"
10701123 elif external :
10711124 # external table
10721125 table_type = TableType .EXTERNAL
@@ -1079,7 +1132,7 @@ def create( # pylint: disable=too-many-locals,too-many-arguments,too-many-state
10791132 table_type = TableType .MANAGED
10801133 data_source_format = DataSourceFormat .DELTA
10811134 storage_location = f"dbfs:/user/hive/warehouse/{ schema_name } /{ name } "
1082- ddl = f"{ ddl } (id INT, value STRING) "
1135+ ddl = f"{ ddl } { schema } "
10831136 if tbl_properties :
10841137 tbl_properties .update ({"RemoveAfter" : get_test_purge_time ()})
10851138 else :
0 commit comments