Skip to content

Commit 762c57b

Browse files
authored
[Internal] Extract "before retry" handler, use it to rewind the stream (#878)
## What changes are proposed in this pull request? - Introduce a separate handler to be called before we retry the API call. This will make sure handler is called both when (1) we receive an error response we want to retry on and (2) when low-level connection exception is thrown. - Rewind the stream to the initial position in this handler (if applicable). ## How is this tested? Existing tests. **ALWAYS ANSWER THIS QUESTION:** Answer with "N/A" if tests are not applicable to your PR (e.g. if the PR only modifies comments). Do not be afraid of answering "Not tested" if the PR has not been tested. Being clear about what has been done and not done provides important context to the reviewers.
1 parent 4bcfb0a commit 762c57b

File tree

2 files changed

+25
-22
lines changed

2 files changed

+25
-22
lines changed

databricks/sdk/_base_client.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -159,16 +159,29 @@ def do(self,
159159
if isinstance(data, (str, bytes)):
160160
data = io.BytesIO(data.encode('utf-8') if isinstance(data, str) else data)
161161

162-
# Only retry if the request is not a stream or if the stream is seekable and
163-
# we can rewind it. This is necessary to avoid bugs where the retry doesn't
164-
# re-read already read data from the body.
165-
if data is not None and not self._is_seekable_stream(data):
166-
logger.debug(f"Retry disabled for non-seekable stream: type={type(data)}")
167-
call = self._perform
168-
else:
162+
if not data:
163+
# The request is not a stream.
169164
call = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
170165
is_retryable=self._is_retryable,
171166
clock=self._clock)(self._perform)
167+
elif self._is_seekable_stream(data):
168+
# Keep track of the initial position of the stream so that we can rewind to it
169+
# if we need to retry the request.
170+
initial_data_position = data.tell()
171+
172+
def rewind():
173+
logger.debug(f"Rewinding input data to offset {initial_data_position} before retry")
174+
data.seek(initial_data_position)
175+
176+
call = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
177+
is_retryable=self._is_retryable,
178+
clock=self._clock,
179+
before_retry=rewind)(self._perform)
180+
else:
181+
# Do not retry if the stream is not seekable. This is necessary to avoid bugs
182+
# where the retry doesn't re-read already read data from the stream.
183+
logger.debug(f"Retry disabled for non-seekable stream: type={type(data)}")
184+
call = self._perform
172185

173186
response = call(method,
174187
url,
@@ -249,12 +262,6 @@ def _perform(self,
249262
files=None,
250263
data=None,
251264
auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None):
252-
# Keep track of the initial position of the stream so that we can rewind it if
253-
# we need to retry the request.
254-
initial_data_position = 0
255-
if self._is_seekable_stream(data):
256-
initial_data_position = data.tell()
257-
258265
response = self._session.request(method,
259266
url,
260267
params=self._fix_query_string(query),
@@ -266,16 +273,8 @@ def _perform(self,
266273
stream=raw,
267274
timeout=self._http_timeout_seconds)
268275
self._record_request_log(response, raw=raw or data is not None or files is not None)
269-
270276
error = self._error_parser.get_api_error(response)
271277
if error is not None:
272-
# If the request body is a seekable stream, rewind it so that it is ready
273-
# to be read again in case of a retry.
274-
#
275-
# TODO: This should be moved into a "before-retry" hook to avoid one
276-
# unnecessary seek on the last failed retry before aborting.
277-
if self._is_seekable_stream(data):
278-
data.seek(initial_data_position)
279278
raise error from None
280279

281280
return response

databricks/sdk/retries.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ def retried(*,
1313
on: Sequence[Type[BaseException]] = None,
1414
is_retryable: Callable[[BaseException], Optional[str]] = None,
1515
timeout=timedelta(minutes=20),
16-
clock: Clock = None):
16+
clock: Clock = None,
17+
before_retry: Callable = None):
1718
has_allowlist = on is not None
1819
has_callback = is_retryable is not None
1920
if not (has_allowlist or has_callback) or (has_allowlist and has_callback):
@@ -54,6 +55,9 @@ def wrapper(*args, **kwargs):
5455
raise err
5556

5657
logger.debug(f'Retrying: {retry_reason} (sleeping ~{sleep}s)')
58+
if before_retry:
59+
before_retry()
60+
5761
clock.sleep(sleep + random())
5862
attempt += 1
5963
raise TimeoutError(f'Timed out after {timeout}') from last_err

0 commit comments

Comments
 (0)