Skip to content

Commit 787241f

Browse files
Rewind seekable stream before retrying
1 parent 271502b commit 787241f

File tree

2 files changed

+108
-12
lines changed

2 files changed

+108
-12
lines changed

databricks/sdk/_base_client.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import io
12
import logging
23
import urllib.parse
34
from datetime import timedelta
@@ -130,6 +131,14 @@ def flatten_dict(d: Dict[str, Any]) -> Dict[str, Any]:
130131
flattened = dict(flatten_dict(with_fixed_bools))
131132
return flattened
132133

134+
@staticmethod
135+
def _is_seekable_stream(data) -> bool:
136+
if data is None:
137+
return False
138+
if not isinstance(data, io.IOBase):
139+
return False
140+
return data.seekable()
141+
133142
def do(self,
134143
method: str,
135144
url: str,
@@ -144,18 +153,25 @@ def do(self,
144153
if headers is None:
145154
headers = {}
146155
headers['User-Agent'] = self._user_agent_base
147-
retryable = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
148-
is_retryable=self._is_retryable,
149-
clock=self._clock)
150-
response = retryable(self._perform)(method,
151-
url,
152-
query=query,
153-
headers=headers,
154-
body=body,
155-
raw=raw,
156-
files=files,
157-
data=data,
158-
auth=auth)
156+
157+
# Only retry if the request is not a stream or if the stream is seekable and
158+
# we can rewind it. This is necessary to avoid bugs where the retry doesn't
159+
# re-read already read data from the body.
160+
call = self._perform
161+
if data is None or self._is_seekable_stream(data):
162+
call = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
163+
is_retryable=self._is_retryable,
164+
clock=self._clock)(call)
165+
166+
response = call(method,
167+
url,
168+
query=query,
169+
headers=headers,
170+
body=body,
171+
raw=raw,
172+
files=files,
173+
data=data,
174+
auth=auth)
159175

160176
resp = dict()
161177
for header in response_headers if response_headers else []:
@@ -226,6 +242,12 @@ def _perform(self,
226242
files=None,
227243
data=None,
228244
auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None):
245+
# Keep track of the initial position of the stream so that we can rewind it if
246+
# we need to retry the request.
247+
initial_data_position = 0
248+
if self._is_seekable_stream(data):
249+
initial_data_position = data.tell()
250+
229251
response = self._session.request(method,
230252
url,
231253
params=self._fix_query_string(query),
@@ -237,9 +259,15 @@ def _perform(self,
237259
stream=raw,
238260
timeout=self._http_timeout_seconds)
239261
self._record_request_log(response, raw=raw or data is not None or files is not None)
262+
240263
error = self._error_parser.get_api_error(response)
241264
if error is not None:
265+
# If the request body is a seekable stream, rewind it so that it is ready
266+
# to be read again in case of a retry.
267+
if self._is_seekable_stream(data):
268+
data.seek(initial_data_position)
242269
raise error from None
270+
243271
return response
244272

245273
def _record_request_log(self, response: requests.Response, raw: bool = False) -> None:

tests/test_base_client.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import io
12
import random
23
from http.server import BaseHTTPRequestHandler
34
from typing import Iterator, List
@@ -314,3 +315,70 @@ def mock_iter_content(chunk_size):
314315
assert received_data == test_data # all data was received correctly
315316
assert len(content_chunks) == expected_chunks # correct number of chunks
316317
assert all(len(c) <= chunk_size for c in content_chunks) # chunks don't exceed size
318+
319+
320+
def test_perform_resets_seekable_stream_on_error():
321+
received_data = []
322+
323+
# Response that triggers a retry.
324+
def inner(h: BaseHTTPRequestHandler):
325+
content_length = int(h.headers.get('Content-Length', 0))
326+
if content_length > 0:
327+
received_data.append(h.rfile.read(content_length))
328+
329+
h.send_response(429)
330+
h.send_header('Retry-After', '1')
331+
h.end_headers()
332+
333+
stream = io.BytesIO(b"0123456789") # seekable stream
334+
335+
with http_fixture_server(inner) as host:
336+
client = _BaseClient()
337+
338+
# Read some data from the stream first to verify that the stream is
339+
# reset to the correct position rather than to its beginning.
340+
stream.read(4)
341+
assert stream.tell() == 4
342+
343+
# Call perform which should fail but reset the stream.
344+
with pytest.raises(DatabricksError):
345+
client._perform('POST', f'{host}/foo', data=stream)
346+
347+
assert received_data == [b"456789"]
348+
349+
# Verify stream was reset to initial position.
350+
assert stream.tell() == 4
351+
352+
353+
def test_perform_does_not_reset_nonseekable_stream_on_error():
354+
received_data = []
355+
356+
# Response that triggers a retry.
357+
def inner(h: BaseHTTPRequestHandler):
358+
content_length = int(h.headers.get('Content-Length', 0))
359+
if content_length > 0:
360+
received_data.append(h.rfile.read(content_length))
361+
362+
h.send_response(429)
363+
h.send_header('Retry-After', '1')
364+
h.end_headers()
365+
366+
stream = io.BytesIO(b"0123456789")
367+
stream.seekable = lambda: False # makes the stream appear non-seekable
368+
369+
with http_fixture_server(inner) as host:
370+
client = _BaseClient()
371+
372+
# Read some data from the stream first to verify that the stream is
373+
# reset to the correct position rather than to its beginning.
374+
stream.read(4)
375+
assert stream.tell() == 4
376+
377+
# Call perform which should fail but reset the stream.
378+
with pytest.raises(DatabricksError):
379+
client._perform('POST', f'{host}/foo', data=stream)
380+
381+
assert received_data == [b"456789"]
382+
383+
# Verify stream was NOT reset to initial position.
384+
assert stream.tell() == 10 # EOF

0 commit comments

Comments
 (0)