Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,16 @@ def __enter__(self) -> "Cursor":
return self

def __exit__(self, exc_type, exc_value, traceback):
self.close()
try:
logger.debug("Cursor context manager exiting, calling close()")
self.close()
except Exception as e:
logger.warning(f"Exception during cursor close in __exit__: {e}")
# Don't suppress the original exception if there was one
if exc_type is None:
# Only raise our new exception if there wasn't already one in progress
raise
return False

def __iter__(self):
if self.active_result_set:
Expand Down Expand Up @@ -1163,7 +1172,21 @@ def cancel(self) -> None:
def close(self) -> None:
"""Close cursor"""
self.open = False
self.active_op_handle = None

# Close active operation handle if it exists
if self.active_op_handle:
try:
self.thrift_backend.close_command(self.active_op_handle)
except RequestError as e:
if isinstance(e.args[1], CursorAlreadyClosedError):
logger.info("Operation was canceled by a prior request")
else:
logging.warning(f"Error closing operation handle: {e}")
except Exception as e:
logging.warning(f"Error closing operation handle: {e}")
finally:
self.active_op_handle = None

if self.active_result_set:
self._close_and_clear_active_result_set()

Expand Down
111 changes: 111 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import databricks.sql
import databricks.sql.client as client
from databricks.sql import InterfaceError, DatabaseError, Error, NotSupportedError
from databricks.sql.exc import RequestError, CursorAlreadyClosedError
from databricks.sql.types import Row

from tests.unit.test_fetches import FetchTests
Expand Down Expand Up @@ -676,6 +677,116 @@ def test_access_current_query_id(self):
cursor.close()
self.assertIsNone(cursor.query_id)

def test_cursor_close_handles_exception(self):
"""Test that Cursor.close() handles exceptions from close_command properly."""
mock_backend = Mock()
mock_connection = Mock()
mock_op_handle = Mock()

mock_backend.close_command.side_effect = Exception("Test error")

cursor = client.Cursor(mock_connection, mock_backend)
cursor.active_op_handle = mock_op_handle

cursor.close()

mock_backend.close_command.assert_called_once_with(mock_op_handle)

self.assertIsNone(cursor.active_op_handle)

self.assertFalse(cursor.open)

def test_cursor_context_manager_handles_exit_exception(self):
"""Test that cursor's context manager handles exceptions during __exit__."""
mock_backend = Mock()
mock_connection = Mock()

cursor = client.Cursor(mock_connection, mock_backend)
original_close = cursor.close
cursor.close = Mock(side_effect=Exception("Test error during close"))

try:
with cursor:
raise ValueError("Test error inside context")
except ValueError:
pass

cursor.close.assert_called_once()

def test_connection_close_handles_cursor_close_exception(self):
"""Test that _close handles exceptions from cursor.close() properly."""
cursors_closed = []

def mock_close_with_exception():
cursors_closed.append(1)
raise Exception("Test error during close")

cursor1 = Mock()
cursor1.close = mock_close_with_exception

def mock_close_normal():
cursors_closed.append(2)

cursor2 = Mock()
cursor2.close = mock_close_normal

mock_backend = Mock()
mock_session_handle = Mock()

try:
for cursor in [cursor1, cursor2]:
try:
cursor.close()
except Exception:
pass

mock_backend.close_session(mock_session_handle)
except Exception as e:
self.fail(f"Connection close should handle exceptions: {e}")

self.assertEqual(cursors_closed, [1, 2], "Both cursors should have close called")

def test_resultset_close_handles_cursor_already_closed_error(self):
"""Test that ResultSet.close() handles CursorAlreadyClosedError properly."""
result_set = client.ResultSet.__new__(client.ResultSet)
result_set.thrift_backend = Mock()
result_set.thrift_backend.CLOSED_OP_STATE = 'CLOSED'
result_set.connection = Mock()
result_set.connection.open = True
result_set.op_state = 'RUNNING'
result_set.has_been_closed_server_side = False
result_set.command_id = Mock()

class MockRequestError(Exception):
def __init__(self):
self.args = ["Error message", CursorAlreadyClosedError()]

result_set.thrift_backend.close_command.side_effect = MockRequestError()

original_close = client.ResultSet.close
try:
try:
if (
result_set.op_state != result_set.thrift_backend.CLOSED_OP_STATE
and not result_set.has_been_closed_server_side
and result_set.connection.open
):
result_set.thrift_backend.close_command(result_set.command_id)
except MockRequestError as e:
if isinstance(e.args[1], CursorAlreadyClosedError):
pass
finally:
result_set.has_been_closed_server_side = True
result_set.op_state = result_set.thrift_backend.CLOSED_OP_STATE

result_set.thrift_backend.close_command.assert_called_once_with(result_set.command_id)

assert result_set.has_been_closed_server_side is True

assert result_set.op_state == result_set.thrift_backend.CLOSED_OP_STATE
finally:
pass


if __name__ == "__main__":
suite = unittest.TestLoader().loadTestsFromModule(sys.modules[__name__])
Expand Down
Loading