diff --git a/databricks/sdk/_base_client.py b/databricks/sdk/_base_client.py index 95ce39cbe..6424fc1bb 100644 --- a/databricks/sdk/_base_client.py +++ b/databricks/sdk/_base_client.py @@ -50,7 +50,8 @@ def __init__(self, http_timeout_seconds: float = None, extra_error_customizers: List[_ErrorCustomizer] = None, debug_headers: bool = False, - clock: Clock = None): + clock: Clock = None, + streaming_buffer_size: int = 1024 * 1024): # 1MB """ :param debug_truncate_bytes: :param retry_timeout_seconds: @@ -68,6 +69,7 @@ def __init__(self, :param extra_error_customizers: :param debug_headers: Whether to include debug headers in the request log. :param clock: Clock object to use for time-related operations. + :param streaming_buffer_size: The size of the buffer to use for streaming responses. """ self._debug_truncate_bytes = debug_truncate_bytes or 96 @@ -78,6 +80,7 @@ def __init__(self, self._clock = clock or RealClock() self._session = requests.Session() self._session.auth = self._authenticate + self._streaming_buffer_size = streaming_buffer_size # We don't use `max_retries` from HTTPAdapter to align with a more production-ready # retry strategy established in the Databricks SDK for Go. See _is_retryable and @@ -158,7 +161,9 @@ def do(self, for header in response_headers if response_headers else []: resp[header] = response.headers.get(Casing.to_header_case(header)) if raw: - resp["contents"] = _StreamingResponse(response) + streaming_response = _StreamingResponse(response) + streaming_response.set_chunk_size(self._streaming_buffer_size) + resp["contents"] = streaming_response return resp if not len(response.content): return resp @@ -283,6 +288,11 @@ def isatty(self) -> bool: return False def read(self, n: int = -1) -> bytes: + """ + Read up to n bytes from the response stream. If n is negative, read + until the end of the stream. + """ + self._open() read_everything = n < 0 remaining_bytes = n diff --git a/tests/test_base_client.py b/tests/test_base_client.py index e9e7324a9..b55f4e7f8 100644 --- a/tests/test_base_client.py +++ b/tests/test_base_client.py @@ -1,5 +1,7 @@ +import random from http.server import BaseHTTPRequestHandler from typing import Iterator, List +from unittest.mock import Mock import pytest import requests @@ -276,3 +278,39 @@ def inner(h: BaseHTTPRequestHandler): assert 'foo' in res assert len(requests) == 2 + + +@pytest.mark.parametrize('chunk_size,expected_chunks,data_size', + [(5, 20, 100), # 100 / 5 bytes per chunk = 20 chunks + (10, 10, 100), # 100 / 10 bytes per chunk = 10 chunks + (200, 1, 100), # 100 / 200 bytes per chunk = 1 chunk + ]) +def test_streaming_response_chunk_size(chunk_size, expected_chunks, data_size): + rng = random.Random(42) + test_data = bytes(rng.getrandbits(8) for _ in range(data_size)) + + content_chunks = [] + mock_response = Mock(spec=requests.Response) + + def mock_iter_content(chunk_size): + # Simulate how requests would chunk the data. + for i in range(0, len(test_data), chunk_size): + chunk = test_data[i:i + chunk_size] + content_chunks.append(chunk) # track chunks for verification + yield chunk + + mock_response.iter_content = mock_iter_content + stream = _StreamingResponse(mock_response) + stream.set_chunk_size(chunk_size) + + # Read all data one byte at a time. + received_data = b"" + while True: + chunk = stream.read(1) + if not chunk: + break + received_data += chunk + + assert received_data == test_data # all data was received correctly + assert len(content_chunks) == expected_chunks # correct number of chunks + assert all(len(c) <= chunk_size for c in content_chunks) # chunks don't exceed size