Skip to content
107 changes: 55 additions & 52 deletions src/databricks/sql/backend/sea/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
DeleteSessionRequest,
StatementParameter,
ExecuteStatementResponse,
GetStatementResponse,
CreateSessionResponse,
)

Expand Down Expand Up @@ -324,7 +323,7 @@ def _extract_description_from_manifest(
return columns

def _results_message_to_execute_response(
self, response: GetStatementResponse
self, response: ExecuteStatementResponse
) -> ExecuteResponse:
"""
Convert a SEA response to an ExecuteResponse and extract result data.
Expand Down Expand Up @@ -358,6 +357,28 @@ def _results_message_to_execute_response(

return execute_response

def _response_to_result_set(
self, response: ExecuteStatementResponse, cursor: Cursor
) -> SeaResultSet:
"""
Convert a SEA response to a SeaResultSet.
"""

# Create and return a SeaResultSet
from databricks.sql.backend.sea.result_set import SeaResultSet

execute_response = self._results_message_to_execute_response(response)

return SeaResultSet(
connection=cursor.connection,
execute_response=execute_response,
sea_client=self,
result_data=response.result,
manifest=response.manifest,
buffer_size_bytes=cursor.buffer_size_bytes,
arraysize=cursor.arraysize,
)

def _check_command_not_in_failed_or_closed_state(
self, state: CommandState, command_id: CommandId
) -> None:
Expand All @@ -378,7 +399,7 @@ def _check_command_not_in_failed_or_closed_state(

def _wait_until_command_done(
self, response: ExecuteStatementResponse
) -> CommandState:
) -> ExecuteStatementResponse:
"""
Wait until a command is done.
"""
Expand All @@ -388,11 +409,12 @@ def _wait_until_command_done(

while state in [CommandState.PENDING, CommandState.RUNNING]:
time.sleep(self.POLL_INTERVAL_SECONDS)
state = self.get_query_state(command_id)
response = self._poll_query(command_id)
state = response.status.state

self._check_command_not_in_failed_or_closed_state(state, command_id)

return state
return response

def execute_command(
self,
Expand Down Expand Up @@ -494,8 +516,12 @@ def execute_command(
if async_op:
return None

self._wait_until_command_done(response)
return self.get_execution_result(command_id, cursor)
if response.status.state == CommandState.SUCCEEDED:
# if the response succeeded within the wait_timeout, return the results immediately
return self._response_to_result_set(response, cursor)

response = self._wait_until_command_done(response)
return self._response_to_result_set(response, cursor)

def cancel_command(self, command_id: CommandId) -> None:
"""
Expand Down Expand Up @@ -547,18 +573,9 @@ def close_command(self, command_id: CommandId) -> None:
data=request.to_dict(),
)

def get_query_state(self, command_id: CommandId) -> CommandState:
def _poll_query(self, command_id: CommandId) -> ExecuteStatementResponse:
"""
Get the state of a running query.

Args:
command_id: Command identifier

Returns:
CommandState: The current state of the command

Raises:
ProgrammingError: If the command ID is invalid
Poll for the current command info.
"""

if command_id.backend_type != BackendType.SEA:
Expand All @@ -574,9 +591,25 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
data=request.to_dict(),
)
response = ExecuteStatementResponse.from_dict(response_data)

# Parse the response
response = GetStatementResponse.from_dict(response_data)
return response

def get_query_state(self, command_id: CommandId) -> CommandState:
"""
Get the state of a running query.

Args:
command_id: Command identifier

Returns:
CommandState: The current state of the command

Raises:
ProgrammingError: If the command ID is invalid
"""

response = self._poll_query(command_id)
return response.status.state

def get_execution_result(
Expand All @@ -598,38 +631,8 @@ def get_execution_result(
ValueError: If the command ID is invalid
"""

if command_id.backend_type != BackendType.SEA:
raise ValueError("Not a valid SEA command ID")

sea_statement_id = command_id.to_sea_statement_id()
if sea_statement_id is None:
raise ValueError("Not a valid SEA command ID")

# Create the request model
request = GetStatementRequest(statement_id=sea_statement_id)

# Get the statement result
response_data = self.http_client._make_request(
method="GET",
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
data=request.to_dict(),
)
response = GetStatementResponse.from_dict(response_data)

# Create and return a SeaResultSet
from databricks.sql.backend.sea.result_set import SeaResultSet

execute_response = self._results_message_to_execute_response(response)

return SeaResultSet(
connection=cursor.connection,
execute_response=execute_response,
sea_client=self,
result_data=response.result,
manifest=response.manifest,
buffer_size_bytes=cursor.buffer_size_bytes,
arraysize=cursor.arraysize,
)
response = self._poll_query(command_id)
return self._response_to_result_set(response, cursor)

# == Metadata Operations ==

Expand Down
2 changes: 0 additions & 2 deletions src/databricks/sql/backend/sea/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

from databricks.sql.backend.sea.models.responses import (
ExecuteStatementResponse,
GetStatementResponse,
CreateSessionResponse,
)

Expand All @@ -47,6 +46,5 @@
"DeleteSessionRequest",
# Response models
"ExecuteStatementResponse",
"GetStatementResponse",
"CreateSessionResponse",
]
20 changes: 0 additions & 20 deletions src/databricks/sql/backend/sea/models/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,26 +124,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse":
)


@dataclass
class GetStatementResponse:
"""Representation of the response from getting information about a statement."""

statement_id: str
status: StatementStatus
manifest: ResultManifest
result: ResultData

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse":
"""Create a GetStatementResponse from a dictionary."""
return cls(
statement_id=data.get("statement_id", ""),
status=_parse_status(data),
manifest=_parse_manifest(data),
result=_parse_result(data),
)


@dataclass
class CreateSessionResponse:
"""Representation of the response from creating a new session."""
Expand Down
9 changes: 3 additions & 6 deletions tests/unit/test_sea_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def test_command_execution_sync(
mock_http_client._make_request.return_value = execute_response

with patch.object(
sea_client, "get_execution_result", return_value="mock_result_set"
sea_client, "_response_to_result_set", return_value="mock_result_set"
) as mock_get_result:
result = sea_client.execute_command(
operation="SELECT 1",
Expand All @@ -242,9 +242,6 @@ def test_command_execution_sync(
enforce_embedded_schema_correctness=False,
)
assert result == "mock_result_set"
cmd_id_arg = mock_get_result.call_args[0][0]
assert isinstance(cmd_id_arg, CommandId)
assert cmd_id_arg.guid == "test-statement-123"

# Test with invalid session ID
with pytest.raises(ValueError) as excinfo:
Expand Down Expand Up @@ -332,7 +329,7 @@ def test_command_execution_advanced(
mock_http_client._make_request.side_effect = [initial_response, poll_response]

with patch.object(
sea_client, "get_execution_result", return_value="mock_result_set"
sea_client, "_response_to_result_set", return_value="mock_result_set"
) as mock_get_result:
with patch("time.sleep"):
result = sea_client.execute_command(
Expand Down Expand Up @@ -360,7 +357,7 @@ def test_command_execution_advanced(
dbsql_param = IntegerParameter(name="param1", value=1)
param = dbsql_param.as_tspark_param(named=True)

with patch.object(sea_client, "get_execution_result"):
with patch.object(sea_client, "_response_to_result_set"):
sea_client.execute_command(
operation="SELECT * FROM table WHERE col = :param1",
session_id=sea_session_id,
Expand Down
Loading