Skip to content

Commit 966a7aa

Browse files
simple col transformations
Signed-off-by: varun-edachali-dbx <[email protected]>
1 parent 55d8c75 commit 966a7aa

File tree

4 files changed

+105
-12
lines changed

4 files changed

+105
-12
lines changed

src/databricks/sql/backend/sea/result_set.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def _prepare_column_mapping(self) -> None:
320320
None,
321321
None,
322322
None,
323-
True,
323+
None,
324324
)
325325

326326
# Set the mapping
@@ -357,13 +357,20 @@ def _normalise_arrow_metadata_cols(self, table: "pyarrow.Table") -> "pyarrow.Tab
357357
else None
358358
)
359359

360-
column = (
361-
pyarrow.nulls(table.num_rows)
362-
if old_idx is None
363-
else table.column(old_idx)
364-
)
365-
new_columns.append(column)
360+
if old_idx is None:
361+
column = pyarrow.nulls(table.num_rows)
362+
else:
363+
column = table.column(old_idx)
364+
# Apply transform if available
365+
if result_column.transform_value:
366+
# Convert to list, apply transform, and convert back
367+
values = column.to_pylist()
368+
transformed_values = [
369+
result_column.transform_value(v) for v in values
370+
]
371+
column = pyarrow.array(transformed_values)
366372

373+
new_columns.append(column)
367374
column_names.append(result_column.thrift_col_name)
368375

369376
return pyarrow.Table.from_arrays(new_columns, names=column_names)
@@ -384,6 +391,9 @@ def _normalise_json_metadata_cols(self, rows: List[List[str]]) -> List[List[Any]
384391
)
385392

386393
value = None if old_idx is None else row[old_idx]
394+
# Apply transform if available
395+
if value is not None and result_column.transform_value:
396+
value = result_column.transform_value(value)
387397
new_row.append(value)
388398
transformed_rows.append(new_row)
389399
return transformed_rows

src/databricks/sql/backend/sea/utils/metadata_mappings.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
from databricks.sql.backend.sea.utils.result_column import ResultColumn
22
from databricks.sql.backend.sea.utils.conversion import SqlType
3+
from databricks.sql.backend.sea.utils.metadata_transforms import (
4+
transform_remarks,
5+
transform_is_autoincrement,
6+
transform_is_nullable,
7+
transform_nullable,
8+
transform_data_type,
9+
transform_ordinal_position,
10+
)
311

412

513
class MetadataColumnMappings:
@@ -28,7 +36,7 @@ class MetadataColumnMappings:
2836
REF_GENERATION_COLUMN = ResultColumn("REF_GENERATION", None, SqlType.STRING)
2937

3038
COL_NAME_COLUMN = ResultColumn("COLUMN_NAME", "col_name", SqlType.STRING)
31-
DATA_TYPE_COLUMN = ResultColumn("DATA_TYPE", None, SqlType.INT)
39+
DATA_TYPE_COLUMN = ResultColumn("DATA_TYPE", None, SqlType.INT, transform_data_type)
3240
COLUMN_TYPE_COLUMN = ResultColumn("TYPE_NAME", "columnType", SqlType.STRING)
3341
COLUMN_SIZE_COLUMN = ResultColumn("COLUMN_SIZE", "columnSize", SqlType.INT)
3442
BUFFER_LENGTH_COLUMN = ResultColumn("BUFFER_LENGTH", None, SqlType.TINYINT)
@@ -43,22 +51,28 @@ class MetadataColumnMappings:
4351
"ORDINAL_POSITION",
4452
"ordinalPosition",
4553
SqlType.INT,
54+
transform_ordinal_position,
4655
)
4756

48-
NULLABLE_COLUMN = ResultColumn("NULLABLE", None, SqlType.INT)
57+
NULLABLE_COLUMN = ResultColumn("NULLABLE", None, SqlType.INT, transform_nullable)
4958
COLUMN_DEF_COLUMN = ResultColumn("COLUMN_DEF", None, SqlType.STRING)
5059
SQL_DATA_TYPE_COLUMN = ResultColumn("SQL_DATA_TYPE", None, SqlType.INT)
5160
SQL_DATETIME_SUB_COLUMN = ResultColumn("SQL_DATETIME_SUB", None, SqlType.INT)
5261
CHAR_OCTET_LENGTH_COLUMN = ResultColumn("CHAR_OCTET_LENGTH", None, SqlType.INT)
53-
IS_NULLABLE_COLUMN = ResultColumn("IS_NULLABLE", "isNullable", SqlType.STRING)
62+
IS_NULLABLE_COLUMN = ResultColumn(
63+
"IS_NULLABLE", "isNullable", SqlType.STRING, transform_is_nullable
64+
)
5465

