1+ import io
12import logging
23import urllib .parse
34from datetime import timedelta
@@ -130,6 +131,14 @@ def flatten_dict(d: Dict[str, Any]) -> Dict[str, Any]:
130131 flattened = dict (flatten_dict (with_fixed_bools ))
131132 return flattened
132133
134+ @staticmethod
135+ def _is_seekable_stream (data ) -> bool :
136+ if data is None :
137+ return False
138+ if not isinstance (data , io .IOBase ):
139+ return False
140+ return data .seekable ()
141+
133142 def do (self ,
134143 method : str ,
135144 url : str ,
@@ -144,18 +153,31 @@ def do(self,
144153 if headers is None :
145154 headers = {}
146155 headers ['User-Agent' ] = self ._user_agent_base
147- retryable = retried (timeout = timedelta (seconds = self ._retry_timeout_seconds ),
148- is_retryable = self ._is_retryable ,
149- clock = self ._clock )
150- response = retryable (self ._perform )(method ,
151- url ,
152- query = query ,
153- headers = headers ,
154- body = body ,
155- raw = raw ,
156- files = files ,
157- data = data ,
158- auth = auth )
156+
157+ # Wrap strings and bytes in a seekable stream so that we can rewind them.
158+ if isinstance (data , (str , bytes )):
159+ data = io .BytesIO (data .encode ('utf-8' ) if isinstance (data , str ) else data )
160+
161+ # Only retry if the request is not a stream or if the stream is seekable and
162+ # we can rewind it. This is necessary to avoid bugs where the retry doesn't
163+ # re-read already read data from the body.
164+ if data is not None and not self ._is_seekable_stream (data ):
165+ logger .debug (f"Retry disabled for non-seekable stream: type={ type (data )} " )
166+ call = self ._perform
167+ else :
168+ call = retried (timeout = timedelta (seconds = self ._retry_timeout_seconds ),
169+ is_retryable = self ._is_retryable ,
170+ clock = self ._clock )(self ._perform )
171+
172+ response = call (method ,
173+ url ,
174+ query = query ,
175+ headers = headers ,
176+ body = body ,
177+ raw = raw ,
178+ files = files ,
179+ data = data ,
180+ auth = auth )
159181
160182 resp = dict ()
161183 for header in response_headers if response_headers else []:
@@ -226,6 +248,12 @@ def _perform(self,
226248 files = None ,
227249 data = None ,
228250 auth : Callable [[requests .PreparedRequest ], requests .PreparedRequest ] = None ):
251+ # Keep track of the initial position of the stream so that we can rewind it if
252+ # we need to retry the request.
253+ initial_data_position = 0
254+ if self ._is_seekable_stream (data ):
255+ initial_data_position = data .tell ()
256+
229257 response = self ._session .request (method ,
230258 url ,
231259 params = self ._fix_query_string (query ),
@@ -237,9 +265,18 @@ def _perform(self,
237265 stream = raw ,
238266 timeout = self ._http_timeout_seconds )
239267 self ._record_request_log (response , raw = raw or data is not None or files is not None )
268+
240269 error = self ._error_parser .get_api_error (response )
241270 if error is not None :
271+ # If the request body is a seekable stream, rewind it so that it is ready
272+ # to be read again in case of a retry.
273+ #
274+ # TODO: This should be moved into a "before-retry" hook to avoid one
275+ # unnecessary seek on the last failed retry before aborting.
276+ if self ._is_seekable_stream (data ):
277+ data .seek (initial_data_position )
242278 raise error from None
279+
243280 return response
244281
245282 def _record_request_log (self , response : requests .Response , raw : bool = False ) -> None :
0 commit comments