Skip to content

Commit 2564d41

Browse files
committed
merge sea-migration fixes
Signed-off-by: Sai Shree Pradhan <[email protected]>
2 parents 79d58dd + c07beb1 commit 2564d41

File tree

17 files changed

+1327
-346
lines changed

17 files changed

+1327
-346
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 58 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import re
66
from typing import Any, Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set
77

8-
from databricks.sql.backend.sea.models.base import ResultManifest
8+
from databricks.sql.backend.sea.models.base import ExternalLink, ResultManifest
99
from databricks.sql.backend.sea.utils.constants import (
1010
ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP,
1111
ResultFormat,
@@ -28,7 +28,7 @@
2828
BackendType,
2929
ExecuteResponse,
3030
)
31-
from databricks.sql.exc import DatabaseError, ProgrammingError, ServerOperationError
31+
from databricks.sql.exc import DatabaseError, ServerOperationError
3232
from databricks.sql.backend.sea.utils.http_client import SeaHttpClient
3333
from databricks.sql.types import SSLOptions
3434

@@ -44,22 +44,23 @@
4444
GetStatementResponse,
4545
CreateSessionResponse,
4646
)
47+
from databricks.sql.backend.sea.models.responses import GetChunksResponse
4748

4849
logger = logging.getLogger(__name__)
4950

5051

5152
def _filter_session_configuration(
52-
session_configuration: Optional[Dict[str, str]]
53-
) -> Optional[Dict[str, str]]:
53+
session_configuration: Optional[Dict[str, Any]],
54+
) -> Dict[str, str]:
5455
if not session_configuration:
55-
return None
56+
return {}
5657

5758
filtered_session_configuration = {}
5859
ignored_configs: Set[str] = set()
5960

6061
for key, value in session_configuration.items():
6162
if key.upper() in ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP:
62-
filtered_session_configuration[key.lower()] = value
63+
filtered_session_configuration[key.lower()] = str(value)
6364
else:
6465
ignored_configs.add(key)
6566

@@ -88,6 +89,7 @@ class SeaDatabricksClient(DatabricksClient):
8889
STATEMENT_PATH = BASE_PATH + "statements"
8990
STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}"
9091
CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel"
92+
CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}"
9193

9294
# SEA constants
9395
POLL_INTERVAL_SECONDS = 0.2
@@ -123,18 +125,22 @@ def __init__(
123125
)
124126

125127
self._max_download_threads = kwargs.get("max_download_threads", 10)
128+
self._ssl_options = ssl_options
129+
self._use_arrow_native_complex_types = kwargs.get(
130+
"_use_arrow_native_complex_types", True
131+
)
126132

127133
# Extract warehouse ID from http_path
128134
self.warehouse_id = self._extract_warehouse_id(http_path)
129135

130136
# Initialize HTTP client
131-
self.http_client = SeaHttpClient(
137+
self._http_client = SeaHttpClient(
132138
server_hostname=server_hostname,
133139
port=port,
134140
http_path=http_path,
135141
http_headers=http_headers,
136142
auth_provider=auth_provider,
137-
ssl_options=ssl_options,
143+
ssl_options=self._ssl_options,
138144
**kwargs,
139145
)
140146

@@ -173,7 +179,7 @@ def _extract_warehouse_id(self, http_path: str) -> str:
173179
f"Note: SEA only works for warehouses."
174180
)
175181
logger.error(error_message)
176-
raise ProgrammingError(error_message)
182+
raise ValueError(error_message)
177183

178184
@property
179185
def max_download_threads(self) -> int:
@@ -182,7 +188,7 @@ def max_download_threads(self) -> int:
182188

183189
def open_session(
184190
self,
185-
session_configuration: Optional[Dict[str, str]],
191+
session_configuration: Optional[Dict[str, Any]],
186192
catalog: Optional[str],
187193
schema: Optional[str],
188194
) -> SessionId:
@@ -220,7 +226,7 @@ def open_session(
220226
schema=schema,
221227
)
222228

