1111 DateTime ,
1212)
1313from collections .abc import Sequence
14- from databricks .sqlalchemy import TIMESTAMP , TINYINT , DatabricksArray , DatabricksMap
14+ from databricks .sqlalchemy import TIMESTAMP , TINYINT , DatabricksArray , DatabricksMap , DatabricksVariant
1515from sqlalchemy .orm import DeclarativeBase , Session
1616from sqlalchemy import select
1717from datetime import date , datetime , time , timedelta , timezone
1818import pandas as pd
1919import numpy as np
2020import decimal
21+ import json
2122
2223
2324class TestComplexTypes (TestSetup ):
@@ -46,7 +47,7 @@ def _parse_to_common_type(self, value):
4647 ):
4748 return tuple (value )
4849 elif isinstance (value , dict ):
49- return tuple (value .items ())
50+ return tuple (sorted ( value .items () ))
5051 elif isinstance (value , np .generic ):
5152 return value .item ()
5253 elif isinstance (value , decimal .Decimal ):
@@ -152,6 +153,35 @@ class MapTable(Base):
152153
153154 return MapTable , sample_data
154155
156+ def sample_variant_table (self ) -> tuple [DeclarativeBase , dict ]:
157+ class Base (DeclarativeBase ):
158+ pass
159+
160+ class VariantTable (Base ):
161+ __tablename__ = "sqlalchemy_variant_table"
162+
163+ int_col = Column (Integer , primary_key = True )
164+ variant_simple_col = Column (DatabricksVariant ())
165+ variant_nested_col = Column (DatabricksVariant ())
166+ variant_array_col = Column (DatabricksVariant ())
167+ variant_mixed_col = Column (DatabricksVariant ())
168+
169+ sample_data = {
170+ "int_col" : 1 ,
171+ "variant_simple_col" : {"key" : "value" , "number" : 42 },
172+ "variant_nested_col" : {"user" : {"name" : "John" , "age" : 30 }, "active" : True },
173+ "variant_array_col" : [1 , 2 , 3 , "hello" , {"nested" : "data" }],
174+ "variant_mixed_col" : {
175+ "string" : "test" ,
176+ "number" : 123 ,
177+ "boolean" : True ,
178+ "array" : [1 , 2 , 3 ],
179+ "object" : {"nested" : "value" }
180+ }
181+ }
182+
183+ return VariantTable , sample_data
184+
155185 def test_insert_array_table_sqlalchemy (self ):
156186 table , sample_data = self .sample_array_table ()
157187
@@ -209,3 +239,57 @@ def test_map_table_creation_pandas(self):
209239 stmt = select (table )
210240 df_result = pd .read_sql (stmt , engine )
211241 assert self ._recursive_compare (df_result .iloc [0 ].to_dict (), sample_data )
242+
243+ def test_insert_variant_table_sqlalchemy (self ):
244+ table , sample_data = self .sample_variant_table ()
245+
246+ with self .table_context (table ) as engine :
247+ # Pre-serialize variant data for SQLAlchemy
248+ variant_data = sample_data .copy ()
249+ for key in ['variant_simple_col' , 'variant_nested_col' , 'variant_array_col' , 'variant_mixed_col' ]:
250+ variant_data [key ] = None if sample_data [key ] is None else json .dumps (sample_data [key ])
251+
252+ sa_obj = table (** variant_data )
253+ session = Session (engine )
254+ session .add (sa_obj )
255+ session .commit ()
256+
257+ stmt = select (table ).where (table .int_col == 1 )
258+
259+ result = session .scalar (stmt )
260+
261+ compare = {key : getattr (result , key ) for key in sample_data .keys ()}
262+ # Parse JSON values back to original format for comparison
263+ for key in ['variant_simple_col' , 'variant_nested_col' , 'variant_array_col' , 'variant_mixed_col' ]:
264+ if compare [key ] is not None :
265+ compare [key ] = json .loads (compare [key ])
266+ assert self ._recursive_compare (compare , sample_data )
267+
268+ def test_variant_table_creation_pandas (self ):
269+ table , sample_data = self .sample_variant_table ()
270+
271+ with self .table_context (table ) as engine :
272+ # Pre-serialize variant data for pandas
273+ variant_data = sample_data .copy ()
274+ for key in ['variant_simple_col' , 'variant_nested_col' , 'variant_array_col' , 'variant_mixed_col' ]:
275+ variant_data [key ] = None if sample_data [key ] is None else json .dumps (sample_data [key ])
276+
277+ # Insert the data into the table
278+ df = pd .DataFrame ([variant_data ])
279+ dtype_mapping = {
280+ "variant_simple_col" : DatabricksVariant ,
281+ "variant_nested_col" : DatabricksVariant ,
282+ "variant_array_col" : DatabricksVariant ,
283+ "variant_mixed_col" : DatabricksVariant
284+ }
285+ df .to_sql (table .__tablename__ , engine , if_exists = "append" , index = False , dtype = dtype_mapping )
286+
287+ # Read the data from the table
288+ stmt = select (table )
289+ df_result = pd .read_sql (stmt , engine )
290+ result_dict = df_result .iloc [0 ].to_dict ()
291+ # Parse JSON values back to original format for comparison
292+ for key in ['variant_simple_col' , 'variant_nested_col' , 'variant_array_col' , 'variant_mixed_col' ]:
293+ if result_dict [key ] is not None :
294+ result_dict [key ] = json .loads (result_dict [key ])
295+ assert self ._recursive_compare (result_dict , sample_data )
0 commit comments