Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,6 +1509,7 @@ def close(self) -> None:
if isinstance(e.args[1], CursorAlreadyClosedError):
logger.info("Operation was canceled by a prior request")
finally:
self.results.close()
self.has_been_closed_server_side = True
self.op_state = self.thrift_backend.CLOSED_OP_STATE

Expand Down
13 changes: 13 additions & 0 deletions src/databricks/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def next_n_rows(self, num_rows: int):
def remaining_rows(self):
pass

@abstractmethod
def close(self):
pass


class ResultSetQueueFactory(ABC):
@staticmethod
Expand Down Expand Up @@ -157,6 +161,9 @@ def remaining_rows(self):
self.cur_row_index += slice.num_rows
return slice

def close(self):
return


class ArrowQueue(ResultSetQueue):
def __init__(
Expand Down Expand Up @@ -192,6 +199,9 @@ def remaining_rows(self) -> "pyarrow.Table":
self.cur_row_index += slice.num_rows
return slice

def close(self):
return


class CloudFetchQueue(ResultSetQueue):
def __init__(
Expand Down Expand Up @@ -341,6 +351,9 @@ def _create_empty_table(self) -> "pyarrow.Table":
# Create a 0-row table with just the schema bytes
return create_arrow_table_from_arrow_file(self.schema_bytes, self.description)

def close(self):
self.download_manager._shutdown_manager()


ExecuteResponse = namedtuple(
"ExecuteResponse",
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,33 +267,39 @@ def test_arraysize_buffer_size_passthrough(
def test_closing_result_set_with_closed_connection_soft_closes_commands(self):
mock_connection = Mock()
mock_backend = Mock()
mock_results = Mock()
result_set = client.ResultSet(
connection=mock_connection,
thrift_backend=mock_backend,
execute_response=Mock(),
)
result_set.results = mock_results
mock_connection.open = False

result_set.close()

self.assertFalse(mock_backend.close_command.called)
self.assertTrue(result_set.has_been_closed_server_side)
mock_results.close.assert_called_once()

def test_closing_result_set_hard_closes_commands(self):
mock_results_response = Mock()
mock_results_response.has_been_closed_server_side = False
mock_connection = Mock()
mock_thrift_backend = Mock()
mock_results = Mock()
mock_connection.open = True
result_set = client.ResultSet(
mock_connection, mock_results_response, mock_thrift_backend
)
result_set.results = mock_results

result_set.close()

mock_thrift_backend.close_command.assert_called_once_with(
mock_results_response.command_handle
)
mock_results.close.assert_called_once()

@patch("%s.client.ResultSet" % PACKAGE_NAME)
def test_executing_multiple_commands_uses_the_most_recent_command(
Expand Down
Loading