|
| 1 | +import logging |
| 2 | +from datetime import timedelta |
| 3 | +from types import TracebackType |
| 4 | +from typing import (Any, BinaryIO, Callable, Dict, Iterable, Iterator, List, |
| 5 | + Optional, Type, Union) |
| 6 | + |
| 7 | +import requests |
| 8 | +import requests.adapters |
| 9 | + |
| 10 | +from . import useragent |
| 11 | +from .casing import Casing |
| 12 | +from .clock import Clock, RealClock |
| 13 | +from .errors import DatabricksError, _ErrorCustomizer, _Parser |
| 14 | +from .logger import RoundTrip |
| 15 | +from .retries import retried |
| 16 | + |
| 17 | +logger = logging.getLogger('databricks.sdk') |
| 18 | + |
| 19 | + |
| 20 | +class _BaseClient: |
| 21 | + |
| 22 | + def __init__(self, |
| 23 | + debug_truncate_bytes: int = None, |
| 24 | + retry_timeout_seconds: int = None, |
| 25 | + user_agent_base: str = None, |
| 26 | + header_factory: Callable[[], dict] = None, |
| 27 | + max_connection_pools: int = None, |
| 28 | + max_connections_per_pool: int = None, |
| 29 | + pool_block: bool = True, |
| 30 | + http_timeout_seconds: float = None, |
| 31 | + extra_error_customizers: List[_ErrorCustomizer] = None, |
| 32 | + debug_headers: bool = False, |
| 33 | + clock: Clock = None): |
| 34 | + """ |
| 35 | + :param debug_truncate_bytes: |
| 36 | + :param retry_timeout_seconds: |
| 37 | + :param user_agent_base: |
| 38 | + :param header_factory: A function that returns a dictionary of headers to include in the request. |
| 39 | + :param max_connection_pools: Number of urllib3 connection pools to cache before discarding the least |
| 40 | + recently used pool. Python requests default value is 10. |
| 41 | + :param max_connections_per_pool: The maximum number of connections to save in the pool. Improves performance |
| 42 | + in multithreaded situations. For now, we're setting it to the same value as connection_pool_size. |
| 43 | + :param pool_block: If pool_block is False, then more connections will are created, but not saved after the |
| 44 | + first use. Blocks when no free connections are available. urllib3 ensures that no more than |
| 45 | + pool_maxsize connections are used at a time. Prevents platform from flooding. By default, requests library |
| 46 | + doesn't block. |
| 47 | + :param http_timeout_seconds: |
| 48 | + :param extra_error_customizers: |
| 49 | + :param debug_headers: Whether to include debug headers in the request log. |
| 50 | + :param clock: Clock object to use for time-related operations. |
| 51 | + """ |
| 52 | + |
| 53 | + self._debug_truncate_bytes = debug_truncate_bytes or 96 |
| 54 | + self._debug_headers = debug_headers |
| 55 | + self._retry_timeout_seconds = retry_timeout_seconds or 300 |
| 56 | + self._user_agent_base = user_agent_base or useragent.to_string() |
| 57 | + self._header_factory = header_factory |
| 58 | + self._clock = clock or RealClock() |
| 59 | + self._session = requests.Session() |
| 60 | + self._session.auth = self._authenticate |
| 61 | + |
| 62 | + # We don't use `max_retries` from HTTPAdapter to align with a more production-ready |
| 63 | + # retry strategy established in the Databricks SDK for Go. See _is_retryable and |
| 64 | + # @retried for more details. |
| 65 | + http_adapter = requests.adapters.HTTPAdapter(pool_connections=max_connections_per_pool or 20, |
| 66 | + pool_maxsize=max_connection_pools or 20, |
| 67 | + pool_block=pool_block) |
| 68 | + self._session.mount("https://", http_adapter) |
| 69 | + |
| 70 | + # Default to 60 seconds |
| 71 | + self._http_timeout_seconds = http_timeout_seconds or 60 |
| 72 | + |
| 73 | + self._error_parser = _Parser(extra_error_customizers=extra_error_customizers) |
| 74 | + |
| 75 | + def _authenticate(self, r: requests.PreparedRequest) -> requests.PreparedRequest: |
| 76 | + if self._header_factory: |
| 77 | + headers = self._header_factory() |
| 78 | + for k, v in headers.items(): |
| 79 | + r.headers[k] = v |
| 80 | + return r |
| 81 | + |
| 82 | + @staticmethod |
| 83 | + def _fix_query_string(query: Optional[dict] = None) -> Optional[dict]: |
| 84 | + # Convert True -> "true" for Databricks APIs to understand booleans. |
| 85 | + # See: https://github.com/databricks/databricks-sdk-py/issues/142 |
| 86 | + if query is None: |
| 87 | + return None |
| 88 | + with_fixed_bools = {k: v if type(v) != bool else ('true' if v else 'false') for k, v in query.items()} |
| 89 | + |
| 90 | + # Query parameters may be nested, e.g. |
| 91 | + # {'filter_by': {'user_ids': [123, 456]}} |
| 92 | + # The HTTP-compatible representation of this is |
| 93 | + # filter_by.user_ids=123&filter_by.user_ids=456 |
| 94 | + # To achieve this, we convert the above dictionary to |
| 95 | + # {'filter_by.user_ids': [123, 456]} |
| 96 | + # See the following for more information: |
| 97 | + # https://cloud.google.com/endpoints/docs/grpc-service-config/reference/rpc/google.api#google.api.HttpRule |
| 98 | + def flatten_dict(d: Dict[str, Any]) -> Dict[str, Any]: |
| 99 | + for k1, v1 in d.items(): |
| 100 | + if isinstance(v1, dict): |
| 101 | + v1 = dict(flatten_dict(v1)) |
| 102 | + for k2, v2 in v1.items(): |
| 103 | + yield f"{k1}.{k2}", v2 |
| 104 | + else: |
| 105 | + yield k1, v1 |
| 106 | + |
| 107 | + flattened = dict(flatten_dict(with_fixed_bools)) |
| 108 | + return flattened |
| 109 | + |
| 110 | + def do(self, |
| 111 | + method: str, |
| 112 | + url: str, |
| 113 | + query: dict = None, |
| 114 | + headers: dict = None, |
| 115 | + body: dict = None, |
| 116 | + raw: bool = False, |
| 117 | + files=None, |
| 118 | + data=None, |
| 119 | + auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None, |
| 120 | + response_headers: List[str] = None) -> Union[dict, list, BinaryIO]: |
| 121 | + if headers is None: |
| 122 | + headers = {} |
| 123 | + headers['User-Agent'] = self._user_agent_base |
| 124 | + retryable = retried(timeout=timedelta(seconds=self._retry_timeout_seconds), |
| 125 | + is_retryable=self._is_retryable, |
| 126 | + clock=self._clock) |
| 127 | + response = retryable(self._perform)(method, |
| 128 | + url, |
| 129 | + query=query, |
| 130 | + headers=headers, |
| 131 | + body=body, |
| 132 | + raw=raw, |
| 133 | + files=files, |
| 134 | + data=data, |
| 135 | + auth=auth) |
| 136 | + |
| 137 | + resp = dict() |
| 138 | + for header in response_headers if response_headers else []: |
| 139 | + resp[header] = response.headers.get(Casing.to_header_case(header)) |
| 140 | + if raw: |
| 141 | + resp["contents"] = _StreamingResponse(response) |
| 142 | + return resp |
| 143 | + if not len(response.content): |
| 144 | + return resp |
| 145 | + |
| 146 | + json_response = response.json() |
| 147 | + if json_response is None: |
| 148 | + return resp |
| 149 | + |
| 150 | + if isinstance(json_response, list): |
| 151 | + return json_response |
| 152 | + |
| 153 | + return {**resp, **json_response} |
| 154 | + |
| 155 | + @staticmethod |
| 156 | + def _is_retryable(err: BaseException) -> Optional[str]: |
| 157 | + # this method is Databricks-specific port of urllib3 retries |
| 158 | + # (see https://github.com/urllib3/urllib3/blob/main/src/urllib3/util/retry.py) |
| 159 | + # and Databricks SDK for Go retries |
| 160 | + # (see https://github.com/databricks/databricks-sdk-go/blob/main/apierr/errors.go) |
| 161 | + from urllib3.exceptions import ProxyError |
| 162 | + if isinstance(err, ProxyError): |
| 163 | + err = err.original_error |
| 164 | + if isinstance(err, requests.ConnectionError): |
| 165 | + # corresponds to `connection reset by peer` and `connection refused` errors from Go, |
| 166 | + # which are generally related to the temporary glitches in the networking stack, |
| 167 | + # also caused by endpoint protection software, like ZScaler, to drop connections while |
| 168 | + # not yet authenticated. |
| 169 | + # |
| 170 | + # return a simple string for debug log readability, as `raise TimeoutError(...) from err` |
| 171 | + # will bubble up the original exception in case we reach max retries. |
| 172 | + return f'cannot connect' |
| 173 | + if isinstance(err, requests.Timeout): |
| 174 | + # corresponds to `TLS handshake timeout` and `i/o timeout` in Go. |
| 175 | + # |
| 176 | + # return a simple string for debug log readability, as `raise TimeoutError(...) from err` |
| 177 | + # will bubble up the original exception in case we reach max retries. |
| 178 | + return f'timeout' |
| 179 | + if isinstance(err, DatabricksError): |
| 180 | + message = str(err) |
| 181 | + transient_error_string_matches = [ |
| 182 | + "com.databricks.backend.manager.util.UnknownWorkerEnvironmentException", |
| 183 | + "does not have any associated worker environments", "There is no worker environment with id", |
| 184 | + "Unknown worker environment", "ClusterNotReadyException", "Unexpected error", |
| 185 | + "Please try again later or try a faster operation.", |
| 186 | + "RPC token bucket limit has been exceeded", |
| 187 | + ] |
| 188 | + for substring in transient_error_string_matches: |
| 189 | + if substring not in message: |
| 190 | + continue |
| 191 | + return f'matched {substring}' |
| 192 | + return None |
| 193 | + |
| 194 | + def _perform(self, |
| 195 | + method: str, |
| 196 | + url: str, |
| 197 | + query: dict = None, |
| 198 | + headers: dict = None, |
| 199 | + body: dict = None, |
| 200 | + raw: bool = False, |
| 201 | + files=None, |
| 202 | + data=None, |
| 203 | + auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None): |
| 204 | + response = self._session.request(method, |
| 205 | + url, |
| 206 | + params=self._fix_query_string(query), |
| 207 | + json=body, |
| 208 | + headers=headers, |
| 209 | + files=files, |
| 210 | + data=data, |
| 211 | + auth=auth, |
| 212 | + stream=raw, |
| 213 | + timeout=self._http_timeout_seconds) |
| 214 | + self._record_request_log(response, raw=raw or data is not None or files is not None) |
| 215 | + error = self._error_parser.get_api_error(response) |
| 216 | + if error is not None: |
| 217 | + raise error from None |
| 218 | + return response |
| 219 | + |
| 220 | + def _record_request_log(self, response: requests.Response, raw: bool = False) -> None: |
| 221 | + if not logger.isEnabledFor(logging.DEBUG): |
| 222 | + return |
| 223 | + logger.debug(RoundTrip(response, self._debug_headers, self._debug_truncate_bytes, raw).generate()) |
| 224 | + |
| 225 | + |
| 226 | +class _StreamingResponse(BinaryIO): |
| 227 | + _response: requests.Response |
| 228 | + _buffer: bytes |
| 229 | + _content: Union[Iterator[bytes], None] |
| 230 | + _chunk_size: Union[int, None] |
| 231 | + _closed: bool = False |
| 232 | + |
| 233 | + def fileno(self) -> int: |
| 234 | + pass |
| 235 | + |
| 236 | + def flush(self) -> int: |
| 237 | + pass |
| 238 | + |
| 239 | + def __init__(self, response: requests.Response, chunk_size: Union[int, None] = None): |
| 240 | + self._response = response |
| 241 | + self._buffer = b'' |
| 242 | + self._content = None |
| 243 | + self._chunk_size = chunk_size |
| 244 | + |
| 245 | + def _open(self) -> None: |
| 246 | + if self._closed: |
| 247 | + raise ValueError("I/O operation on closed file") |
| 248 | + if not self._content: |
| 249 | + self._content = self._response.iter_content(chunk_size=self._chunk_size) |
| 250 | + |
| 251 | + def __enter__(self) -> BinaryIO: |
| 252 | + self._open() |
| 253 | + return self |
| 254 | + |
| 255 | + def set_chunk_size(self, chunk_size: Union[int, None]) -> None: |
| 256 | + self._chunk_size = chunk_size |
| 257 | + |
| 258 | + def close(self) -> None: |
| 259 | + self._response.close() |
| 260 | + self._closed = True |
| 261 | + |
| 262 | + def isatty(self) -> bool: |
| 263 | + return False |
| 264 | + |
| 265 | + def read(self, n: int = -1) -> bytes: |
| 266 | + self._open() |
| 267 | + read_everything = n < 0 |
| 268 | + remaining_bytes = n |
| 269 | + res = b'' |
| 270 | + while remaining_bytes > 0 or read_everything: |
| 271 | + if len(self._buffer) == 0: |
| 272 | + try: |
| 273 | + self._buffer = next(self._content) |
| 274 | + except StopIteration: |
| 275 | + break |
| 276 | + bytes_available = len(self._buffer) |
| 277 | + to_read = bytes_available if read_everything else min(remaining_bytes, bytes_available) |
| 278 | + res += self._buffer[:to_read] |
| 279 | + self._buffer = self._buffer[to_read:] |
| 280 | + remaining_bytes -= to_read |
| 281 | + return res |
| 282 | + |
| 283 | + def readable(self) -> bool: |
| 284 | + return self._content is not None |
| 285 | + |
| 286 | + def readline(self, __limit: int = ...) -> bytes: |
| 287 | + raise NotImplementedError() |
| 288 | + |
| 289 | + def readlines(self, __hint: int = ...) -> List[bytes]: |
| 290 | + raise NotImplementedError() |
| 291 | + |
| 292 | + def seek(self, __offset: int, __whence: int = ...) -> int: |
| 293 | + raise NotImplementedError() |
| 294 | + |
| 295 | + def seekable(self) -> bool: |
| 296 | + return False |
| 297 | + |
| 298 | + def tell(self) -> int: |
| 299 | + raise NotImplementedError() |
| 300 | + |
| 301 | + def truncate(self, __size: Union[int, None] = ...) -> int: |
| 302 | + raise NotImplementedError() |
| 303 | + |
| 304 | + def writable(self) -> bool: |
| 305 | + return False |
| 306 | + |
| 307 | + def write(self, s: Union[bytes, bytearray]) -> int: |
| 308 | + raise NotImplementedError() |
| 309 | + |
| 310 | + def writelines(self, lines: Iterable[bytes]) -> None: |
| 311 | + raise NotImplementedError() |
| 312 | + |
| 313 | + def __next__(self) -> bytes: |
| 314 | + return self.read(1) |
| 315 | + |
| 316 | + def __iter__(self) -> Iterator[bytes]: |
| 317 | + return self._content |
| 318 | + |
| 319 | + def __exit__(self, t: Union[Type[BaseException], None], value: Union[BaseException, None], |
| 320 | + traceback: Union[TracebackType, None]) -> None: |
| 321 | + self._content = None |
| 322 | + self._buffer = b'' |
| 323 | + self.close() |
0 commit comments