Skip to content
52 changes: 40 additions & 12 deletions databricks/sdk/_base_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import logging
import urllib.parse
from datetime import timedelta
Expand Down Expand Up @@ -130,6 +131,14 @@ def flatten_dict(d: Dict[str, Any]) -> Dict[str, Any]:
flattened = dict(flatten_dict(with_fixed_bools))
return flattened

@staticmethod
def _is_seekable_stream(data) -> bool:
if data is None:
return False
if not isinstance(data, io.IOBase):
return False
return data.seekable()

def do(self,
method: str,
url: str,
Expand All @@ -144,18 +153,25 @@ def do(self,
if headers is None:
headers = {}
headers['User-Agent'] = self._user_agent_base
retryable = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
is_retryable=self._is_retryable,
clock=self._clock)
response = retryable(self._perform)(method,
url,
query=query,
headers=headers,
body=body,
raw=raw,
files=files,
data=data,
auth=auth)

# Only retry if the request is not a stream or if the stream is seekable and
# we can rewind it. This is necessary to avoid bugs where the retry doesn't
# re-read already read data from the body.
call = self._perform
if data is None or self._is_seekable_stream(data):
call = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
is_retryable=self._is_retryable,
clock=self._clock)(call)

response = call(method,
url,
query=query,
headers=headers,
body=body,
raw=raw,
files=files,
data=data,
auth=auth)

resp = dict()
for header in response_headers if response_headers else []:
Expand Down Expand Up @@ -226,6 +242,12 @@ def _perform(self,
files=None,
data=None,
auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None):
# Keep track of the initial position of the stream so that we can rewind it if
# we need to retry the request.
initial_data_position = 0
if self._is_seekable_stream(data):
initial_data_position = data.tell()

response = self._session.request(method,
url,
params=self._fix_query_string(query),
Expand All @@ -237,9 +259,15 @@ def _perform(self,
stream=raw,
timeout=self._http_timeout_seconds)
self._record_request_log(response, raw=raw or data is not None or files is not None)

error = self._error_parser.get_api_error(response)
if error is not None:
# If the request body is a seekable stream, rewind it so that it is ready
# to be read again in case of a retry.
if self._is_seekable_stream(data):
data.seek(initial_data_position)
raise error from None

return response

def _record_request_log(self, response: requests.Response, raw: bool = False) -> None:
Expand Down
96 changes: 96 additions & 0 deletions tests/test_base_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import random
from http.server import BaseHTTPRequestHandler
from typing import Iterator, List
Expand Down Expand Up @@ -314,3 +315,98 @@ def mock_iter_content(chunk_size):
assert received_data == test_data # all data was received correctly
assert len(content_chunks) == expected_chunks # correct number of chunks
assert all(len(c) <= chunk_size for c in content_chunks) # chunks don't exceed size


def test_no_retry_on_non_seekable_stream():
requests = []

# Always respond with a response that triggers a retry.
def inner(h: BaseHTTPRequestHandler):
content_length = int(h.headers.get('Content-Length', 0))
if content_length > 0:
requests.append(h.rfile.read(content_length))

h.send_response(429)
h.send_header('Retry-After', '1')
h.end_headers()

stream = io.BytesIO(b"test data")
stream.seekable = lambda: False # makes the stream appear non-seekable

with http_fixture_server(inner) as host:
client = _BaseClient()

# Should raise error immediately without retry.
with pytest.raises(DatabricksError):
client.do('POST', f'{host}/foo', data=stream)

# Verify that only one request was made (no retries).
assert len(requests) == 1
assert requests[0] == b"test data"


def test_perform_resets_seekable_stream_on_error():
received_data = []

# Always respond with a response that triggers a retry.
def inner(h: BaseHTTPRequestHandler):
content_length = int(h.headers.get('Content-Length', 0))
if content_length > 0:
received_data.append(h.rfile.read(content_length))

h.send_response(429)
h.send_header('Retry-After', '1')
h.end_headers()

stream = io.BytesIO(b"0123456789") # seekable stream

with http_fixture_server(inner) as host:
client = _BaseClient()

# Read some data from the stream first to verify that the stream is
# reset to the correct position rather than to its beginning.
stream.read(4)
assert stream.tell() == 4

# Should fail but reset the stream.
with pytest.raises(DatabricksError):
client._perform('POST', f'{host}/foo', data=stream)

assert received_data == [b"456789"]

# Verify stream was reset to initial position.
assert stream.tell() == 4


def test_perform_does_not_reset_nonseekable_stream_on_error():
received_data = []

# Always respond with a response that triggers a retry.
def inner(h: BaseHTTPRequestHandler):
content_length = int(h.headers.get('Content-Length', 0))
if content_length > 0:
received_data.append(h.rfile.read(content_length))

h.send_response(429)
h.send_header('Retry-After', '1')
h.end_headers()

stream = io.BytesIO(b"0123456789")
stream.seekable = lambda: False # makes the stream appear non-seekable

with http_fixture_server(inner) as host:
client = _BaseClient()

# Read some data from the stream first to verify that the stream is
# reset to the correct position rather than to its beginning.
stream.read(4)
assert stream.tell() == 4

# Should fail without resetting the stream.
with pytest.raises(DatabricksError):
client._perform('POST', f'{host}/foo', data=stream)

assert received_data == [b"456789"]

# Verify stream was NOT reset to initial position.
assert stream.tell() == 10 # EOF
Loading