223-
response = self.http_client._make_request(
229+
response = self._http_client._make_request(
224230
method="POST", path=self.SESSION_PATH, data=request_data.to_dict()
225231
)
226232

@@ -245,7 +251,7 @@ def close_session(self, session_id: SessionId) -> None:
245251
session_id: The session identifier returned by open_session()
246252
247253
Raises:
248-
ProgrammingError: If the session ID is invalid
254+
ValueError: If the session ID is invalid
249255
OperationalError: If there's an error closing the session
250256
"""
251257

@@ -260,7 +266,7 @@ def close_session(self, session_id: SessionId) -> None:
260266
session_id=sea_session_id,
261267
)
262268

263-
self.http_client._make_request(
269+
self._http_client._make_request(
264270
method="DELETE",
265271
path=self.SESSION_PATH_WITH_ID.format(sea_session_id),
266272
data=request_data.to_dict(),
@@ -342,7 +348,7 @@ def _results_message_to_execute_response(
342348

343349
# Check for compression
344350
lz4_compressed = (
345-
response.manifest.result_compression == ResultCompression.LZ4_FRAME
351+
response.manifest.result_compression == ResultCompression.LZ4_FRAME.value
346352
)
347353

348354
execute_response = ExecuteResponse(
@@ -424,7 +430,7 @@ def execute_command(
424430
enforce_embedded_schema_correctness: Whether to enforce schema correctness
425431
426432
Returns:
427-
ResultSet: A SeaResultSet instance for the executed command
433+
SeaResultSet: A SeaResultSet instance for the executed command
428434
"""
429435

430436
if session_id.backend_type != BackendType.SEA:
@@ -471,7 +477,7 @@ def execute_command(
471477
result_compression=result_compression,
472478
)
473479

474-
response_data = self.http_client._make_request(
480+
response_data = self._http_client._make_request(
475481
method="POST", path=self.STATEMENT_PATH, data=request.to_dict()
476482
)
477483
response = ExecuteStatementResponse.from_dict(response_data)
@@ -505,7 +511,7 @@ def cancel_command(self, command_id: CommandId) -> None:
505511
command_id: Command identifier to cancel
506512
507513
Raises:
508-
ProgrammingError: If the command ID is invalid
514+
ValueError: If the command ID is invalid
509515
"""
510516

511517
if command_id.backend_type != BackendType.SEA:
@@ -516,7 +522,7 @@ def cancel_command(self, command_id: CommandId) -> None:
516522
raise ValueError("Not a valid SEA command ID")
517523

518524
request = CancelStatementRequest(statement_id=sea_statement_id)
519-
self.http_client._make_request(
525+
self._http_client._make_request(
520526
method="POST",
521527
path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id),
522528
data=request.to_dict(),
@@ -530,7 +536,7 @@ def close_command(self, command_id: CommandId) -> None:
530536
command_id: Command identifier to close
531537
532538
Raises:
533-
ProgrammingError: If the command ID is invalid
539+
ValueError: If the command ID is invalid
534540
"""
535541

536542
if command_id.backend_type != BackendType.SEA:
@@ -541,7 +547,7 @@ def close_command(self, command_id: CommandId) -> None:
541547
raise ValueError("Not a valid SEA command ID")
542548

543549
request = CloseStatementRequest(statement_id=sea_statement_id)
544-
self.http_client._make_request(
550+
self._http_client._make_request(
545551
method="DELETE",
546552
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
547553
data=request.to_dict(),
@@ -558,7 +564,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
558564
CommandState: The current state of the command
559565
560566
Raises:
561-
ProgrammingError: If the command ID is invalid
567+
ValueError: If the command ID is invalid
562568
"""
563569

564570
if command_id.backend_type != BackendType.SEA:
@@ -569,7 +575,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
569575
raise ValueError("Not a valid SEA command ID")
570576

571577
request = GetStatementRequest(statement_id=sea_statement_id)
572-
response_data = self.http_client._make_request(
578+
response_data = self._http_client._make_request(
573579
method="GET",
574580
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
575581
data=request.to_dict(),
@@ -609,7 +615,7 @@ def get_execution_result(
609615
request = GetStatementRequest(statement_id=sea_statement_id)
610616

611617
# Get the statement result
612-
response_data = self.http_client._make_request(
618+
response_data = self._http_client._make_request(
613619
method="GET",
614620
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
615621
data=request.to_dict(),
@@ -631,6 +637,35 @@ def get_execution_result(
631637
arraysize=cursor.arraysize,
632638
)
633639

640+
def get_chunk_link(self, statement_id: str, chunk_index: int) -> ExternalLink:
641+
"""
642+
Get links for chunks starting from the specified index.
643+
Args:
644+
statement_id: The statement ID
645+
chunk_index: The starting chunk index
646+
Returns:
647+
ExternalLink: External link for the chunk
648+
"""
649+
650+
response_data = self._http_client._make_request(
651+
method="GET",
652+
path=self.CHUNK_PATH_WITH_ID_AND_INDEX.format(statement_id, chunk_index),
653+
)
654+
response = GetChunksResponse.from_dict(response_data)
655+
656+
links = response.external_links or []
657+
link = next((l for l in links if l.chunk_index == chunk_index), None)
658+
if not link:
659+
raise ServerOperationError(
660+
f"No link found for chunk index {chunk_index}",
661+
{
662+
"operation-id": statement_id,
663+
"diagnostic-info": None,
664+
},
665+
)
666+
667+
return link
668+
634669
# == Metadata Operations ==
635670

636671
def get_catalogs(

src/databricks/sql/backend/sea/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
ExecuteStatementResponse,
2828
GetStatementResponse,
2929
CreateSessionResponse,
30+
GetChunksResponse,
3031
)
3132

3233
__all__ = [
@@ -49,4 +50,5 @@
4950
"ExecuteStatementResponse",
5051
"GetStatementResponse",
5152
"CreateSessionResponse",
53+
"GetChunksResponse",
5254
]

src/databricks/sql/backend/sea/models/responses.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
These models define the structures used in SEA API responses.
55
"""
66

7-
from typing import Dict, Any
7+
from typing import Dict, Any, List, Optional
88
from dataclasses import dataclass
99

1010
from databricks.sql.backend.types import CommandState
@@ -154,3 +154,37 @@ class CreateSessionResponse:
154154
def from_dict(cls, data: Dict[str, Any]) -> "CreateSessionResponse":
155155
"""Create a CreateSessionResponse from a dictionary."""
156156
return cls(session_id=data.get("session_id", ""))
157+
158+
159+
@dataclass
160+
class GetChunksResponse:
161+
"""
162+
Response from getting chunks for a statement.
163+
164+
The response model can be found in the docs, here:
165+
https://docs.databricks.com/api/workspace/statementexecution/getstatementresultchunkn
166+
"""
167+
168+
data: Optional[List[List[Any]]] = None
169+
external_links: Optional[List[ExternalLink]] = None
170+
byte_count: Optional[int] = None
171+
chunk_index: Optional[int] = None
172+
next_chunk_index: Optional[int] = None
173+
next_chunk_internal_link: Optional[str] = None
174+
row_count: Optional[int] = None
175+
row_offset: Optional[int] = None
176+
177+
@classmethod
178+
def from_dict(cls, data: Dict[str, Any]) -> "GetChunksResponse":
179+
"""Create a GetChunksResponse from a dictionary."""
180+
result = _parse_result({"result": data})
181+
return cls(
182+
data=result.data,
183+
external_links=result.external_links,
184+
byte_count=result.byte_count,
185+
chunk_index=result.chunk_index,
186+
next_chunk_index=result.next_chunk_index,
187+
next_chunk_internal_link=result.next_chunk_internal_link,
188+
row_count=result.row_count,
189+
row_offset=result.row_offset,
190+
)

0 commit comments

Comments
 (0)