Skip to content

Commit f50d9ab

Browse files
preliminary complex types support
Signed-off-by: varun-edachali-dbx <[email protected]>
1 parent 59d28b0 commit f50d9ab

File tree

5 files changed

+110
-31
lines changed

5 files changed

+110
-31
lines changed

src/databricks/sql/backend/sea/result_set.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import json
34
from typing import Any, List, Optional, TYPE_CHECKING
45

56
import logging
@@ -82,6 +83,47 @@ def __init__(
8283
arrow_schema_bytes=execute_response.arrow_schema_bytes,
8384
)
8485

86+
assert isinstance(
87+
self.backend, SeaDatabricksClient
88+
), "SeaResultSet must be used with SeaDatabricksClient"
89+
90+
def _convert_complex_types_to_string(
91+
self, rows: "pyarrow.Table"
92+
) -> "pyarrow.Table":
93+
"""
94+
Convert complex types (array, struct, map) to string representation.
95+
Args:
96+
rows: Input PyArrow table
97+
Returns:
98+
PyArrow table with complex types converted to strings
99+
"""
100+
101+
if not pyarrow:
102+
raise ImportError(
103+
"PyArrow is not installed: _use_arrow_native_complex_types = False requires pyarrow"
104+
)
105+
106+
def convert_complex_column_to_string(col: "pyarrow.Array") -> "pyarrow.Array":
107+
python_values = col.to_pylist()
108+
json_strings = [
109+
(None if val is None else json.dumps(val)) for val in python_values
110+
]
111+
return pyarrow.array(json_strings, type=pyarrow.string())
112+
113+
converted_columns = []
114+
for col in rows.columns:
115+
converted_col = col
116+
if (
117+
pyarrow.types.is_list(col.type)
118+
or pyarrow.types.is_large_list(col.type)
119+
or pyarrow.types.is_struct(col.type)
120+
or pyarrow.types.is_map(col.type)
121+
):
122+
converted_col = convert_complex_column_to_string(col)
123+
converted_columns.append(converted_col)
124+
125+
return pyarrow.Table.from_arrays(converted_columns, names=rows.column_names)
126+
85127
def _convert_json_types(self, row: List[str]) -> List[Any]:
86128
"""
87129
Convert string values in the row to appropriate Python types based on column metadata.
@@ -200,6 +242,9 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
200242
if isinstance(self.results, JsonQueue):
201243
results = self._convert_json_to_arrow_table(results)
202244

245+
if not self.backend._use_arrow_native_complex_types:
246+
results = self._convert_complex_types_to_string(results)
247+
203248
self._next_row_index += results.num_rows
204249

205250
return results
@@ -213,6 +258,9 @@ def fetchall_arrow(self) -> "pyarrow.Table":
213258
if isinstance(self.results, JsonQueue):
214259
results = self._convert_json_to_arrow_table(results)
215260

261+
if not self.backend._use_arrow_native_complex_types:
262+
results = self._convert_complex_types_to_string(results)
263+
216264
self._next_row_index += results.num_rows
217265

218266
return results

tests/e2e/test_complex_types.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,19 @@ def table_fixture(self, connection_details):
5454
("map_array_col", list),
5555
],
5656
)
57-
def test_read_complex_types_as_arrow(self, field, expected_type, table_fixture):
57+
@pytest.mark.parametrize(
58+
"backend_params",
59+
[
60+
{},
61+
{"use_sea": True},
62+
],
63+
)
64+
def test_read_complex_types_as_arrow(
65+
self, field, expected_type, table_fixture, backend_params
66+
):
5867
"""Confirms the return types of a complex type field when reading as arrow"""
5968

60-
with self.cursor() as cursor:
69+
with self.cursor(extra_params=backend_params) as cursor:
6170
result = cursor.execute(
6271
"SELECT * FROM pysql_test_complex_types_table LIMIT 1"
6372
).fetchone()
@@ -75,11 +84,17 @@ def test_read_complex_types_as_arrow(self, field, expected_type, table_fixture):
7584
("map_array_col"),
7685
],
7786
)
78-
def test_read_complex_types_as_string(self, field, table_fixture):
87+
@pytest.mark.parametrize(
88+
"backend_params",
89+
[
90+
{},
91+
{"use_sea": True},
92+
],
93+
)
94+
def test_read_complex_types_as_string(self, field, table_fixture, backend_params):
7995
"""Confirms the return type of a complex type that is returned as a string"""
80-
with self.cursor(
81-
extra_params={"_use_arrow_native_complex_types": False}
82-
) as cursor:
96+
extra_params = {**backend_params, "_use_arrow_native_complex_types": False}
97+
with self.cursor(extra_params=extra_params) as cursor:
8398
result = cursor.execute(
8499
"SELECT * FROM pysql_test_complex_types_table LIMIT 1"
85100
).fetchone()

tests/unit/test_client.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,7 @@ def test_negative_fetch_throws_exception(self):
262262
mock_backend = Mock()
263263
mock_backend.fetch_results.return_value = (Mock(), False, 0)
264264

265-
result_set = ThriftResultSet(
266-
Mock(), Mock(), mock_backend
267-
)
265+
result_set = ThriftResultSet(Mock(), Mock(), mock_backend)
268266

269267
with self.assertRaises(ValueError) as e:
270268
result_set.fetchmany(-1)

tests/unit/test_downloader.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,14 @@ class DownloaderTests(unittest.TestCase):
2626
def _setup_time_mock_for_download(self, mock_time, end_time):
2727
"""Helper to setup time mock that handles logging system calls."""
2828
call_count = [0]
29+
2930
def time_side_effect():
3031
call_count[0] += 1
3132
if call_count[0] <= 2: # First two calls (validation, start_time)
3233
return 1000
3334
else: # All subsequent calls (logging, duration calculation)
3435
return end_time
36+
3537
mock_time.side_effect = time_side_effect
3638

3739
@patch("time.time", return_value=1000)
@@ -104,7 +106,7 @@ def test_run_get_response_not_ok(self, mock_time):
104106
@patch("time.time")
105107
def test_run_uncompressed_successful(self, mock_time):
106108
self._setup_time_mock_for_download(mock_time, 1000.5)
107-
109+
108110
http_client = DatabricksHttpClient.get_instance()
109111
file_bytes = b"1234567890" * 10
110112
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False)
@@ -133,7 +135,7 @@ def test_run_uncompressed_successful(self, mock_time):
133135
@patch("time.time")
134136
def test_run_compressed_successful(self, mock_time):
135137
self._setup_time_mock_for_download(mock_time, 1000.2)
136-
138+
137139
http_client = DatabricksHttpClient.get_instance()
138140
file_bytes = b"1234567890" * 10
139141
compressed_bytes = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00'

tests/unit/test_telemetry_retry.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory
77
from databricks.sql.auth.retry import DatabricksRetryPolicy
88

9-
PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn'
9+
PATCH_TARGET = "urllib3.connectionpool.HTTPSConnectionPool._get_conn"
10+
1011

1112
def create_mock_conn(responses):
1213
"""Creates a mock connection object whose getresponse() method yields a series of responses."""
@@ -16,15 +17,18 @@ def create_mock_conn(responses):
1617
mock_http_response = MagicMock()
1718
mock_http_response.status = resp.get("status")
1819
mock_http_response.headers = resp.get("headers", {})
19-
body = resp.get("body", b'{}')
20+
body = resp.get("body", b"{}")
2021
mock_http_response.fp = io.BytesIO(body)
22+
2123
def release():
2224
mock_http_response.fp.close()
25+
2326
mock_http_response.release_conn = release
2427
mock_http_responses.append(mock_http_response)
2528
mock_conn.getresponse.side_effect = mock_http_responses
2629
return mock_conn
2730

31+
2832
class TestTelemetryClientRetries:
2933
@pytest.fixture(autouse=True)
3034
def setup_and_teardown(self):
@@ -49,28 +53,28 @@ def get_client(self, session_id, num_retries=3):
4953
host_url="test.databricks.com",
5054
)
5155
client = TelemetryClientFactory.get_telemetry_client(session_id)
52-
56+
5357
retry_policy = DatabricksRetryPolicy(
5458
delay_min=0.01,
5559
delay_max=0.02,
5660
stop_after_attempts_duration=2.0,
57-
stop_after_attempts_count=num_retries,
61+
stop_after_attempts_count=num_retries,
5862
delay_default=0.1,
5963
force_dangerous_codes=[],
60-
urllib3_kwargs={'total': num_retries}
64+
urllib3_kwargs={"total": num_retries},
6165
)
6266
adapter = client._http_client.session.adapters.get("https://")
6367
adapter.max_retries = retry_policy
6468
return client
6569

6670
@pytest.mark.parametrize(
67-
"status_code, description",
68-
[
69-
(401, "Unauthorized"),
70-
(403, "Forbidden"),
71-
(501, "Not Implemented"),
72-
(200, "Success"),
73-
],
71+
"status_code, description",
72+
[
73+
(401, "Unauthorized"),
74+
(403, "Forbidden"),
75+
(501, "Not Implemented"),
76+
(200, "Success"),
77+
],
7478
)
7579
def test_non_retryable_status_codes_are_not_retried(self, status_code, description):
7680
"""
@@ -80,7 +84,9 @@ def test_non_retryable_status_codes_are_not_retried(self, status_code, descripti
8084
client = self.get_client(f"session-{status_code}")
8185
mock_responses = [{"status": status_code}]
8286

83-
with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn:
87+
with patch(
88+
PATCH_TARGET, return_value=create_mock_conn(mock_responses)
89+
) as mock_get_conn:
8490
client.export_failure_log("TestError", "Test message")
8591
TelemetryClientFactory.close(client._session_id_hex)
8692

@@ -92,16 +98,26 @@ def test_exceeds_retry_count_limit(self):
9298
Verifies that the client respects the Retry-After header and retries on 429, 502, 503.
9399
"""
94100
num_retries = 3
95-
expected_total_calls = num_retries + 1
101+
expected_total_calls = num_retries + 1
96102
retry_after = 1
97103
client = self.get_client("session-exceed-limit", num_retries=num_retries)
98-
mock_responses = [{"status": 503, "headers": {"Retry-After": str(retry_after)}}, {"status": 429}, {"status": 502}, {"status": 503}]
99-
100-
with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn:
104+
mock_responses = [
105+
{"status": 503, "headers": {"Retry-After": str(retry_after)}},
106+
{"status": 429},
107+
{"status": 502},
108+
{"status": 503},
109+
]
110+
111+
with patch(
112+
PATCH_TARGET, return_value=create_mock_conn(mock_responses)
113+
) as mock_get_conn:
101114
start_time = time.time()
102115
client.export_failure_log("TestError", "Test message")
103116
TelemetryClientFactory.close(client._session_id_hex)
104117
end_time = time.time()
105-
106-
assert mock_get_conn.return_value.getresponse.call_count == expected_total_calls
107-
assert end_time - start_time > retry_after
118+
119+
assert (
120+
mock_get_conn.return_value.getresponse.call_count
121+
== expected_total_calls
122+
)
123+
assert end_time - start_time > retry_after

0 commit comments

Comments
 (0)