Skip to content
54 changes: 42 additions & 12 deletions databricks/sdk/_base_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import logging
import urllib.parse
from datetime import timedelta
Expand Down Expand Up @@ -130,6 +131,14 @@ def flatten_dict(d: Dict[str, Any]) -> Dict[str, Any]:
flattened = dict(flatten_dict(with_fixed_bools))
return flattened

@staticmethod
def _is_seekable_stream(data) -> bool:
if data is None:
return False
if not isinstance(data, io.IOBase):
return False
return data.seekable()

def do(self,
method: str,
url: str,
Expand All @@ -144,18 +153,27 @@ def do(self,
if headers is None:
headers = {}
headers['User-Agent'] = self._user_agent_base
retryable = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
is_retryable=self._is_retryable,
clock=self._clock)
response = retryable(self._perform)(method,
url,
query=query,
headers=headers,
body=body,
raw=raw,
files=files,
data=data,
auth=auth)

# Only retry if the request is not a stream or if the stream is seekable and
# we can rewind it. This is necessary to avoid bugs where the retry doesn't
# re-read already read data from the body.
if data is not None and not self._is_seekable_stream(data):
logger.debug(f"Retry disabled for non-seekable stream: type={type(data)}")
call = self._perform
else:
call = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
is_retryable=self._is_retryable,
clock=self._clock)(self._perform)

response = call(method,
url,
query=query,
headers=headers,
body=body,
raw=raw,
files=files,
data=data,
auth=auth)

resp = dict()
for header in response_headers if response_headers else []:
Expand Down Expand Up @@ -226,6 +244,12 @@ def _perform(self,
files=None,
data=None,
auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None):
# Keep track of the initial position of the stream so that we can rewind it if
# we need to retry the request.
initial_data_position = 0
if self._is_seekable_stream(data):
initial_data_position = data.tell()

response = self._session.request(method,
url,
params=self._fix_query_string(query),
Expand All @@ -237,9 +261,15 @@ def _perform(self,
stream=raw,
timeout=self._http_timeout_seconds)
self._record_request_log(response, raw=raw or data is not None or files is not None)

error = self._error_parser.get_api_error(response)
if error is not None:
# If the request body is a seekable stream, rewind it so that it is ready
# to be read again in case of a retry.
if self._is_seekable_stream(data):
data.seek(initial_data_position)
raise error from None

return response

def _record_request_log(self, response: requests.Response, raw: bool = False) -> None:
Expand Down
138 changes: 138 additions & 0 deletions tests/test_base_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import random
from http.server import BaseHTTPRequestHandler
from typing import Iterator, List
Expand Down Expand Up @@ -314,3 +315,140 @@ def mock_iter_content(chunk_size):
assert received_data == test_data # all data was received correctly
assert len(content_chunks) == expected_chunks # correct number of chunks
assert all(len(c) <= chunk_size for c in content_chunks) # chunks don't exceed size


def test_is_seekable_stream():
client = _BaseClient()

# Test various input types that are not streams.
assert not client._is_seekable_stream(None) # None
assert not client._is_seekable_stream("string data") # str
assert not client._is_seekable_stream(b"binary data") # bytes
assert not client._is_seekable_stream(["list", "data"]) # list
assert not client._is_seekable_stream(42) # int

# Test non-seekable stream.
non_seekable = io.BytesIO(b"test data")
non_seekable.seekable = lambda: False
assert not client._is_seekable_stream(non_seekable)

# Test seekable streams.
assert client._is_seekable_stream(io.BytesIO(b"test data")) # BytesIO
assert client._is_seekable_stream(io.StringIO("test data")) # StringIO

# Test file objects.
with open(__file__, 'rb') as f:
assert client._is_seekable_stream(f) # File object

# Test custom seekable stream.
class CustomSeekableStream(io.IOBase):

def seekable(self):
return True

def seek(self, offset, whence=0):
return 0

def tell(self):
return 0

assert client._is_seekable_stream(CustomSeekableStream())


def test_no_retry_on_non_seekable_stream():
requests = []

# Always respond with a response that triggers a retry.
def inner(h: BaseHTTPRequestHandler):
content_length = int(h.headers.get('Content-Length', 0))
if content_length > 0:
requests.append(h.rfile.read(content_length))

h.send_response(429)
h.send_header('Retry-After', '1')
h.end_headers()

stream = io.BytesIO(b"test data")
stream.seekable = lambda: False # makes the stream appear non-seekable

with http_fixture_server(inner) as host:
client = _BaseClient()

# Should raise error immediately without retry.
with pytest.raises(DatabricksError):
client.do('POST', f'{host}/foo', data=stream)

# Verify that only one request was made (no retries).
assert len(requests) == 1
assert requests[0] == b"test data"


def test_perform_resets_seekable_stream_on_retry():
received_data = []

# Always respond with a response that triggers a retry.
def inner(h: BaseHTTPRequestHandler):
content_length = int(h.headers.get('Content-Length', 0))
if content_length > 0:
received_data.append(h.rfile.read(content_length))

h.send_response(429)
h.send_header('Retry-After', '1')
h.end_headers()

stream = io.BytesIO(b"0123456789") # seekable stream

# Read some data from the stream first to verify that the stream is
# reset to the correct position rather than to its beginning.
stream.read(4)
assert stream.tell() == 4

with http_fixture_server(inner) as host:
client = _BaseClient()

# Each call should fail and reset the stream.
with pytest.raises(DatabricksError):
client._perform('POST', f'{host}/foo', data=stream)
with pytest.raises(DatabricksError):
client._perform('POST', f'{host}/foo', data=stream)
with pytest.raises(DatabricksError):
client._perform('POST', f'{host}/foo', data=stream)

assert received_data == [b"456789", b"456789", b"456789"]

# Verify stream was reset to initial position.
assert stream.tell() == 4


def test_perform_does_not_reset_nonseekable_stream_on_retry():
received_data = []

# Always respond with a response that triggers a retry.
def inner(h: BaseHTTPRequestHandler):
content_length = int(h.headers.get('Content-Length', 0))
if content_length > 0:
received_data.append(h.rfile.read(content_length))

h.send_response(429)
h.send_header('Retry-After', '1')
h.end_headers()

stream = io.BytesIO(b"0123456789")
stream.seekable = lambda: False # makes the stream appear non-seekable

# Read some data from the stream first to verify that the stream is
# reset to the correct position rather than to its beginning.
stream.read(4)
assert stream.tell() == 4

with http_fixture_server(inner) as host:
client = _BaseClient()

# Should fail without resetting the stream.
with pytest.raises(DatabricksError):
client._perform('POST', f'{host}/foo', data=stream)

assert received_data == [b"456789"]

# Verify stream was NOT reset to initial position.
assert stream.tell() == 10 # EOF
Loading