1717from .logger import RoundTrip
1818from .retries import retried
1919
20- logger = logging .getLogger (' databricks.sdk' )
20+ logger = logging .getLogger (" databricks.sdk" )
2121
2222
2323def _fix_host_if_needed (host : Optional [str ]) -> Optional [str ]:
2424 if not host :
2525 return host
2626
2727 # Add a default scheme if it's missing
28- if ' ://' not in host :
29- host = ' https://' + host
28+ if " ://" not in host :
29+ host = " https://" + host
3030
3131 o = urllib .parse .urlparse (host )
3232 # remove trailing slash
33- path = o .path .rstrip ('/' )
33+ path = o .path .rstrip ("/" )
3434 # remove port if 443
3535 netloc = o .netloc
3636 if o .port == 443 :
37- netloc = netloc .split (':' )[0 ]
37+ netloc = netloc .split (":" )[0 ]
3838
3939 return urllib .parse .urlunparse ((o .scheme , netloc , path , o .params , o .query , o .fragment ))
4040
4141
4242class _BaseClient :
4343
44- def __init__ (self ,
45- debug_truncate_bytes : int = None ,
46- retry_timeout_seconds : int = None ,
47- user_agent_base : str = None ,
48- header_factory : Callable [[], dict ] = None ,
49- max_connection_pools : int = None ,
50- max_connections_per_pool : int = None ,
51- pool_block : bool = True ,
52- http_timeout_seconds : float = None ,
53- extra_error_customizers : List [_ErrorCustomizer ] = None ,
54- debug_headers : bool = False ,
55- clock : Clock = None ,
56- streaming_buffer_size : int = 1024 * 1024 ): # 1MB
44+ def __init__ (
45+ self ,
46+ debug_truncate_bytes : int = None ,
47+ retry_timeout_seconds : int = None ,
48+ user_agent_base : str = None ,
49+ header_factory : Callable [[], dict ] = None ,
50+ max_connection_pools : int = None ,
51+ max_connections_per_pool : int = None ,
52+ pool_block : bool = True ,
53+ http_timeout_seconds : float = None ,
54+ extra_error_customizers : List [_ErrorCustomizer ] = None ,
55+ debug_headers : bool = False ,
56+ clock : Clock = None ,
57+ streaming_buffer_size : int = 1024 * 1024 ,
58+ ): # 1MB
5759 """
5860 :param debug_truncate_bytes:
5961 :param retry_timeout_seconds:
@@ -87,9 +89,11 @@ def __init__(self,
8789 # We don't use `max_retries` from HTTPAdapter to align with a more production-ready
8890 # retry strategy established in the Databricks SDK for Go. See _is_retryable and
8991 # @retried for more details.
90- http_adapter = requests .adapters .HTTPAdapter (pool_connections = max_connections_per_pool or 20 ,
91- pool_maxsize = max_connection_pools or 20 ,
92- pool_block = pool_block )
92+ http_adapter = requests .adapters .HTTPAdapter (
93+ pool_connections = max_connections_per_pool or 20 ,
94+ pool_maxsize = max_connection_pools or 20 ,
95+ pool_block = pool_block ,
96+ )
9397 self ._session .mount ("https://" , http_adapter )
9498
9599 # Default to 60 seconds
@@ -110,7 +114,7 @@ def _fix_query_string(query: Optional[dict] = None) -> Optional[dict]:
110114 # See: https://github.com/databricks/databricks-sdk-py/issues/142
111115 if query is None :
112116 return None
113- with_fixed_bools = {k : v if type (v ) != bool else (' true' if v else ' false' ) for k , v in query .items ()}
117+ with_fixed_bools = {k : v if type (v ) != bool else (" true" if v else " false" ) for k , v in query .items ()}
114118
115119 # Query parameters may be nested, e.g.
116120 # {'filter_by': {'user_ids': [123, 456]}}
@@ -140,30 +144,34 @@ def _is_seekable_stream(data) -> bool:
140144 return False
141145 return data .seekable ()
142146
143- def do (self ,
144- method : str ,
145- url : str ,
146- query : dict = None ,
147- headers : dict = None ,
148- body : dict = None ,
149- raw : bool = False ,
150- files = None ,
151- data = None ,
152- auth : Callable [[requests .PreparedRequest ], requests .PreparedRequest ] = None ,
153- response_headers : List [str ] = None ) -> Union [dict , list , BinaryIO ]:
147+ def do (
148+ self ,
149+ method : str ,
150+ url : str ,
151+ query : dict = None ,
152+ headers : dict = None ,
153+ body : dict = None ,
154+ raw : bool = False ,
155+ files = None ,
156+ data = None ,
157+ auth : Callable [[requests .PreparedRequest ], requests .PreparedRequest ] = None ,
158+ response_headers : List [str ] = None ,
159+ ) -> Union [dict , list , BinaryIO ]:
154160 if headers is None :
155161 headers = {}
156- headers [' User-Agent' ] = self ._user_agent_base
162+ headers [" User-Agent" ] = self ._user_agent_base
157163
158164 # Wrap strings and bytes in a seekable stream so that we can rewind them.
159165 if isinstance (data , (str , bytes )):
160- data = io .BytesIO (data .encode (' utf-8' ) if isinstance (data , str ) else data )
166+ data = io .BytesIO (data .encode (" utf-8" ) if isinstance (data , str ) else data )
161167
162168 if not data :
163169 # The request is not a stream.
164- call = retried (timeout = timedelta (seconds = self ._retry_timeout_seconds ),
165- is_retryable = self ._is_retryable ,
166- clock = self ._clock )(self ._perform )
170+ call = retried (
171+ timeout = timedelta (seconds = self ._retry_timeout_seconds ),
172+ is_retryable = self ._is_retryable ,
173+ clock = self ._clock ,
174+ )(self ._perform )
167175 elif self ._is_seekable_stream (data ):
168176 # Keep track of the initial position of the stream so that we can rewind to it
169177 # if we need to retry the request.
@@ -173,25 +181,29 @@ def rewind():
173181 logger .debug (f"Rewinding input data to offset { initial_data_position } before retry" )
174182 data .seek (initial_data_position )
175183
176- call = retried (timeout = timedelta (seconds = self ._retry_timeout_seconds ),
177- is_retryable = self ._is_retryable ,
178- clock = self ._clock ,
179- before_retry = rewind )(self ._perform )
184+ call = retried (
185+ timeout = timedelta (seconds = self ._retry_timeout_seconds ),
186+ is_retryable = self ._is_retryable ,
187+ clock = self ._clock ,
188+ before_retry = rewind ,
189+ )(self ._perform )
180190 else :
181191 # Do not retry if the stream is not seekable. This is necessary to avoid bugs
182192 # where the retry doesn't re-read already read data from the stream.
183193 logger .debug (f"Retry disabled for non-seekable stream: type={ type (data )} " )
184194 call = self ._perform
185195
186- response = call (method ,
187- url ,
188- query = query ,
189- headers = headers ,
190- body = body ,
191- raw = raw ,
192- files = files ,
193- data = data ,
194- auth = auth )
196+ response = call (
197+ method ,
198+ url ,
199+ query = query ,
200+ headers = headers ,
201+ body = body ,
202+ raw = raw ,
203+ files = files ,
204+ data = data ,
205+ auth = auth ,
206+ )
195207
196208 resp = dict ()
197209 for header in response_headers if response_headers else []:
@@ -220,6 +232,7 @@ def _is_retryable(err: BaseException) -> Optional[str]:
220232 # and Databricks SDK for Go retries
221233 # (see https://github.com/databricks/databricks-sdk-go/blob/main/apierr/errors.go)
222234 from urllib3 .exceptions import ProxyError
235+
223236 if isinstance (err , ProxyError ):
224237 err = err .original_error
225238 if isinstance (err , requests .ConnectionError ):
@@ -230,48 +243,55 @@ def _is_retryable(err: BaseException) -> Optional[str]:
230243 #
231244 # return a simple string for debug log readability, as `raise TimeoutError(...) from err`
232245 # will bubble up the original exception in case we reach max retries.
233- return f' cannot connect'
246+ return f" cannot connect"
234247 if isinstance (err , requests .Timeout ):
235248 # corresponds to `TLS handshake timeout` and `i/o timeout` in Go.
236249 #
237250 # return a simple string for debug log readability, as `raise TimeoutError(...) from err`
238251 # will bubble up the original exception in case we reach max retries.
239- return f' timeout'
252+ return f" timeout"
240253 if isinstance (err , DatabricksError ):
241254 message = str (err )
242255 transient_error_string_matches = [
243256 "com.databricks.backend.manager.util.UnknownWorkerEnvironmentException" ,
244- "does not have any associated worker environments" , "There is no worker environment with id" ,
245- "Unknown worker environment" , "ClusterNotReadyException" , "Unexpected error" ,
257+ "does not have any associated worker environments" ,
258+ "There is no worker environment with id" ,
259+ "Unknown worker environment" ,
260+ "ClusterNotReadyException" ,
261+ "Unexpected error" ,
246262 "Please try again later or try a faster operation." ,
247263 "RPC token bucket limit has been exceeded" ,
248264 ]
249265 for substring in transient_error_string_matches :
250266 if substring not in message :
251267 continue
252- return f' matched { substring } '
268+ return f" matched { substring } "
253269 return None
254270
255- def _perform (self ,
256- method : str ,
257- url : str ,
258- query : dict = None ,
259- headers : dict = None ,
260- body : dict = None ,
261- raw : bool = False ,
262- files = None ,
263- data = None ,
264- auth : Callable [[requests .PreparedRequest ], requests .PreparedRequest ] = None ):
265- response = self ._session .request (method ,
266- url ,
267- params = self ._fix_query_string (query ),
268- json = body ,
269- headers = headers ,
270- files = files ,
271- data = data ,
272- auth = auth ,
273- stream = raw ,
274- timeout = self ._http_timeout_seconds )
271+ def _perform (
272+ self ,
273+ method : str ,
274+ url : str ,
275+ query : dict = None ,
276+ headers : dict = None ,
277+ body : dict = None ,
278+ raw : bool = False ,
279+ files = None ,
280+ data = None ,
281+ auth : Callable [[requests .PreparedRequest ], requests .PreparedRequest ] = None ,
282+ ):
283+ response = self ._session .request (
284+ method ,
285+ url ,
286+ params = self ._fix_query_string (query ),
287+ json = body ,
288+ headers = headers ,
289+ files = files ,
290+ data = data ,
291+ auth = auth ,
292+ stream = raw ,
293+ timeout = self ._http_timeout_seconds ,
294+ )
275295 self ._record_request_log (response , raw = raw or data is not None or files is not None )
276296 error = self ._error_parser .get_api_error (response )
277297 if error is not None :
@@ -312,7 +332,7 @@ def flush(self) -> int:
312332
313333 def __init__ (self , response : _RawResponse , chunk_size : Union [int , None ] = None ):
314334 self ._response = response
315- self ._buffer = b''
335+ self ._buffer = b""
316336 self ._content = None
317337 self ._chunk_size = chunk_size
318338
@@ -338,14 +358,14 @@ def isatty(self) -> bool:
338358
339359 def read (self , n : int = - 1 ) -> bytes :
340360 """
341- Read up to n bytes from the response stream. If n is negative, read
342- until the end of the stream.
361+ Read up to n bytes from the response stream. If n is negative, read
362+ until the end of the stream.
343363 """
344364
345365 self ._open ()
346366 read_everything = n < 0
347367 remaining_bytes = n
348- res = b''
368+ res = b""
349369 while remaining_bytes > 0 or read_everything :
350370 if len (self ._buffer ) == 0 :
351371 try :
@@ -395,8 +415,12 @@ def __next__(self) -> bytes:
395415 def __iter__ (self ) -> Iterator [bytes ]:
396416 return self ._content
397417
398- def __exit__ (self , t : Union [Type [BaseException ], None ], value : Union [BaseException , None ],
399- traceback : Union [TracebackType , None ]) -> None :
418+ def __exit__ (
419+ self ,
420+ t : Union [Type [BaseException ], None ],
421+ value : Union [BaseException , None ],
422+ traceback : Union [TracebackType , None ],
423+ ) -> None :
400424 self ._content = None
401- self ._buffer = b''
425+ self ._buffer = b""
402426 self .close ()
0 commit comments