Skip to content

Commit d697cfc

Browse files
committed
Added support for Variant datatype in SQLAlchemy
1 parent b8e499b commit d697cfc

File tree

6 files changed

+141
-6
lines changed

6 files changed

+141
-6
lines changed

src/databricks/sqlalchemy/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
TIMESTAMP_NTZ,
66
DatabricksArray,
77
DatabricksMap,
8+
DatabricksVariant,
89
)
910

10-
__all__ = ["TINYINT", "TIMESTAMP", "TIMESTAMP_NTZ", "DatabricksArray", "DatabricksMap"]
11+
__all__ = ["TINYINT", "TIMESTAMP", "TIMESTAMP_NTZ", "DatabricksArray", "DatabricksMap", "DatabricksVariant"]

src/databricks/sqlalchemy/_parse.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ def get_comment_from_dte_output(dte_output: List[Dict[str, str]]) -> Optional[st
318318
"map": sqlalchemy.types.String,
319319
"struct": sqlalchemy.types.String,
320320
"uniontype": sqlalchemy.types.String,
321+
"variant": type_overrides.DatabricksVariant,
321322
"decimal": sqlalchemy.types.Numeric,
322323
"timestamp": type_overrides.TIMESTAMP,
323324
"timestamp_ntz": type_overrides.TIMESTAMP_NTZ,

src/databricks/sqlalchemy/_types.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from databricks.sql.utils import ParamEscaper
1111

12+
from sqlalchemy.sql import expression
1213

1314
def process_literal_param_hack(value: Any):
1415
"""This method is supposed to accept a Python type and return a string representation of that type.
@@ -397,3 +398,47 @@ def compile_databricks_map(type_, compiler, **kw):
397398
key_type = compiler.process(type_.key_type, **kw)
398399
value_type = compiler.process(type_.value_type, **kw)
399400
return f"MAP<{key_type},{value_type}>"
401+
402+
class DatabricksVariant(UserDefinedType):
403+
"""
404+
A custom variant type for storing semi-structured data including STRUCT, ARRAY, MAP, and scalar types.
405+
Note: VARIANT MAP types can only have STRING keys.
406+
407+
Examples:
408+
DatabricksVariant() -> VARIANT
409+
410+
Usage:
411+
Column('data', DatabricksVariant())
412+
"""
413+
cache_ok = True
414+
415+
def __init__(self):
416+
self.pe = ParamEscaper()
417+
418+
def bind_processor(self, dialect):
419+
"""Process values before sending to database.
420+
"""
421+
422+
def process(value):
423+
return value
424+
425+
return process
426+
427+
def bind_expression(self, bindvalue):
428+
"""Wrap with PARSE_JSON() in SQL"""
429+
return expression.func.PARSE_JSON(bindvalue)
430+
431+
def literal_processor(self, dialect):
432+
"""Process literal values for SQL generation.
433+
For VARIANT columns, use PARSE_JSON() to properly insert data.
434+
"""
435+
def process(value):
436+
if value is None:
437+
return "NULL"
438+
return self.pe.escape_string(value)
439+
440+
return f"PARSE_JSON('{process}')"
441+
442+
@compiles(DatabricksVariant, "databricks")
443+
def compile_variant(type_, compiler, **kw):
444+
return "VARIANT"

tests/test_local/e2e/test_complex_types.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111
DateTime,
1212
)
1313
from collections.abc import Sequence
14-
from databricks.sqlalchemy import TIMESTAMP, TINYINT, DatabricksArray, DatabricksMap
14+
from databricks.sqlalchemy import TIMESTAMP, TINYINT, DatabricksArray, DatabricksMap, DatabricksVariant
1515
from sqlalchemy.orm import DeclarativeBase, Session
1616
from sqlalchemy import select
1717
from datetime import date, datetime, time, timedelta, timezone
1818
import pandas as pd
1919
import numpy as np
2020
import decimal
21+
import json
2122

2223

2324
class TestComplexTypes(TestSetup):
@@ -46,7 +47,7 @@ def _parse_to_common_type(self, value):
4647
):
4748
return tuple(value)
4849
elif isinstance(value, dict):
49-
return tuple(value.items())
50+
return tuple(sorted(value.items()))
5051
elif isinstance(value, np.generic):
5152
return value.item()
5253
elif isinstance(value, decimal.Decimal):
@@ -152,6 +153,35 @@ class MapTable(Base):
152153

153154
return MapTable, sample_data
154155

156+
def sample_variant_table(self) -> tuple[DeclarativeBase, dict]:
157+
class Base(DeclarativeBase):
158+
pass
159+
160+
class VariantTable(Base):
161+
__tablename__ = "sqlalchemy_variant_table"
162+
163+
int_col = Column(Integer, primary_key=True)
164+
variant_simple_col = Column(DatabricksVariant())
165+
variant_nested_col = Column(DatabricksVariant())
166+
variant_array_col = Column(DatabricksVariant())
167+
variant_mixed_col = Column(DatabricksVariant())
168+
169+
sample_data = {
170+
"int_col": 1,
171+
"variant_simple_col": {"key": "value", "number": 42},
172+
"variant_nested_col": {"user": {"name": "John", "age": 30}, "active": True},
173+
"variant_array_col": [1, 2, 3, "hello", {"nested": "data"}],
174+
"variant_mixed_col": {
175+
"string": "test",
176+
"number": 123,
177+
"boolean": True,
178+
"array": [1, 2, 3],
179+
"object": {"nested": "value"}
180+
}
181+
}
182+
183+
return VariantTable, sample_data
184+
155185
def test_insert_array_table_sqlalchemy(self):
156186
table, sample_data = self.sample_array_table()
157187

@@ -209,3 +239,57 @@ def test_map_table_creation_pandas(self):
209239
stmt = select(table)
210240
df_result = pd.read_sql(stmt, engine)
211241
assert self._recursive_compare(df_result.iloc[0].to_dict(), sample_data)
242+
243+
def test_insert_variant_table_sqlalchemy(self):
244+
table, sample_data = self.sample_variant_table()
245+
246+
with self.table_context(table) as engine:
247+
# Pre-serialize variant data for SQLAlchemy
248+
variant_data = sample_data.copy()
249+
for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']:
250+
variant_data[key] = None if sample_data[key] is None else json.dumps(sample_data[key])
251+
252+
sa_obj = table(**variant_data)
253+
session = Session(engine)
254+
session.add(sa_obj)
255+
session.commit()
256+
257+
stmt = select(table).where(table.int_col == 1)
258+
259+
result = session.scalar(stmt)
260+
261+
compare = {key: getattr(result, key) for key in sample_data.keys()}
262+
# Parse JSON values back to original format for comparison
263+
for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']:
264+
if compare[key] is not None:
265+
compare[key] = json.loads(compare[key])
266+
assert self._recursive_compare(compare, sample_data)
267+
268+
def test_variant_table_creation_pandas(self):
269+
table, sample_data = self.sample_variant_table()
270+
271+
with self.table_context(table) as engine:
272+
# Pre-serialize variant data for pandas
273+
variant_data = sample_data.copy()
274+
for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']:
275+
variant_data[key] = None if sample_data[key] is None else json.dumps(sample_data[key])
276+
277+
# Insert the data into the table
278+
df = pd.DataFrame([variant_data])
279+
dtype_mapping = {
280+
"variant_simple_col": DatabricksVariant,
281+
"variant_nested_col": DatabricksVariant,
282+
"variant_array_col": DatabricksVariant,
283+
"variant_mixed_col": DatabricksVariant
284+
}
285+
df.to_sql(table.__tablename__, engine, if_exists="append", index=False, dtype=dtype_mapping)
286+
287+
# Read the data from the table
288+
stmt = select(table)
289+
df_result = pd.read_sql(stmt, engine)
290+
result_dict = df_result.iloc[0].to_dict()
291+
# Parse JSON values back to original format for comparison
292+
for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']:
293+
if result_dict[key] is not None:
294+
result_dict[key] = json.loads(result_dict[key])
295+
assert self._recursive_compare(result_dict, sample_data)

tests/test_local/test_ddl.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
SetColumnComment,
88
SetTableComment,
99
)
10-
from databricks.sqlalchemy import DatabricksArray, DatabricksMap
10+
from databricks.sqlalchemy import DatabricksArray, DatabricksMap, DatabricksVariant
1111

1212

1313
class DDLTestBase:
@@ -103,7 +103,8 @@ def metadata(self) -> MetaData:
103103
metadata = MetaData()
104104
col1 = Column("array_array_string", DatabricksArray(DatabricksArray(String)))
105105
col2 = Column("map_string_string", DatabricksMap(String, String))
106-
table = Table("complex_type", metadata, col1, col2)
106+
col3 = Column("variant_col", DatabricksVariant())
107+
table = Table("complex_type", metadata, col1, col2, col3)
107108
return metadata
108109

109110
def test_create_table_with_complex_type(self, metadata):
@@ -112,3 +113,4 @@ def test_create_table_with_complex_type(self, metadata):
112113

113114
assert "array_array_string ARRAY<ARRAY<STRING>>" in output
114115
assert "map_string_string MAP<STRING,STRING>" in output
116+
assert "variant_col VARIANT" in output

tests/test_local/test_types.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import sqlalchemy
55

66
from databricks.sqlalchemy.base import DatabricksDialect
7-
from databricks.sqlalchemy._types import TINYINT, TIMESTAMP, TIMESTAMP_NTZ
7+
from databricks.sqlalchemy._types import TINYINT, TIMESTAMP, TIMESTAMP_NTZ, DatabricksVariant
88

99

1010
class DatabricksDataType(enum.Enum):
@@ -28,6 +28,7 @@ class DatabricksDataType(enum.Enum):
2828
ARRAY = enum.auto()
2929
MAP = enum.auto()
3030
STRUCT = enum.auto()
31+
VARIANT = enum.auto()
3132

3233

3334
# Defines the way that SQLAlchemy CamelCase types are compiled into Databricks SQL types.
@@ -131,6 +132,7 @@ def test_numeric_renders_as_decimal_with_precision_and_scale(self):
131132
TINYINT: DatabricksDataType.TINYINT,
132133
TIMESTAMP: DatabricksDataType.TIMESTAMP,
133134
TIMESTAMP_NTZ: DatabricksDataType.TIMESTAMP_NTZ,
135+
DatabricksVariant: DatabricksDataType.VARIANT,
134136
}
135137

136138

0 commit comments

Comments
 (0)