Skip to content

Commit f57c6a3

Browse files
Merge branch 'main' into renaud.hartert/main
2 parents 0009ec6 + 762c57b commit f57c6a3

File tree

2 files changed

+25
-22
lines changed

2 files changed

+25
-22
lines changed

databricks/sdk/_base_client.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -159,16 +159,29 @@ def do(self,
159159
if isinstance(data, (str, bytes)):
160160
data = io.BytesIO(data.encode('utf-8') if isinstance(data, str) else data)
161161

162-
# Only retry if the request is not a stream or if the stream is seekable and
163-
# we can rewind it. This is necessary to avoid bugs where the retry doesn't
164-
# re-read already read data from the body.
165-
if data is not None and not self._is_seekable_stream(data):
166-
logger.debug(f"Retry disabled for non-seekable stream: type={type(data)}")
167-
call = self._perform
168-
else:
162+
if not data:
163+
# The request is not a stream.
169164
call = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
170165
is_retryable=self._is_retryable,
171166
clock=self._clock)(self._perform)
167+
elif self._is_seekable_stream(data):
168+
# Keep track of the initial position of the stream so that we can rewind to it
169+
# if we need to retry the request.
170+
initial_data_position = data.tell()
171+
172+
def rewind():
173+
logger.debug(f"Rewinding input data to offset {initial_data_position} before retry")
174+
data.seek(initial_data_position)
175+
176+
call = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
177+
is_retryable=self._is_retryable,
178+
clock=self._clock,
179+
before_retry=rewind)(self._perform)
180+
else:
181+
# Do not retry if the stream is not seekable. This is necessary to avoid bugs
182+
# where the retry doesn't re-read already read data from the stream.
183+
logger.debug(f"Retry disabled for non-seekable stream: type={type(data)}")
184+
call = self._perform
172185

173186
response = call(method,
174187
url,
@@ -249,12 +262,6 @@ def _perform(self,
249262
files=None,
250263
data=None,
251264
auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None):
252-
# Keep track of the initial position of the stream so that we can rewind it if
253-
# we need to retry the request.
254-
initial_data_position = 0
255-
if self._is_seekable_stream(data):
256-
initial_data_position = data.tell()
257-
258265
response = self._session.request(method,
259266
url,
260267
params=self._fix_query_string(query),
@@ -266,16 +273,8 @@ def _perform(self,
266273
stream=raw,
267274
timeout=self._http_timeout_seconds)
268275
self._record_request_log(response, raw=raw or data is not None or files is not None)
269-
270276
error = self._error_parser.get_api_error(response)
271277
if error is not None:
272-
# If the request body is a seekable stream, rewind it so that it is ready
273-
# to be read again in case of a retry.
274-
#
275-
# TODO: This should be moved into a "before-retry" hook to avoid one
276-
# unnecessary seek on the last failed retry before aborting.
277-
if self._is_seekable_stream(data):
278-
data.seek(initial_data_position)
279278
raise error from None
280279

281280
return response

databricks/sdk/retries.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ def retried(*,
1313
on: Sequence[Type[BaseException]] = None,
1414
is_retryable: Callable[[BaseException], Optional[str]] = None,
1515
timeout=timedelta(minutes=20),
16-
clock: Clock = None):
16+
clock: Clock = None,
17+
before_retry: Callable = None):
1718
has_allowlist = on is not None
1819
has_callback = is_retryable is not None
1920
if not (has_allowlist or has_callback) or (has_allowlist and has_callback):
@@ -54,6 +55,9 @@ def wrapper(*args, **kwargs):
5455
raise err
5556

5657
logger.debug(f'Retrying: {retry_reason} (sleeping ~{sleep}s)')
58+
if before_retry:
59+
before_retry()
60+
5761
clock.sleep(sleep + random())
5862
attempt += 1
5963
raise TimeoutError(f'Timed out after {timeout}') from last_err

0 commit comments

Comments
 (0)