Skip to content

Commit b834ce7

Browse files
committed
add variant support
1 parent 7b51c6e commit b834ce7

File tree

3 files changed

+275
-9
lines changed

3 files changed

+275
-9
lines changed

src/databricks/sql/thrift_backend.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ def convert_col(t_column_desc):
665665
return pyarrow.schema([convert_col(col) for col in t_table_schema.columns])
666666

667667
@staticmethod
668-
def _col_to_description(col):
668+
def _col_to_description(col, field):
669669
type_entry = col.typeDesc.types[0]
670670

671671
if type_entry.primitiveEntry:
@@ -692,12 +692,36 @@ def _col_to_description(col):
692692
else:
693693
precision, scale = None, None
694694

695+
# Extract variant type from field if available
696+
if field is not None:
697+
try:
698+
# Check for variant type in metadata
699+
if field.metadata and b"Spark:DataType:SqlName" in field.metadata:
700+
sql_type = field.metadata.get(b"Spark:DataType:SqlName")
701+
if sql_type == b"VARIANT":
702+
cleaned_type = "variant"
703+
except Exception as e:
704+
logger.debug(f"Could not extract variant type from field: {e}")
705+
695706
return col.columnName, cleaned_type, None, None, precision, scale, None
696707

697708
@staticmethod
698-
def _hive_schema_to_description(t_table_schema):
709+
def _hive_schema_to_description(t_table_schema, schema_bytes=None):
710+
# Create a field lookup dictionary for efficient column access
711+
field_dict = {}
712+
if pyarrow and schema_bytes:
713+
try:
714+
arrow_schema = pyarrow.ipc.read_schema(pyarrow.py_buffer(schema_bytes))
715+
# Build a dictionary mapping column names to fields
716+
for field in arrow_schema:
717+
field_dict[field.name] = field
718+
except Exception as e:
719+
logger.debug(f"Could not parse arrow schema: {e}")
720+
721+
# Process each column with its corresponding Arrow field (if available)
699722
return [
700-
ThriftBackend._col_to_description(col) for col in t_table_schema.columns
723+
ThriftBackend._col_to_description(col, field_dict.get(col.columnName))
724+
for col in t_table_schema.columns
701725
]
702726

703727
def _results_message_to_execute_response(self, resp, operation_state):
@@ -726,9 +750,6 @@ def _results_message_to_execute_response(self, resp, operation_state):
726750
or (not direct_results.resultSet)
727751
or direct_results.resultSet.hasMoreRows
728752
)
729-
description = self._hive_schema_to_description(
730-
t_result_set_metadata_resp.schema
731-
)
732753

