Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 57 additions & 8 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
TExecuteStatementResp,
TOperationHandle,
THandleIdentifier,
TOperationState,
TOperationType,
)
from databricks.sql.thrift_backend import ThriftBackend
Expand All @@ -23,6 +24,7 @@
from databricks.sql.exc import RequestError, CursorAlreadyClosedError
from databricks.sql.types import Row

from databricks.sql.utils import ExecuteResponse
from tests.unit.test_fetches import FetchTests
from tests.unit.test_thrift_backend import ThriftBackendTestSuite
from tests.unit.test_arrow_queue import ArrowQueueSuite
Expand Down Expand Up @@ -168,22 +170,69 @@ def test_useragent_header(self, mock_client_class):
http_headers = mock_client_class.call_args[0][3]
self.assertIn(user_agent_header_with_entry, http_headers)

@patch("%s.client.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new())
@patch("%s.client.ResultSet" % PACKAGE_NAME)
def test_closing_connection_closes_commands(self, mock_result_set_class):
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
def test_closing_connection_closes_commands(self, mock_thrift_client_class):
# Test once with has_been_closed_server side, once without
for closed in (True, False):
with self.subTest(closed=closed):
mock_result_set_class.return_value = Mock()
initial_state = (
TOperationState.FINISHED_STATE
if not closed
else TOperationState.CLOSED_STATE
)

# Mock the execute response with controlled state
mock_execute_response = Mock(spec=ExecuteResponse)
mock_execute_response.status = initial_state
mock_execute_response.has_been_closed_server_side = closed
mock_execute_response.is_staging_operation = False

# Mock the backend that will be used
mock_backend = Mock(spec=ThriftBackend)
mock_thrift_client_class.return_value = mock_backend

# Create connection and cursor
connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
cursor = connection.cursor()
cursor.execute("SELECT 1;")

# Verify initial state
self.assertEqual(
mock_execute_response.has_been_closed_server_side, closed
)
self.assertEqual(mock_execute_response.status, initial_state)

# Mock execute_command to return our execute response
cursor.thrift_backend.execute_command = Mock(
return_value=mock_execute_response
)
cursor.execute("SELECT 1")

# Verify that cursor.execute() set up the result set correctly
active_result_set = cursor.active_result_set
self.assertEqual(active_result_set.has_been_closed_server_side, closed)

# Close the connection
connection.close()

self.assertTrue(
mock_result_set_class.return_value.has_been_closed_server_side
# Verify the close logic worked:
# 1. has_been_closed_server_side should always be True after close()
self.assertTrue(active_result_set.has_been_closed_server_side)

# 2. op_state should always be CLOSED after close()
self.assertEqual(
active_result_set.op_state,
connection.thrift_backend.CLOSED_OP_STATE,
)
mock_result_set_class.return_value.close.assert_called_once_with()

# 3. Backend close_command should be called appropriately
if not closed:
# Should have called backend.close_command during the close chain
mock_backend.close_command.assert_called_once_with(
mock_execute_response.command_handle
)
else:
# Should NOT have called backend.close_command (already closed)
mock_backend.close_command.assert_not_called()

@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
def test_cant_open_cursor_on_closed_connection(self, mock_client_class):
Expand Down
Loading