5566
SCOPE_CATALOG_COLUMN = ResultColumn("SCOPE_CATALOG", None, SqlType.STRING)
5667
SCOPE_SCHEMA_COLUMN = ResultColumn("SCOPE_SCHEMA", None, SqlType.STRING)
5768
SCOPE_TABLE_COLUMN = ResultColumn("SCOPE_TABLE", None, SqlType.STRING)
5869
SOURCE_DATA_TYPE_COLUMN = ResultColumn("SOURCE_DATA_TYPE", None, SqlType.SMALLINT)
5970

6071
IS_AUTO_INCREMENT_COLUMN = ResultColumn(
61-
"IS_AUTOINCREMENT", "isAutoIncrement", SqlType.STRING
72+
"IS_AUTO_INCREMENT",
73+
"isAutoIncrement",
74+
SqlType.STRING,
75+
transform_is_autoincrement,
6276
)
6377
IS_GENERATED_COLUMN = ResultColumn(
6478
"IS_GENERATEDCOLUMN", "isGenerated", SqlType.STRING
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Simple transformation functions for metadata value normalization."""
2+
3+
4+
def transform_is_autoincrement(value):
5+
"""Transform IS_AUTOINCREMENT: boolean to YES/NO string."""
6+
if isinstance(value, bool):
7+
return "YES" if value else "NO"
8+
return value
9+
10+
11+
def transform_is_nullable(value):
12+
"""Transform IS_NULLABLE: true/false to YES/NO string."""
13+
if value is True or value == "true":
14+
return "YES"
15+
elif value is False or value == "false":
16+
return "NO"
17+
return value
18+
19+
20+
def transform_nullable(value):
21+
"""Transform NULLABLE column: boolean/string to integer."""
22+
if value is True or value == "true" or value == "YES":
23+
return 1
24+
elif value is False or value == "false" or value == "NO":
25+
return 0
26+
return value
27+
28+
29+
# Type code mapping based on JDBC specification
30+
TYPE_CODE_MAP = {
31+
"STRING": 12, # VARCHAR
32+
"VARCHAR": 12, # VARCHAR
33+
"CHAR": 1, # CHAR
34+
"INT": 4, # INTEGER
35+
"INTEGER": 4, # INTEGER
36+
"BIGINT": -5, # BIGINT
37+
"SMALLINT": 5, # SMALLINT
38+
"TINYINT": -6, # TINYINT
39+
"DOUBLE": 8, # DOUBLE
40+
"FLOAT": 6, # FLOAT
41+
"REAL": 7, # REAL
42+
"DECIMAL": 3, # DECIMAL
43+
"NUMERIC": 2, # NUMERIC
44+
"BOOLEAN": 16, # BOOLEAN
45+
"DATE": 91, # DATE
46+
"TIMESTAMP": 93, # TIMESTAMP
47+
"BINARY": -2, # BINARY
48+
"ARRAY": 2003, # ARRAY
49+
"MAP": 2002, # JAVA_OBJECT
50+
"STRUCT": 2002, # JAVA_OBJECT
51+
}
52+
53+
54+
def transform_data_type(value):
55+
"""Transform DATA_TYPE: type name to JDBC type code."""
56+
if isinstance(value, str):
57+
# Handle parameterized types like DECIMAL(10,2)
58+
base_type = value.split("(")[0].upper()
59+
return TYPE_CODE_MAP.get(base_type, value)
60+
return value
61+
62+
63+
def transform_ordinal_position(value):
64+
"""Transform ORDINAL_POSITION: decrement by 1 (1-based to 0-based)."""
65+
if isinstance(value, int):
66+
return value - 1
67+
return value

src/databricks/sql/backend/sea/utils/result_column.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Optional
2+
from typing import Optional, Callable, Any
33

44

55
@dataclass(frozen=True)
@@ -11,8 +11,10 @@ class ResultColumn:
1111
thrift_col_name: Column name as returned by Thrift (e.g., "TABLE_CAT")
1212
sea_col_name: Server result column name from SEA (e.g., "catalog")
1313
thrift_col_type: SQL type name
14+
transform_value: Optional callback to transform values for this column
1415
"""
1516

1617
thrift_col_name: str
1718
sea_col_name: Optional[str] # None if SEA doesn't return this column
1819
thrift_col_type: str
20+
transform_value: Optional[Callable[[Any], Any]] = None

0 commit comments

Comments
 (0)