Skip to content

Commit 230ed00

Browse files
committed
first work
1 parent 5e871cb commit 230ed00

File tree

10 files changed

+988
-721
lines changed

10 files changed

+988
-721
lines changed

databricks/sdk/_base_client.py

Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
import logging
2+
from datetime import timedelta
3+
from types import TracebackType
4+
from typing import (Any, BinaryIO, Callable, Dict, Iterable, Iterator, List,
5+
Optional, Type, Union)
6+
import urllib.parse
7+
8+
import requests
9+
import requests.adapters
10+
11+
from . import useragent
12+
from .casing import Casing
13+
from .clock import Clock, RealClock
14+
from .errors import DatabricksError, _ErrorCustomizer, _Parser
15+
from .logger import RoundTrip
16+
from .retries import retried
17+
18+
logger = logging.getLogger('databricks.sdk')
19+
20+
21+
def fix_host_if_needed(host: Optional[str]) -> Optional[str]:
22+
if not host:
23+
return host
24+
25+
# Add a default scheme if it's missing
26+
if '://' not in host:
27+
host = 'https://' + host
28+
29+
o = urllib.parse.urlparse(host)
30+
# remove trailing slash
31+
path = o.path.rstrip('/')
32+
# remove port if 443
33+
netloc = o.netloc
34+
if o.port == 443:
35+
netloc = netloc.split(':')[0]
36+
37+
return urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment))
38+
39+
40+
class _BaseClient:
41+
42+
def __init__(self,
43+
debug_truncate_bytes: int = None,
44+
retry_timeout_seconds: int = None,
45+
user_agent_base: str = None,
46+
header_factory: Callable[[], dict] = None,
47+
max_connection_pools: int = None,
48+
max_connections_per_pool: int = None,
49+
pool_block: bool = True,
50+
http_timeout_seconds: float = None,
51+
extra_error_customizers: List[_ErrorCustomizer] = None,
52+
debug_headers: bool = False,
53+
clock: Clock = None):
54+
"""
55+
:param debug_truncate_bytes:
56+
:param retry_timeout_seconds:
57+
:param user_agent_base:
58+
:param header_factory: A function that returns a dictionary of headers to include in the request.
59+
:param max_connection_pools: Number of urllib3 connection pools to cache before discarding the least
60+
recently used pool. Python requests default value is 10.
61+
:param max_connections_per_pool: The maximum number of connections to save in the pool. Improves performance
62+
in multithreaded situations. For now, we're setting it to the same value as connection_pool_size.
63+
:param pool_block: If pool_block is False, then more connections will are created, but not saved after the
64+
first use. Blocks when no free connections are available. urllib3 ensures that no more than
65+
pool_maxsize connections are used at a time. Prevents platform from flooding. By default, requests library
66+
doesn't block.
67+
:param http_timeout_seconds:
68+
:param extra_error_customizers:
69+
:param debug_headers: Whether to include debug headers in the request log.
70+
:param clock: Clock object to use for time-related operations.
71+
"""
72+
73+
self._debug_truncate_bytes = debug_truncate_bytes or 96
74+
self._debug_headers = debug_headers
75+
self._retry_timeout_seconds = retry_timeout_seconds or 300
76+
self._user_agent_base = user_agent_base or useragent.to_string()
77+
self._header_factory = header_factory
78+
self._clock = clock or RealClock()
79+
self._session = requests.Session()
80+
self._session.auth = self._authenticate
81+
82+
# We don't use `max_retries` from HTTPAdapter to align with a more production-ready
83+
# retry strategy established in the Databricks SDK for Go. See _is_retryable and
84+
# @retried for more details.
85+
http_adapter = requests.adapters.HTTPAdapter(pool_connections=max_connections_per_pool or 20,
86+
pool_maxsize=max_connection_pools or 20,
87+
pool_block=pool_block)
88+
self._session.mount("https://", http_adapter)
89+
90+
# Default to 60 seconds
91+
self._http_timeout_seconds = http_timeout_seconds or 60
92+
93+
self._error_parser = _Parser(extra_error_customizers=extra_error_customizers)
94+
95+
def _authenticate(self, r: requests.PreparedRequest) -> requests.PreparedRequest:
96+
if self._header_factory:
97+
headers = self._header_factory()
98+
for k, v in headers.items():
99+
r.headers[k] = v
100+
return r
101+
102+
@staticmethod
103+
def _fix_query_string(query: Optional[dict] = None) -> Optional[dict]:
104+
# Convert True -> "true" for Databricks APIs to understand booleans.
105+
# See: https://github.com/databricks/databricks-sdk-py/issues/142
106+
if query is None:
107+
return None
108+
with_fixed_bools = {k: v if type(v) != bool else ('true' if v else 'false') for k, v in query.items()}
109+
110+
# Query parameters may be nested, e.g.
111+
# {'filter_by': {'user_ids': [123, 456]}}
112+
# The HTTP-compatible representation of this is
113+
# filter_by.user_ids=123&filter_by.user_ids=456
114+
# To achieve this, we convert the above dictionary to
115+
# {'filter_by.user_ids': [123, 456]}
116+
# See the following for more information:
117+
# https://cloud.google.com/endpoints/docs/grpc-service-config/reference/rpc/google.api#google.api.HttpRule
118+
def flatten_dict(d: Dict[str, Any]) -> Dict[str, Any]:
119+
for k1, v1 in d.items():
120+
if isinstance(v1, dict):
121+
v1 = dict(flatten_dict(v1))
122+
for k2, v2 in v1.items():
123+
yield f"{k1}.{k2}", v2
124+
else:
125+
yield k1, v1
126+
127+
flattened = dict(flatten_dict(with_fixed_bools))
128+
return flattened
129+
130+
def do(self,
131+
method: str,
132+
url: str,
133+
query: dict = None,
134+
headers: dict = None,
135+
body: dict = None,
136+
raw: bool = False,
137+
files=None,
138+
data=None,
139+
auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None,
140+
response_headers: List[str] = None) -> Union[dict, list, BinaryIO]:
141+
if headers is None:
142+
headers = {}
143+
headers['User-Agent'] = self._user_agent_base
144+
retryable = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
145+
is_retryable=self._is_retryable,
146+
clock=self._clock)
147+
response = retryable(self._perform)(method,
148+
url,
149+
query=query,
150+
headers=headers,
151+
body=body,
152+
raw=raw,
153+
files=files,
154+
data=data,
155+
auth=auth)
156+
157+
resp = dict()
158+
for header in response_headers if response_headers else []:
159+
resp[header] = response.headers.get(Casing.to_header_case(header))
160+
if raw:
161+
resp["contents"] = _StreamingResponse(response)
162+
return resp
163+
if not len(response.content):
164+
return resp
165+
166+
json_response = response.json()
167+
if json_response is None:
168+
return resp
169+
170+
if isinstance(json_response, list):
171+
return json_response
172+
173+
return {**resp, **json_response}
174+
175+
@staticmethod
176+
def _is_retryable(err: BaseException) -> Optional[str]:
177+
# this method is Databricks-specific port of urllib3 retries
178+
# (see https://github.com/urllib3/urllib3/blob/main/src/urllib3/util/retry.py)
179+
# and Databricks SDK for Go retries
180+
# (see https://github.com/databricks/databricks-sdk-go/blob/main/apierr/errors.go)
181+
from urllib3.exceptions import ProxyError
182+
if isinstance(err, ProxyError):
183+
err = err.original_error
184+
if isinstance(err, requests.ConnectionError):
185+
# corresponds to `connection reset by peer` and `connection refused` errors from Go,
186+
# which are generally related to the temporary glitches in the networking stack,
187+
# also caused by endpoint protection software, like ZScaler, to drop connections while
188+
# not yet authenticated.
189+
#
190+
# return a simple string for debug log readability, as `raise TimeoutError(...) from err`
191+
# will bubble up the original exception in case we reach max retries.
192+
return f'cannot connect'
193+
if isinstance(err, requests.Timeout):
194+
# corresponds to `TLS handshake timeout` and `i/o timeout` in Go.
195+
#
196+
# return a simple string for debug log readability, as `raise TimeoutError(...) from err`
197+
# will bubble up the original exception in case we reach max retries.
198+
return f'timeout'
199+
if isinstance(err, DatabricksError):
200+
message = str(err)
201+
transient_error_string_matches = [
202+
"com.databricks.backend.manager.util.UnknownWorkerEnvironmentException",
203+
"does not have any associated worker environments", "There is no worker environment with id",
204+
"Unknown worker environment", "ClusterNotReadyException", "Unexpected error",
205+
"Please try again later or try a faster operation.",
206+
"RPC token bucket limit has been exceeded",
207+
]
208+
for substring in transient_error_string_matches:
209+
if substring not in message:
210+
continue
211+
return f'matched {substring}'
212+
return None
213+
214+
def _perform(self,
215+
method: str,
216+
url: str,
217+
query: dict = None,
218+
headers: dict = None,
219+
body: dict = None,
220+
raw: bool = False,
221+
files=None,
222+
data=None,
223+
auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None):
224+
response = self._session.request(method,
225+
url,
226+
params=self._fix_query_string(query),
227+
json=body,
228+
headers=headers,
229+
files=files,
230+
data=data,
231+
auth=auth,
232+
stream=raw,
233+
timeout=self._http_timeout_seconds)
234+
self._record_request_log(response, raw=raw or data is not None or files is not None)
235+
error = self._error_parser.get_api_error(response)
236+
if error is not None:
237+
raise error from None
238+
return response
239+
240+
def _record_request_log(self, response: requests.Response, raw: bool = False) -> None:
241+
if not logger.isEnabledFor(logging.DEBUG):
242+
return
243+
logger.debug(RoundTrip(response, self._debug_headers, self._debug_truncate_bytes, raw).generate())
244+
245+
246+
class _StreamingResponse(BinaryIO):
247+
_response: requests.Response
248+
_buffer: bytes
249+
_content: Union[Iterator[bytes], None]
250+
_chunk_size: Union[int, None]
251+
_closed: bool = False
252+
253+
def fileno(self) -> int:
254+
pass
255+
256+
def flush(self) -> int:
257+
pass
258+
259+
def __init__(self, response: requests.Response, chunk_size: Union[int, None] = None):
260+
self._response = response
261+
self._buffer = b''
262+
self._content = None
263+
self._chunk_size = chunk_size
264+
265+
def _open(self) -> None:
266+
if self._closed:
267+
raise ValueError("I/O operation on closed file")
268+
if not self._content:
269+
self._content = self._response.iter_content(chunk_size=self._chunk_size)
270+
271+
def __enter__(self) -> BinaryIO:
272+
self._open()
273+
return self
274+
275+
def set_chunk_size(self, chunk_size: Union[int, None]) -> None:
276+
self._chunk_size = chunk_size
277+
278+
def close(self) -> None:
279+
self._response.close()
280+
self._closed = True
281+
282+
def isatty(self) -> bool:
283+
return False
284+
285+
def read(self, n: int = -1) -> bytes:
286+
self._open()
287+
read_everything = n < 0
288+
remaining_bytes = n
289+
res = b''
290+
while remaining_bytes > 0 or read_everything:
291+
if len(self._buffer) == 0:
292+
try:
293+
self._buffer = next(self._content)
294+
except StopIteration:
295+
break
296+
bytes_available = len(self._buffer)
297+
to_read = bytes_available if read_everything else min(remaining_bytes, bytes_available)
298+
res += self._buffer[:to_read]
299+
self._buffer = self._buffer[to_read:]
300+
remaining_bytes -= to_read
301+
return res
302+
303+
def readable(self) -> bool:
304+
return self._content is not None
305+
306+
def readline(self, __limit: int = ...) -> bytes:
307+
raise NotImplementedError()
308+
309+
def readlines(self, __hint: int = ...) -> List[bytes]:
310+
raise NotImplementedError()
311+
312+
def seek(self, __offset: int, __whence: int = ...) -> int:
313+
raise NotImplementedError()
314+
315+
def seekable(self) -> bool:
316+
return False
317+
318+
def tell(self) -> int:
319+
raise NotImplementedError()
320+
321+
def truncate(self, __size: Union[int, None] = ...) -> int:
322+
raise NotImplementedError()
323+
324+
def writable(self) -> bool:
325+
return False
326+
327+
def write(self, s: Union[bytes, bytearray]) -> int:
328+
raise NotImplementedError()
329+
330+
def writelines(self, lines: Iterable[bytes]) -> None:
331+
raise NotImplementedError()
332+
333+
def __next__(self) -> bytes:
334+
return self.read(1)
335+
336+
def __iter__(self) -> Iterator[bytes]:
337+
return self._content
338+
339+
def __exit__(self, t: Union[Type[BaseException], None], value: Union[BaseException, None],
340+
traceback: Union[TracebackType, None]) -> None:
341+
self._content = None
342+
self._buffer = b''
343+
self.close()

0 commit comments

Comments
 (0)