733754
if pyarrow:
734755
schema_bytes = (
@@ -740,6 +761,10 @@ def _results_message_to_execute_response(self, resp, operation_state):
740761
else:
741762
schema_bytes = None
742763

764+
description = self._hive_schema_to_description(
765+
t_result_set_metadata_resp.schema, schema_bytes
766+
)
767+
743768
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
744769
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
745770
if direct_results and direct_results.resultSet:
@@ -793,9 +818,6 @@ def get_execution_result(self, op_handle, cursor):
793818
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
794819
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
795820
has_more_rows = resp.hasMoreRows
796-
description = self._hive_schema_to_description(
797-
t_result_set_metadata_resp.schema
798-
)
799821

800822
if pyarrow:
801823
schema_bytes = (
@@ -807,6 +829,10 @@ def get_execution_result(self, op_handle, cursor):
807829
else:
808830
schema_bytes = None
809831

832+
description = self._hive_schema_to_description(
833+
t_result_set_metadata_resp.schema, schema_bytes
834+
)
835+
810836
queue = ResultSetQueueFactory.build_queue(
811837
row_set_type=resp.resultSetMetadata.resultFormat,
812838
t_row_set=resp.results,

tests/e2e/test_variant_types.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import pytest
2+
from datetime import datetime
3+
import json
4+
try:
5+
import pyarrow
6+
except ImportError:
7+
pyarrow = None
8+
9+
from tests.e2e.test_driver import PySQLPytestTestCase
10+
11+
class TestVariantTypes(PySQLPytestTestCase):
12+
"""Tests for the proper detection and handling of VARIANT type columns"""
13+
14+
@pytest.fixture(scope="class")
15+
def variant_table_fixture(self, connection_details):
16+
self.arguments = connection_details.copy()
17+
"""A pytest fixture that creates a table with variant columns, inserts records, yields, and then drops the table"""
18+
19+
with self.cursor() as cursor:
20+
# Check if VARIANT type is supported
21+
try:
22+
# delete the table if it exists
23+
cursor.execute("DROP TABLE IF EXISTS pysql_test_variant_types_table")
24+
25+
# Create the table with variant columns
26+
cursor.execute(
27+
"""
28+
CREATE TABLE IF NOT EXISTS pysql_test_variant_types_table (
29+
id INTEGER,
30+
variant_col VARIANT,
31+
regular_string_col STRING
32+
)
33+
"""
34+
)
35+
36+
# Insert test records with different variant values
37+
cursor.execute(
38+
"""
39+
INSERT INTO pysql_test_variant_types_table
40+
VALUES
41+
(1, PARSE_JSON('{"name": "John", "age": 30}'), 'regular string'),
42+
(2, PARSE_JSON('[1, 2, 3, 4]'), 'another string')
43+
"""
44+
)
45+
46+
variant_supported = True
47+
except Exception as e:
48+
# VARIANT type not supported in this environment
49+
print(f"VARIANT type not supported: {e}")
50+
variant_supported = False
51+
52+
yield variant_supported
53+
54+
# Clean up if table was created
55+
if variant_supported:
56+
cursor.execute("DROP TABLE IF EXISTS pysql_test_variant_types_table")
57+
58+
def test_variant_type_detection(self, variant_table_fixture):
59+
"""Test that VARIANT type columns are properly detected"""
60+
if not variant_table_fixture:
61+
pytest.skip("VARIANT type not supported in this environment")
62+
63+
with self.cursor() as cursor:
64+
cursor.execute("SELECT * FROM pysql_test_variant_types_table LIMIT 1")
65+
66+
# Check that the column type is properly detected as 'variant'
67+
assert cursor.description[1][1] == 'variant', "VARIANT column type not correctly identified"
68+
69+
# Regular string column should still be reported as string
70+
assert cursor.description[2][1] == 'string', "Regular string column type not correctly identified"
71+
72+
def test_variant_data_retrieval(self, variant_table_fixture):
73+
"""Test that VARIANT data is properly retrieved and can be accessed as JSON"""
74+
if not variant_table_fixture:
75+
pytest.skip("VARIANT type not supported in this environment")
76+
77+
with self.cursor() as cursor:
78+
cursor.execute("SELECT * FROM pysql_test_variant_types_table ORDER BY id")
79+
rows = cursor.fetchall()
80+
81+
# First row should have a JSON object
82+
json_obj = rows[0][1]
83+
assert isinstance(json_obj, str), "VARIANT column should be returned as string"
84+
85+
# Parsing to verify it's valid JSON
86+
parsed = json.loads(json_obj)
87+
assert parsed.get('name') == 'John'
88+
assert parsed.get('age') == 30
89+
90+
# Second row should have a JSON array
91+
json_array = rows[1][1]
92+
assert isinstance(json_array, str), "VARIANT array should be returned as string"
93+
94+
# Parsing to verify it's valid JSON array
95+
parsed_array = json.loads(json_array)
96+
assert isinstance(parsed_array, list)
97+
assert parsed_array == [1, 2, 3, 4]

tests/unit/test_thrift_backend.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2200,6 +2200,149 @@ def test_execute_command_sets_complex_type_fields_correctly(
22002200
t_execute_statement_req.useArrowNativeTypes.intervalTypesAsArrow
22012201
)
22022202

2203+
def test_col_to_description_with_variant_type(self):
2204+
# Test variant type detection from Arrow field metadata
2205+
col = ttypes.TColumnDesc(
2206+
columnName="variant_col",
2207+
typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE),
2208+
)
2209+
2210+
# Create a field with variant type in metadata
2211+
field = pyarrow.field(
2212+
"variant_col",
2213+
pyarrow.string(),
2214+
metadata={b'Spark:DataType:SqlName': b'VARIANT'}
2215+
)
2216+
2217+
result = ThriftBackend._col_to_description(col, field)
2218+
2219+
# Verify the result has variant as the type
2220+
self.assertEqual(result[0], "variant_col") # Column name
2221+
self.assertEqual(result[1], "variant") # Type name (should be variant instead of string)
2222+
self.assertIsNone(result[2]) # No display size
2223+
self.assertIsNone(result[3]) # No internal size
2224+
self.assertIsNone(result[4]) # No precision
2225+
self.assertIsNone(result[5]) # No scale
2226+
self.assertIsNone(result[6]) # No null ok
2227+
2228+
def test_col_to_description_without_variant_type(self):
2229+
# Test normal column without variant type
2230+
col = ttypes.TColumnDesc(
2231+
columnName="normal_col",
2232+
typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE),
2233+
)
2234+
2235+
# Create a normal field without variant metadata
2236+
field = pyarrow.field(
2237+
"normal_col",
2238+
pyarrow.string(),
2239+
metadata={}
2240+
)
2241+
2242+
result = ThriftBackend._col_to_description(col, field)
2243+
2244+
# Verify the result has string as the type (unchanged)
2245+
self.assertEqual(result[0], "normal_col") # Column name
2246+
self.assertEqual(result[1], "string") # Type name (should be string)
2247+
self.assertIsNone(result[2]) # No display size
2248+
self.assertIsNone(result[3]) # No internal size
2249+
self.assertIsNone(result[4]) # No precision
2250+
self.assertIsNone(result[5]) # No scale
2251+
self.assertIsNone(result[6]) # No null ok
2252+
2253+
def test_col_to_description_with_null_field(self):
2254+
# Test handling of null field
2255+
col = ttypes.TColumnDesc(
2256+
columnName="missing_field",
2257+
typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE),
2258+
)
2259+
2260+
# Pass None as the field
2261+
result = ThriftBackend._col_to_description(col, None)
2262+
2263+
# Verify the result has string as the type (unchanged)
2264+
self.assertEqual(result[0], "missing_field") # Column name
2265+
self.assertEqual(result[1], "string") # Type name (should be string)
2266+
self.assertIsNone(result[2]) # No display size
2267+
self.assertIsNone(result[3]) # No internal size
2268+
self.assertIsNone(result[4]) # No precision
2269+
self.assertIsNone(result[5]) # No scale
2270+
self.assertIsNone(result[6]) # No null ok
2271+
2272+
def test_hive_schema_to_description_with_arrow_schema(self):
2273+
# Create a table schema with regular and variant columns
2274+
columns = [
2275+
ttypes.TColumnDesc(
2276+
columnName="regular_col",
2277+
typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE),
2278+
),
2279+
ttypes.TColumnDesc(
2280+
columnName="variant_col",
2281+
typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE),
2282+
),
2283+
]
2284+
t_table_schema = ttypes.TTableSchema(columns=columns)
2285+
2286+
# Create an Arrow schema with one variant column
2287+
fields = [
2288+
pyarrow.field("regular_col", pyarrow.string()),
2289+
pyarrow.field(
2290+
"variant_col",
2291+
pyarrow.string(),
2292+
metadata={b'Spark:DataType:SqlName': b'VARIANT'}
2293+
)
2294+
]
2295+
arrow_schema = pyarrow.schema(fields)
2296+
schema_bytes = arrow_schema.serialize().to_pybytes()
2297+
2298+
# Get the description
2299+
description = ThriftBackend._hive_schema_to_description(t_table_schema, schema_bytes)
2300+
2301+
# Verify regular column type
2302+
self.assertEqual(description[0][0], "regular_col")
2303+
self.assertEqual(description[0][1], "string")
2304+
2305+
# Verify variant column type
2306+
self.assertEqual(description[1][0], "variant_col")
2307+
self.assertEqual(description[1][1], "variant")
2308+
2309+
def test_hive_schema_to_description_with_null_schema_bytes(self):
2310+
# Create a simple table schema
2311+
columns = [
2312+
ttypes.TColumnDesc(
2313+
columnName="regular_col",
2314+
typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE),
2315+
),
2316+
]
2317+
t_table_schema = ttypes.TTableSchema(columns=columns)
2318+
2319+
# Get the description with null schema_bytes
2320+
description = ThriftBackend._hive_schema_to_description(t_table_schema, None)
2321+
2322+
# Verify column type remains unchanged
2323+
self.assertEqual(description[0][0], "regular_col")
2324+
self.assertEqual(description[0][1], "string")
2325+
2326+
def test_col_to_description_with_malformed_metadata(self):
2327+
# Test handling of malformed metadata
2328+
col = ttypes.TColumnDesc(
2329+
columnName="weird_field",
2330+
typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE),
2331+
)
2332+
2333+
# Create a field with malformed metadata
2334+
field = pyarrow.field(
2335+
"weird_field",
2336+
pyarrow.string(),
2337+
metadata={b'Spark:DataType:SqlName': b'Some unexpected value'}
2338+
)
2339+
2340+
result = ThriftBackend._col_to_description(col, field)
2341+
2342+
# Verify the type remains unchanged
2343+
self.assertEqual(result[0], "weird_field") # Column name
2344+
self.assertEqual(result[1], "string") # Type name (should remain string)
2345+
22032346

22042347
if __name__ == "__main__":
22052348
unittest.main()

0 commit comments

Comments
 (0)