Skip to content

Commit e8b7916

Browse files
[Fix] Rewind seekable streams before retrying (#821)
## What changes are proposed in this pull request? This PR adapts the retry mechanism of `BaseClient` to only retry if (i) the request is not a stream or (ii) the stream is seekable and can be reset to its initial position. This fixes a bug that led retries to ignore part of the request that were already processed in previous attempts. ## How is this tested? Added unit tests to verify that (i) non-seekable streams are not retried, and (ii) seekable streams are properly reset before retrying.
1 parent ee6e70a commit e8b7916

File tree

2 files changed

+178
-12
lines changed

2 files changed

+178
-12
lines changed

databricks/sdk/_base_client.py

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import io
12
import logging
23
import urllib.parse
34
from datetime import timedelta
@@ -130,6 +131,14 @@ def flatten_dict(d: Dict[str, Any]) -> Dict[str, Any]:
130131
flattened = dict(flatten_dict(with_fixed_bools))
131132
return flattened
132133

134+
@staticmethod
135+
def _is_seekable_stream(data) -> bool:
136+
if data is None:
137+
return False
138+
if not isinstance(data, io.IOBase):
139+
return False
140+
return data.seekable()
141+
133142
def do(self,
134143
method: str,
135144
url: str,
@@ -144,18 +153,31 @@ def do(self,
144153
if headers is None:
145154
headers = {}
146155
headers['User-Agent'] = self._user_agent_base
147-
retryable = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
148-
is_retryable=self._is_retryable,
149-
clock=self._clock)
150-
response = retryable(self._perform)(method,
151-
url,
152-
query=query,
153-
headers=headers,
154-
body=body,
155-
raw=raw,
156-
files=files,
157-
data=data,
158-
auth=auth)
156+
157+
# Wrap strings and bytes in a seekable stream so that we can rewind them.
158+
if isinstance(data, (str, bytes)):
159+
data = io.BytesIO(data.encode('utf-8') if isinstance(data, str) else data)
160+
161+
# Only retry if the request is not a stream or if the stream is seekable and
162+
# we can rewind it. This is necessary to avoid bugs where the retry doesn't
163+
# re-read already read data from the body.
164+
if data is not None and not self._is_seekable_stream(data):
165+
logger.debug(f"Retry disabled for non-seekable stream: type={type(data)}")
166+
call = self._perform
167+
else:
168+
call = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
169+
is_retryable=self._is_retryable,
170+
clock=self._clock)(self._perform)
171+
172+
response = call(method,
173+
url,
174+
query=query,
175+
headers=headers,
176+
body=body,
177+
raw=raw,
178+
files=files,
179+
data=data,
180+
auth=auth)
159181

160182
resp = dict()
161183
for header in response_headers if response_headers else []:
@@ -226,6 +248,12 @@ def _perform(self,
226248
files=None,
227249
data=None,
228250
auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None):
251+
# Keep track of the initial position of the stream so that we can rewind it if
252+
# we need to retry the request.
253+
initial_data_position = 0
254+
if self._is_seekable_stream(data):
255+
initial_data_position = data.tell()
256+
229257
response = self._session.request(method,
230258
url,
231259
params=self._fix_query_string(query),
@@ -237,9 +265,18 @@ def _perform(self,
237265
stream=raw,
238266
timeout=self._http_timeout_seconds)
239267
self._record_request_log(response, raw=raw or data is not None or files is not None)
268+
240269
error = self._error_parser.get_api_error(response)
241270
if error is not None:
271+
# If the request body is a seekable stream, rewind it so that it is ready
272+
# to be read again in case of a retry.
273+
#
274+
# TODO: This should be moved into a "before-retry" hook to avoid one
275+
# unnecessary seek on the last failed retry before aborting.
276+
if self._is_seekable_stream(data):
277+
data.seek(initial_data_position)
242278
raise error from None
279+
243280
return response
244281

245282
def _record_request_log(self, response: requests.Response, raw: bool = False) -> None:

tests/test_base_client.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import io
12
import random
23
from http.server import BaseHTTPRequestHandler
34
from typing import Iterator, List
@@ -316,3 +317,131 @@ def mock_iter_content(chunk_size):
316317
assert received_data == test_data # all data was received correctly
317318
assert len(content_chunks) == expected_chunks # correct number of chunks
318319
assert all(len(c) <= chunk_size for c in content_chunks) # chunks don't exceed size
320+
321+
322+
def test_is_seekable_stream():
323+
client = _BaseClient()
324+
325+
# Test various input types that are not streams.
326+
assert not client._is_seekable_stream(None) # None
327+
assert not client._is_seekable_stream("string data") # str
328+
assert not client._is_seekable_stream(b"binary data") # bytes
329+
assert not client._is_seekable_stream(["list", "data"]) # list
330+
assert not client._is_seekable_stream(42) # int
331+
332+
# Test non-seekable stream.
333+
non_seekable = io.BytesIO(b"test data")
334+
non_seekable.seekable = lambda: False
335+
assert not client._is_seekable_stream(non_seekable)
336+
337+
# Test seekable streams.
338+
assert client._is_seekable_stream(io.BytesIO(b"test data")) # BytesIO
339+
assert client._is_seekable_stream(io.StringIO("test data")) # StringIO
340+
341+
# Test file objects.
342+
with open(__file__, 'rb') as f:
343+
assert client._is_seekable_stream(f) # File object
344+
345+
# Test custom seekable stream.
346+
class CustomSeekableStream(io.IOBase):
347+
348+
def seekable(self):
349+
return True
350+
351+
def seek(self, offset, whence=0):
352+
return 0
353+
354+
def tell(self):
355+
return 0
356+
357+
assert client._is_seekable_stream(CustomSeekableStream())
358+
359+
360+
@pytest.mark.parametrize(
361+
'input_data',
362+
[
363+
b"0123456789", # bytes -> BytesIO
364+
"0123456789", # str -> BytesIO
365+
io.BytesIO(b"0123456789"), # BytesIO directly
366+
io.StringIO("0123456789"), # StringIO
367+
])
368+
def test_reset_seekable_stream_on_retry(input_data):
369+
received_data = []
370+
371+
# Retry two times before succeeding.
372+
def inner(h: BaseHTTPRequestHandler):
373+
if len(received_data) == 2:
374+
h.send_response(200)
375+
h.end_headers()
376+
else:
377+
h.send_response(429)
378+
h.end_headers()
379+
380+
content_length = int(h.headers.get('Content-Length', 0))
381+
if content_length > 0:
382+
received_data.append(h.rfile.read(content_length))
383+
384+
with http_fixture_server(inner) as host:
385+
client = _BaseClient()
386+
387+
# Retries should reset the stream.
388+
client.do('POST', f'{host}/foo', data=input_data)
389+
390+
assert received_data == [b"0123456789", b"0123456789", b"0123456789"]
391+
392+
393+
def test_reset_seekable_stream_to_their_initial_position_on_retry():
394+
received_data = []
395+
396+
# Retry two times before succeeding.
397+
def inner(h: BaseHTTPRequestHandler):
398+
if len(received_data) == 2:
399+
h.send_response(200)
400+
h.end_headers()
401+
else:
402+
h.send_response(429)
403+
h.end_headers()
404+
405+
content_length = int(h.headers.get('Content-Length', 0))
406+
if content_length > 0:
407+
received_data.append(h.rfile.read(content_length))
408+
409+
input_data = io.BytesIO(b"0123456789")
410+
input_data.seek(4)
411+
412+
with http_fixture_server(inner) as host:
413+
client = _BaseClient()
414+
415+
# Retries should reset the stream.
416+
client.do('POST', f'{host}/foo', data=input_data)
417+
418+
assert received_data == [b"456789", b"456789", b"456789"]
419+
assert input_data.tell() == 10 # EOF
420+
421+
422+
def test_no_retry_or_reset_on_non_seekable_stream():
423+
requests = []
424+
425+
# Always respond with a response that triggers a retry.
426+
def inner(h: BaseHTTPRequestHandler):
427+
content_length = int(h.headers.get('Content-Length', 0))
428+
if content_length > 0:
429+
requests.append(h.rfile.read(content_length))
430+
431+
h.send_response(429)
432+
h.send_header('Retry-After', '1')
433+
h.end_headers()
434+
435+
input_data = io.BytesIO(b"0123456789")
436+
input_data.seekable = lambda: False # makes the stream appear non-seekable
437+
438+
with http_fixture_server(inner) as host:
439+
client = _BaseClient()
440+
441+
# Should raise error immediately without retry.
442+
with pytest.raises(DatabricksError):
443+
client.do('POST', f'{host}/foo', data=input_data)
444+
445+
# Verify that only one request was made (no retries).
446+
assert requests == [b"0123456789"]
447+
assert input_data.tell() == 10 # EOF

0 commit comments

Comments
 (0)