Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 20 additions & 21 deletions databricks/sdk/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,29 @@ def do(self,
if isinstance(data, (str, bytes)):
data = io.BytesIO(data.encode('utf-8') if isinstance(data, str) else data)

# Only retry if the request is not a stream or if the stream is seekable and
# we can rewind it. This is necessary to avoid bugs where the retry doesn't
# re-read already read data from the body.
if data is not None and not self._is_seekable_stream(data):
logger.debug(f"Retry disabled for non-seekable stream: type={type(data)}")
call = self._perform
else:
if not data:
# The request is not a stream.
call = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
is_retryable=self._is_retryable,
clock=self._clock)(self._perform)
elif self._is_seekable_stream(data):
# Keep track of the initial position of the stream so that we can rewind to it
# if we need to retry the request.
initial_data_position = data.tell()

def rewind():
logger.debug(f"Rewinding input data to offset {initial_data_position} before retry")
data.seek(initial_data_position)

call = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
is_retryable=self._is_retryable,
clock=self._clock,
before_retry=rewind)(self._perform)
else:
# Do not retry if the stream is not seekable. This is necessary to avoid bugs
# where the retry doesn't re-read already read data from the stream.
logger.debug(f"Retry disabled for non-seekable stream: type={type(data)}")
call = self._perform

response = call(method,
url,
Expand Down Expand Up @@ -249,12 +262,6 @@ def _perform(self,
files=None,
data=None,
auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None):
# Keep track of the initial position of the stream so that we can rewind it if
# we need to retry the request.
initial_data_position = 0
if self._is_seekable_stream(data):
initial_data_position = data.tell()

response = self._session.request(method,
url,
params=self._fix_query_string(query),
Expand All @@ -266,16 +273,8 @@ def _perform(self,
stream=raw,
timeout=self._http_timeout_seconds)
self._record_request_log(response, raw=raw or data is not None or files is not None)

error = self._error_parser.get_api_error(response)
if error is not None:
# If the request body is a seekable stream, rewind it so that it is ready
# to be read again in case of a retry.
#
# TODO: This should be moved into a "before-retry" hook to avoid one
# unnecessary seek on the last failed retry before aborting.
if self._is_seekable_stream(data):
data.seek(initial_data_position)
raise error from None

return response
Expand Down
6 changes: 5 additions & 1 deletion databricks/sdk/retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def retried(*,
on: Sequence[Type[BaseException]] = None,
is_retryable: Callable[[BaseException], Optional[str]] = None,
timeout=timedelta(minutes=20),
clock: Clock = None):
clock: Clock = None,
before_retry: Callable = None):
has_allowlist = on is not None
has_callback = is_retryable is not None
if not (has_allowlist or has_callback) or (has_allowlist and has_callback):
Expand Down Expand Up @@ -54,6 +55,9 @@ def wrapper(*args, **kwargs):
raise err

logger.debug(f'Retrying: {retry_reason} (sleeping ~{sleep}s)')
if before_retry:
before_retry()

clock.sleep(sleep + random())
attempt += 1
raise TimeoutError(f'Timed out after {timeout}') from last_err
Expand Down
Loading