diff --git a/databricks/sdk/_base_client.py b/databricks/sdk/_base_client.py new file mode 100644 index 000000000..62c2974ec --- /dev/null +++ b/databricks/sdk/_base_client.py @@ -0,0 +1,323 @@ +import logging +from datetime import timedelta +from types import TracebackType +from typing import (Any, BinaryIO, Callable, Dict, Iterable, Iterator, List, + Optional, Type, Union) + +import requests +import requests.adapters + +from . import useragent +from .casing import Casing +from .clock import Clock, RealClock +from .errors import DatabricksError, _ErrorCustomizer, _Parser +from .logger import RoundTrip +from .retries import retried + +logger = logging.getLogger('databricks.sdk') + + +class _BaseClient: + + def __init__(self, + debug_truncate_bytes: int = None, + retry_timeout_seconds: int = None, + user_agent_base: str = None, + header_factory: Callable[[], dict] = None, + max_connection_pools: int = None, + max_connections_per_pool: int = None, + pool_block: bool = True, + http_timeout_seconds: float = None, + extra_error_customizers: List[_ErrorCustomizer] = None, + debug_headers: bool = False, + clock: Clock = None): + """ + :param debug_truncate_bytes: + :param retry_timeout_seconds: + :param user_agent_base: + :param header_factory: A function that returns a dictionary of headers to include in the request. + :param max_connection_pools: Number of urllib3 connection pools to cache before discarding the least + recently used pool. Python requests default value is 10. + :param max_connections_per_pool: The maximum number of connections to save in the pool. Improves performance + in multithreaded situations. For now, we're setting it to the same value as connection_pool_size. + :param pool_block: If pool_block is False, then more connections will are created, but not saved after the + first use. Blocks when no free connections are available. urllib3 ensures that no more than + pool_maxsize connections are used at a time. Prevents platform from flooding. By default, requests library + doesn't block. + :param http_timeout_seconds: + :param extra_error_customizers: + :param debug_headers: Whether to include debug headers in the request log. + :param clock: Clock object to use for time-related operations. + """ + + self._debug_truncate_bytes = debug_truncate_bytes or 96 + self._debug_headers = debug_headers + self._retry_timeout_seconds = retry_timeout_seconds or 300 + self._user_agent_base = user_agent_base or useragent.to_string() + self._header_factory = header_factory + self._clock = clock or RealClock() + self._session = requests.Session() + self._session.auth = self._authenticate + + # We don't use `max_retries` from HTTPAdapter to align with a more production-ready + # retry strategy established in the Databricks SDK for Go. See _is_retryable and + # @retried for more details. + http_adapter = requests.adapters.HTTPAdapter(pool_connections=max_connections_per_pool or 20, + pool_maxsize=max_connection_pools or 20, + pool_block=pool_block) + self._session.mount("https://", http_adapter) + + # Default to 60 seconds + self._http_timeout_seconds = http_timeout_seconds or 60 + + self._error_parser = _Parser(extra_error_customizers=extra_error_customizers) + + def _authenticate(self, r: requests.PreparedRequest) -> requests.PreparedRequest: + if self._header_factory: + headers = self._header_factory() + for k, v in headers.items(): + r.headers[k] = v + return r + + @staticmethod + def _fix_query_string(query: Optional[dict] = None) -> Optional[dict]: + # Convert True -> "true" for Databricks APIs to understand booleans. + # See: https://github.com/databricks/databricks-sdk-py/issues/142 + if query is None: + return None + with_fixed_bools = {k: v if type(v) != bool else ('true' if v else 'false') for k, v in query.items()} + + # Query parameters may be nested, e.g. + # {'filter_by': {'user_ids': [123, 456]}} + # The HTTP-compatible representation of this is + # filter_by.user_ids=123&filter_by.user_ids=456 + # To achieve this, we convert the above dictionary to + # {'filter_by.user_ids': [123, 456]} + # See the following for more information: + # https://cloud.google.com/endpoints/docs/grpc-service-config/reference/rpc/google.api#google.api.HttpRule + def flatten_dict(d: Dict[str, Any]) -> Dict[str, Any]: + for k1, v1 in d.items(): + if isinstance(v1, dict): + v1 = dict(flatten_dict(v1)) + for k2, v2 in v1.items(): + yield f"{k1}.{k2}", v2 + else: + yield k1, v1 + + flattened = dict(flatten_dict(with_fixed_bools)) + return flattened + + def do(self, + method: str, + url: str, + query: dict = None, + headers: dict = None, + body: dict = None, + raw: bool = False, + files=None, + data=None, + auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None, + response_headers: List[str] = None) -> Union[dict, list, BinaryIO]: + if headers is None: + headers = {} + headers['User-Agent'] = self._user_agent_base + retryable = retried(timeout=timedelta(seconds=self._retry_timeout_seconds), + is_retryable=self._is_retryable, + clock=self._clock) + response = retryable(self._perform)(method, + url, + query=query, + headers=headers, + body=body, + raw=raw, + files=files, + data=data, + auth=auth) + + resp = dict() + for header in response_headers if response_headers else []: + resp[header] = response.headers.get(Casing.to_header_case(header)) + if raw: + resp["contents"] = _StreamingResponse(response) + return resp + if not len(response.content): + return resp + + json_response = response.json() + if json_response is None: + return resp + + if isinstance(json_response, list): + return json_response + + return {**resp, **json_response} + + @staticmethod + def _is_retryable(err: BaseException) -> Optional[str]: + # this method is Databricks-specific port of urllib3 retries + # (see https://github.com/urllib3/urllib3/blob/main/src/urllib3/util/retry.py) + # and Databricks SDK for Go retries + # (see https://github.com/databricks/databricks-sdk-go/blob/main/apierr/errors.go) + from urllib3.exceptions import ProxyError + if isinstance(err, ProxyError): + err = err.original_error + if isinstance(err, requests.ConnectionError): + # corresponds to `connection reset by peer` and `connection refused` errors from Go, + # which are generally related to the temporary glitches in the networking stack, + # also caused by endpoint protection software, like ZScaler, to drop connections while + # not yet authenticated. + # + # return a simple string for debug log readability, as `raise TimeoutError(...) from err` + # will bubble up the original exception in case we reach max retries. + return f'cannot connect' + if isinstance(err, requests.Timeout): + # corresponds to `TLS handshake timeout` and `i/o timeout` in Go. + # + # return a simple string for debug log readability, as `raise TimeoutError(...) from err` + # will bubble up the original exception in case we reach max retries. + return f'timeout' + if isinstance(err, DatabricksError): + message = str(err) + transient_error_string_matches = [ + "com.databricks.backend.manager.util.UnknownWorkerEnvironmentException", + "does not have any associated worker environments", "There is no worker environment with id", + "Unknown worker environment", "ClusterNotReadyException", "Unexpected error", + "Please try again later or try a faster operation.", + "RPC token bucket limit has been exceeded", + ] + for substring in transient_error_string_matches: + if substring not in message: + continue + return f'matched {substring}' + return None + + def _perform(self, + method: str, + url: str, + query: dict = None, + headers: dict = None, + body: dict = None, + raw: bool = False, + files=None, + data=None, + auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None): + response = self._session.request(method, + url, + params=self._fix_query_string(query), + json=body, + headers=headers, + files=files, + data=data, + auth=auth, + stream=raw, + timeout=self._http_timeout_seconds) + self._record_request_log(response, raw=raw or data is not None or files is not None) + error = self._error_parser.get_api_error(response) + if error is not None: + raise error from None + return response + + def _record_request_log(self, response: requests.Response, raw: bool = False) -> None: + if not logger.isEnabledFor(logging.DEBUG): + return + logger.debug(RoundTrip(response, self._debug_headers, self._debug_truncate_bytes, raw).generate()) + + +class _StreamingResponse(BinaryIO): + _response: requests.Response + _buffer: bytes + _content: Union[Iterator[bytes], None] + _chunk_size: Union[int, None] + _closed: bool = False + + def fileno(self) -> int: + pass + + def flush(self) -> int: + pass + + def __init__(self, response: requests.Response, chunk_size: Union[int, None] = None): + self._response = response + self._buffer = b'' + self._content = None + self._chunk_size = chunk_size + + def _open(self) -> None: + if self._closed: + raise ValueError("I/O operation on closed file") + if not self._content: + self._content = self._response.iter_content(chunk_size=self._chunk_size) + + def __enter__(self) -> BinaryIO: + self._open() + return self + + def set_chunk_size(self, chunk_size: Union[int, None]) -> None: + self._chunk_size = chunk_size + + def close(self) -> None: + self._response.close() + self._closed = True + + def isatty(self) -> bool: + return False + + def read(self, n: int = -1) -> bytes: + self._open() + read_everything = n < 0 + remaining_bytes = n + res = b'' + while remaining_bytes > 0 or read_everything: + if len(self._buffer) == 0: + try: + self._buffer = next(self._content) + except StopIteration: + break + bytes_available = len(self._buffer) + to_read = bytes_available if read_everything else min(remaining_bytes, bytes_available) + res += self._buffer[:to_read] + self._buffer = self._buffer[to_read:] + remaining_bytes -= to_read + return res + + def readable(self) -> bool: + return self._content is not None + + def readline(self, __limit: int = ...) -> bytes: + raise NotImplementedError() + + def readlines(self, __hint: int = ...) -> List[bytes]: + raise NotImplementedError() + + def seek(self, __offset: int, __whence: int = ...) -> int: + raise NotImplementedError() + + def seekable(self) -> bool: + return False + + def tell(self) -> int: + raise NotImplementedError() + + def truncate(self, __size: Union[int, None] = ...) -> int: + raise NotImplementedError() + + def writable(self) -> bool: + return False + + def write(self, s: Union[bytes, bytearray]) -> int: + raise NotImplementedError() + + def writelines(self, lines: Iterable[bytes]) -> None: + raise NotImplementedError() + + def __next__(self) -> bytes: + return self.read(1) + + def __iter__(self) -> Iterator[bytes]: + return self._content + + def __exit__(self, t: Union[Type[BaseException], None], value: Union[BaseException, None], + traceback: Union[TracebackType, None]) -> None: + self._content = None + self._buffer = b'' + self.close() diff --git a/databricks/sdk/core.py b/databricks/sdk/core.py index 77e8c9aac..eab22cd71 100644 --- a/databricks/sdk/core.py +++ b/databricks/sdk/core.py @@ -1,19 +1,13 @@ import re -from datetime import timedelta -from types import TracebackType -from typing import Any, BinaryIO, Iterator, Type +from typing import BinaryIO from urllib.parse import urlencode -from requests.adapters import HTTPAdapter - -from .casing import Casing +from ._base_client import _BaseClient from .config import * # To preserve backwards compatibility (as these definitions were previously in this module) from .credentials_provider import * -from .errors import DatabricksError, _ErrorCustomizer, _Parser -from .logger import RoundTrip +from .errors import DatabricksError, _ErrorCustomizer from .oauth import retrieve_token -from .retries import retried __all__ = ['Config', 'DatabricksError'] @@ -25,53 +19,19 @@ class ApiClient: - _cfg: Config - _RETRY_AFTER_DEFAULT: int = 1 - - def __init__(self, cfg: Config = None): - - if cfg is None: - cfg = Config() + def __init__(self, cfg: Config): self._cfg = cfg - # See https://github.com/databricks/databricks-sdk-go/blob/main/client/client.go#L34-L35 - self._debug_truncate_bytes = cfg.debug_truncate_bytes if cfg.debug_truncate_bytes else 96 - self._retry_timeout_seconds = cfg.retry_timeout_seconds if cfg.retry_timeout_seconds else 300 - self._user_agent_base = cfg.user_agent - self._session = requests.Session() - self._session.auth = self._authenticate - - # Number of urllib3 connection pools to cache before discarding the least - # recently used pool. Python requests default value is 10. - pool_connections = cfg.max_connection_pools - if pool_connections is None: - pool_connections = 20 - - # The maximum number of connections to save in the pool. Improves performance - # in multithreaded situations. For now, we're setting it to the same value - # as connection_pool_size. - pool_maxsize = cfg.max_connections_per_pool - if cfg.max_connections_per_pool is None: - pool_maxsize = pool_connections - - # If pool_block is False, then more connections will are created, - # but not saved after the first use. Blocks when no free connections are available. - # urllib3 ensures that no more than pool_maxsize connections are used at a time. - # Prevents platform from flooding. By default, requests library doesn't block. - pool_block = True - - # We don't use `max_retries` from HTTPAdapter to align with a more production-ready - # retry strategy established in the Databricks SDK for Go. See _is_retryable and - # @retried for more details. - http_adapter = HTTPAdapter(pool_connections=pool_connections, - pool_maxsize=pool_maxsize, - pool_block=pool_block) - self._session.mount("https://", http_adapter) - - # Default to 60 seconds - self._http_timeout_seconds = cfg.http_timeout_seconds if cfg.http_timeout_seconds else 60 - - self._error_parser = _Parser(extra_error_customizers=[_AddDebugErrorCustomizer(cfg)]) + self._api_client = _BaseClient(debug_truncate_bytes=cfg.debug_truncate_bytes, + retry_timeout_seconds=cfg.retry_timeout_seconds, + user_agent_base=cfg.user_agent, + header_factory=cfg.authenticate, + max_connection_pools=cfg.max_connection_pools, + max_connections_per_pool=cfg.max_connections_per_pool, + pool_block=True, + http_timeout_seconds=cfg.http_timeout_seconds, + extra_error_customizers=[_AddDebugErrorCustomizer(cfg)], + clock=cfg.clock) @property def account_id(self) -> str: @@ -81,40 +41,6 @@ def account_id(self) -> str: def is_account_client(self) -> bool: return self._cfg.is_account_client - def _authenticate(self, r: requests.PreparedRequest) -> requests.PreparedRequest: - headers = self._cfg.authenticate() - for k, v in headers.items(): - r.headers[k] = v - return r - - @staticmethod - def _fix_query_string(query: Optional[dict] = None) -> Optional[dict]: - # Convert True -> "true" for Databricks APIs to understand booleans. - # See: https://github.com/databricks/databricks-sdk-py/issues/142 - if query is None: - return None - with_fixed_bools = {k: v if type(v) != bool else ('true' if v else 'false') for k, v in query.items()} - - # Query parameters may be nested, e.g. - # {'filter_by': {'user_ids': [123, 456]}} - # The HTTP-compatible representation of this is - # filter_by.user_ids=123&filter_by.user_ids=456 - # To achieve this, we convert the above dictionary to - # {'filter_by.user_ids': [123, 456]} - # See the following for more information: - # https://cloud.google.com/endpoints/docs/grpc-service-config/reference/rpc/google.api#google.api.HttpRule - def flatten_dict(d: Dict[str, Any]) -> Dict[str, Any]: - for k1, v1 in d.items(): - if isinstance(v1, dict): - v1 = dict(flatten_dict(v1)) - for k2, v2 in v1.items(): - yield f"{k1}.{k2}", v2 - else: - yield k1, v1 - - flattened = dict(flatten_dict(with_fixed_bools)) - return flattened - def get_oauth_token(self, auth_details: str) -> Token: if not self._cfg.auth_type: self._cfg.authenticate() @@ -142,115 +68,22 @@ def do(self, files=None, data=None, auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None, - response_headers: List[str] = None) -> Union[dict, BinaryIO]: - if headers is None: - headers = {} + response_headers: List[str] = None) -> Union[dict, list, BinaryIO]: if url is None: # Remove extra `/` from path for Files API # Once we've fixed the OpenAPI spec, we can remove this path = re.sub('^/api/2.0/fs/files//', '/api/2.0/fs/files/', path) url = f"{self._cfg.host}{path}" - headers['User-Agent'] = self._user_agent_base - retryable = retried(timeout=timedelta(seconds=self._retry_timeout_seconds), - is_retryable=self._is_retryable, - clock=self._cfg.clock) - response = retryable(self._perform)(method, - url, - query=query, - headers=headers, - body=body, - raw=raw, - files=files, - data=data, - auth=auth) - - resp = dict() - for header in response_headers if response_headers else []: - resp[header] = response.headers.get(Casing.to_header_case(header)) - if raw: - resp["contents"] = StreamingResponse(response) - return resp - if not len(response.content): - return resp - - jsonResponse = response.json() - if jsonResponse is None: - return resp - - if isinstance(jsonResponse, list): - return jsonResponse - - return {**resp, **jsonResponse} - - @staticmethod - def _is_retryable(err: BaseException) -> Optional[str]: - # this method is Databricks-specific port of urllib3 retries - # (see https://github.com/urllib3/urllib3/blob/main/src/urllib3/util/retry.py) - # and Databricks SDK for Go retries - # (see https://github.com/databricks/databricks-sdk-go/blob/main/apierr/errors.go) - from urllib3.exceptions import ProxyError - if isinstance(err, ProxyError): - err = err.original_error - if isinstance(err, requests.ConnectionError): - # corresponds to `connection reset by peer` and `connection refused` errors from Go, - # which are generally related to the temporary glitches in the networking stack, - # also caused by endpoint protection software, like ZScaler, to drop connections while - # not yet authenticated. - # - # return a simple string for debug log readability, as `raise TimeoutError(...) from err` - # will bubble up the original exception in case we reach max retries. - return f'cannot connect' - if isinstance(err, requests.Timeout): - # corresponds to `TLS handshake timeout` and `i/o timeout` in Go. - # - # return a simple string for debug log readability, as `raise TimeoutError(...) from err` - # will bubble up the original exception in case we reach max retries. - return f'timeout' - if isinstance(err, DatabricksError): - message = str(err) - transient_error_string_matches = [ - "com.databricks.backend.manager.util.UnknownWorkerEnvironmentException", - "does not have any associated worker environments", "There is no worker environment with id", - "Unknown worker environment", "ClusterNotReadyException", "Unexpected error", - "Please try again later or try a faster operation.", - "RPC token bucket limit has been exceeded", - ] - for substring in transient_error_string_matches: - if substring not in message: - continue - return f'matched {substring}' - return None - - def _perform(self, - method: str, - url: str, - query: dict = None, - headers: dict = None, - body: dict = None, - raw: bool = False, - files=None, - data=None, - auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None): - response = self._session.request(method, - url, - params=self._fix_query_string(query), - json=body, - headers=headers, - files=files, - data=data, - auth=auth, - stream=raw, - timeout=self._http_timeout_seconds) - self._record_request_log(response, raw=raw or data is not None or files is not None) - error = self._error_parser.get_api_error(response) - if error is not None: - raise error from None - return response - - def _record_request_log(self, response: requests.Response, raw: bool = False) -> None: - if not logger.isEnabledFor(logging.DEBUG): - return - logger.debug(RoundTrip(response, self._cfg.debug_headers, self._debug_truncate_bytes, raw).generate()) + return self._api_client.do(method=method, + url=url, + query=query, + headers=headers, + body=body, + raw=raw, + files=files, + data=data, + auth=auth, + response_headers=response_headers) class _AddDebugErrorCustomizer(_ErrorCustomizer): @@ -264,103 +97,3 @@ def customize_error(self, response: requests.Response, kwargs: dict): if response.status_code in (401, 403): message = kwargs.get('message', 'request failed') kwargs['message'] = self._cfg.wrap_debug_info(message) - - -class StreamingResponse(BinaryIO): - _response: requests.Response - _buffer: bytes - _content: Union[Iterator[bytes], None] - _chunk_size: Union[int, None] - _closed: bool = False - - def fileno(self) -> int: - pass - - def flush(self) -> int: - pass - - def __init__(self, response: requests.Response, chunk_size: Union[int, None] = None): - self._response = response - self._buffer = b'' - self._content = None - self._chunk_size = chunk_size - - def _open(self) -> None: - if self._closed: - raise ValueError("I/O operation on closed file") - if not self._content: - self._content = self._response.iter_content(chunk_size=self._chunk_size) - - def __enter__(self) -> BinaryIO: - self._open() - return self - - def set_chunk_size(self, chunk_size: Union[int, None]) -> None: - self._chunk_size = chunk_size - - def close(self) -> None: - self._response.close() - self._closed = True - - def isatty(self) -> bool: - return False - - def read(self, n: int = -1) -> bytes: - self._open() - read_everything = n < 0 - remaining_bytes = n - res = b'' - while remaining_bytes > 0 or read_everything: - if len(self._buffer) == 0: - try: - self._buffer = next(self._content) - except StopIteration: - break - bytes_available = len(self._buffer) - to_read = bytes_available if read_everything else min(remaining_bytes, bytes_available) - res += self._buffer[:to_read] - self._buffer = self._buffer[to_read:] - remaining_bytes -= to_read - return res - - def readable(self) -> bool: - return self._content is not None - - def readline(self, __limit: int = ...) -> bytes: - raise NotImplementedError() - - def readlines(self, __hint: int = ...) -> List[bytes]: - raise NotImplementedError() - - def seek(self, __offset: int, __whence: int = ...) -> int: - raise NotImplementedError() - - def seekable(self) -> bool: - return False - - def tell(self) -> int: - raise NotImplementedError() - - def truncate(self, __size: Union[int, None] = ...) -> int: - raise NotImplementedError() - - def writable(self) -> bool: - return False - - def write(self, s: Union[bytes, bytearray]) -> int: - raise NotImplementedError() - - def writelines(self, lines: Iterable[bytes]) -> None: - raise NotImplementedError() - - def __next__(self) -> bytes: - return self.read(1) - - def __iter__(self) -> Iterator[bytes]: - return self._content - - def __exit__(self, t: Union[Type[BaseException], None], value: Union[BaseException, None], - traceback: Union[TracebackType, None]) -> None: - self._content = None - self._buffer = b'' - self.close() diff --git a/tests/fixture_server.py b/tests/fixture_server.py new file mode 100644 index 000000000..e15f9cf2d --- /dev/null +++ b/tests/fixture_server.py @@ -0,0 +1,31 @@ +import contextlib +import functools +import typing +from http.server import BaseHTTPRequestHandler + + +@contextlib.contextmanager +def http_fixture_server(handler: typing.Callable[[BaseHTTPRequestHandler], None]): + from http.server import HTTPServer + from threading import Thread + + class _handler(BaseHTTPRequestHandler): + + def __init__(self, handler: typing.Callable[[BaseHTTPRequestHandler], None], *args): + self._handler = handler + super().__init__(*args) + + def __getattr__(self, item): + if 'do_' != item[0:3]: + raise AttributeError(f'method {item} not found') + return functools.partial(self._handler, self) + + handler_factory = functools.partial(_handler, handler) + srv = HTTPServer(('localhost', 0), handler_factory) + t = Thread(target=srv.serve_forever) + try: + t.daemon = True + t.start() + yield 'http://{0}:{1}'.format(*srv.server_address) + finally: + srv.shutdown() diff --git a/tests/test_base_client.py b/tests/test_base_client.py new file mode 100644 index 000000000..e9e7324a9 --- /dev/null +++ b/tests/test_base_client.py @@ -0,0 +1,278 @@ +from http.server import BaseHTTPRequestHandler +from typing import Iterator, List + +import pytest +import requests + +from databricks.sdk import errors, useragent +from databricks.sdk._base_client import _BaseClient, _StreamingResponse +from databricks.sdk.core import DatabricksError + +from .clock import FakeClock +from .fixture_server import http_fixture_server + + +class DummyResponse(requests.Response): + _content: Iterator[bytes] + _closed: bool = False + + def __init__(self, content: List[bytes]) -> None: + super().__init__() + self._content = iter(content) + + def iter_content(self, chunk_size: int = 1, decode_unicode=False) -> Iterator[bytes]: + return self._content + + def close(self): + self._closed = True + + def isClosed(self): + return self._closed + + +def test_streaming_response_read(config): + content = b"some initial binary data: \x00\x01" + response = _StreamingResponse(DummyResponse([content])) + assert response.read() == content + + +def test_streaming_response_read_partial(config): + content = b"some initial binary data: \x00\x01" + response = _StreamingResponse(DummyResponse([content])) + assert response.read(8) == b"some ini" + + +def test_streaming_response_read_full(config): + content = b"some initial binary data: \x00\x01" + response = _StreamingResponse(DummyResponse([content, content])) + assert response.read() == content + content + + +def test_streaming_response_read_closes(config): + content = b"some initial binary data: \x00\x01" + dummy_response = DummyResponse([content]) + with _StreamingResponse(dummy_response) as response: + assert response.read() == content + assert dummy_response.isClosed() + + +@pytest.mark.parametrize('status_code,headers,body,expected_error', [ + (400, {}, { + "message": + "errorMessage", + "details": [{ + "type": DatabricksError._error_info_type, + "reason": "error reason", + "domain": "error domain", + "metadata": { + "etag": "error etag" + }, + }, { + "type": "wrong type", + "reason": "wrong reason", + "domain": "wrong domain", + "metadata": { + "etag": "wrong etag" + } + }], + }, + errors.BadRequest('errorMessage', + details=[{ + 'type': DatabricksError._error_info_type, + 'reason': 'error reason', + 'domain': 'error domain', + 'metadata': { + 'etag': 'error etag' + }, + }])), + (401, {}, { + 'error_code': 'UNAUTHORIZED', + 'message': 'errorMessage', + }, errors.Unauthenticated('errorMessage', error_code='UNAUTHORIZED')), + (403, {}, { + 'error_code': 'FORBIDDEN', + 'message': 'errorMessage', + }, errors.PermissionDenied('errorMessage', error_code='FORBIDDEN')), + (429, {}, { + 'error_code': 'TOO_MANY_REQUESTS', + 'message': 'errorMessage', + }, errors.TooManyRequests('errorMessage', error_code='TOO_MANY_REQUESTS', retry_after_secs=1)), + (429, { + 'Retry-After': '100' + }, { + 'error_code': 'TOO_MANY_REQUESTS', + 'message': 'errorMessage', + }, errors.TooManyRequests('errorMessage', error_code='TOO_MANY_REQUESTS', retry_after_secs=100)), + (503, {}, { + 'error_code': 'TEMPORARILY_UNAVAILABLE', + 'message': 'errorMessage', + }, errors.TemporarilyUnavailable('errorMessage', error_code='TEMPORARILY_UNAVAILABLE', + retry_after_secs=1)), + (503, { + 'Retry-After': '100' + }, { + 'error_code': 'TEMPORARILY_UNAVAILABLE', + 'message': 'errorMessage', + }, + errors.TemporarilyUnavailable('errorMessage', error_code='TEMPORARILY_UNAVAILABLE', + retry_after_secs=100)), + (404, {}, { + 'scimType': 'scim type', + 'detail': 'detail', + 'status': 'status', + }, errors.NotFound('scim type detail', error_code='SCIM_status')), +]) +def test_error(requests_mock, status_code, headers, body, expected_error): + client = _BaseClient(clock=FakeClock()) + requests_mock.get("/test", json=body, status_code=status_code, headers=headers) + with pytest.raises(DatabricksError) as raised: + client._perform("GET", "https://localhost/test", headers={"test": "test"}) + actual = raised.value + assert isinstance(actual, type(expected_error)) + assert str(actual) == str(expected_error) + assert actual.error_code == expected_error.error_code + assert actual.retry_after_secs == expected_error.retry_after_secs + expected_error_infos, actual_error_infos = expected_error.get_error_info(), actual.get_error_info() + assert len(expected_error_infos) == len(actual_error_infos) + for expected, actual in zip(expected_error_infos, actual_error_infos): + assert expected.type == actual.type + assert expected.reason == actual.reason + assert expected.domain == actual.domain + assert expected.metadata == actual.metadata + + +def test_api_client_do_custom_headers(requests_mock): + client = _BaseClient() + requests_mock.get("/test", + json={"well": "done"}, + request_headers={ + "test": "test", + "User-Agent": useragent.to_string() + }) + res = client.do("GET", "https://localhost/test", headers={"test": "test"}) + assert res == {"well": "done"} + + +@pytest.mark.parametrize('status_code,include_retry_after', + ((429, False), (429, True), (503, False), (503, True))) +def test_http_retry_after(status_code, include_retry_after): + requests = [] + + def inner(h: BaseHTTPRequestHandler): + if len(requests) == 0: + h.send_response(status_code) + if include_retry_after: + h.send_header('Retry-After', '1') + h.send_header('Content-Type', 'application/json') + h.end_headers() + else: + h.send_response(200) + h.send_header('Content-Type', 'application/json') + h.end_headers() + h.wfile.write(b'{"foo": 1}') + requests.append(h.requestline) + + with http_fixture_server(inner) as host: + api_client = _BaseClient(clock=FakeClock()) + res = api_client.do('GET', f'{host}/foo') + assert 'foo' in res + + assert len(requests) == 2 + + +def test_http_retry_after_wrong_format(): + requests = [] + + def inner(h: BaseHTTPRequestHandler): + if len(requests) == 0: + h.send_response(429) + h.send_header('Retry-After', '1.58') + h.end_headers() + else: + h.send_response(200) + h.send_header('Content-Type', 'application/json') + h.end_headers() + h.wfile.write(b'{"foo": 1}') + requests.append(h.requestline) + + with http_fixture_server(inner) as host: + api_client = _BaseClient(clock=FakeClock()) + res = api_client.do('GET', f'{host}/foo') + assert 'foo' in res + + assert len(requests) == 2 + + +def test_http_retried_exceed_limit(): + requests = [] + + def inner(h: BaseHTTPRequestHandler): + h.send_response(429) + h.send_header('Retry-After', '1') + h.end_headers() + requests.append(h.requestline) + + with http_fixture_server(inner) as host: + api_client = _BaseClient(retry_timeout_seconds=1, clock=FakeClock()) + with pytest.raises(TimeoutError): + api_client.do('GET', f'{host}/foo') + + assert len(requests) == 1 + + +def test_http_retried_on_match(): + requests = [] + + def inner(h: BaseHTTPRequestHandler): + if len(requests) == 0: + h.send_response(400) + h.end_headers() + h.wfile.write(b'{"error_code": "abc", "message": "... ClusterNotReadyException ..."}') + else: + h.send_response(200) + h.end_headers() + h.wfile.write(b'{"foo": 1}') + requests.append(h.requestline) + + with http_fixture_server(inner) as host: + api_client = _BaseClient(clock=FakeClock()) + res = api_client.do('GET', f'{host}/foo') + assert 'foo' in res + + assert len(requests) == 2 + + +def test_http_not_retried_on_normal_errors(): + requests = [] + + def inner(h: BaseHTTPRequestHandler): + if len(requests) == 0: + h.send_response(400) + h.end_headers() + h.wfile.write(b'{"error_code": "abc", "message": "something not found"}') + requests.append(h.requestline) + + with http_fixture_server(inner) as host: + api_client = _BaseClient(clock=FakeClock()) + with pytest.raises(DatabricksError): + api_client.do('GET', f'{host}/foo') + + assert len(requests) == 1 + + +def test_http_retried_on_connection_error(): + requests = [] + + def inner(h: BaseHTTPRequestHandler): + if len(requests) > 0: + h.send_response(200) + h.end_headers() + h.wfile.write(b'{"foo": 1}') + requests.append(h.requestline) + + with http_fixture_server(inner) as host: + api_client = _BaseClient(clock=FakeClock()) + res = api_client.do('GET', f'{host}/foo') + assert 'foo' in res + + assert len(requests) == 2 diff --git a/tests/test_core.py b/tests/test_core.py index d54563d4e..16a4c2ad6 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,21 +1,15 @@ -import contextlib -import functools import os import pathlib import platform import random import string -import typing from datetime import datetime from http.server import BaseHTTPRequestHandler -from typing import Iterator, List import pytest -import requests from databricks.sdk import WorkspaceClient, errors -from databricks.sdk.core import (ApiClient, Config, DatabricksError, - StreamingResponse) +from databricks.sdk.core import ApiClient, Config, DatabricksError from databricks.sdk.credentials_provider import (CliTokenSource, CredentialsProvider, CredentialsStrategy, @@ -28,8 +22,8 @@ from databricks.sdk.service.iam import AccessControlRequest from databricks.sdk.version import __version__ -from .clock import FakeClock from .conftest import noop_credentials +from .fixture_server import http_fixture_server def test_parse_dsn(): @@ -80,32 +74,6 @@ def write_small_dummy_executable(path: pathlib.Path): return cli -def test_streaming_response_read(config): - content = b"some initial binary data: \x00\x01" - response = StreamingResponse(DummyResponse([content])) - assert response.read() == content - - -def test_streaming_response_read_partial(config): - content = b"some initial binary data: \x00\x01" - response = StreamingResponse(DummyResponse([content])) - assert response.read(8) == b"some ini" - - -def test_streaming_response_read_full(config): - content = b"some initial binary data: \x00\x01" - response = StreamingResponse(DummyResponse([content, content])) - assert response.read() == content + content - - -def test_streaming_response_read_closes(config): - content = b"some initial binary data: \x00\x01" - dummy_response = DummyResponse([content]) - with StreamingResponse(dummy_response) as response: - assert response.read() == content - assert dummy_response.isClosed() - - def write_large_dummy_executable(path: pathlib.Path): cli = path.joinpath('databricks') @@ -290,36 +258,6 @@ def test_config_parsing_non_string_env_vars(monkeypatch): assert c.debug_truncate_bytes == 100 -class DummyResponse(requests.Response): - _content: Iterator[bytes] - _closed: bool = False - - def __init__(self, content: List[bytes]) -> None: - super().__init__() - self._content = iter(content) - - def iter_content(self, chunk_size: int = 1, decode_unicode=False) -> Iterator[bytes]: - return self._content - - def close(self): - self._closed = True - - def isClosed(self): - return self._closed - - -def test_api_client_do_custom_headers(config, requests_mock): - client = ApiClient(config) - requests_mock.get("/test", - json={"well": "done"}, - request_headers={ - "test": "test", - "User-Agent": config.user_agent - }) - res = client.do("GET", "/test", headers={"test": "test"}) - assert res == {"well": "done"} - - def test_access_control_list(config, requests_mock): requests_mock.post("http://localhost/api/2.1/jobs/create", request_headers={"User-Agent": config.user_agent}) @@ -359,81 +297,25 @@ def test_deletes(config, requests_mock): assert res is None -@pytest.mark.parametrize('status_code,headers,body,expected_error', [ - (400, {}, { - "message": - "errorMessage", - "details": [{ - "type": DatabricksError._error_info_type, - "reason": "error reason", - "domain": "error domain", - "metadata": { - "etag": "error etag" - }, - }, { - "type": "wrong type", - "reason": "wrong reason", - "domain": "wrong domain", - "metadata": { - "etag": "wrong etag" - } - }], - }, - errors.BadRequest('errorMessage', - details=[{ - 'type': DatabricksError._error_info_type, - 'reason': 'error reason', - 'domain': 'error domain', - 'metadata': { - 'etag': 'error etag' - }, - }])), - (401, {}, { +@pytest.mark.parametrize( + 'status_code,headers,body,expected_error', + [(401, {}, { 'error_code': 'UNAUTHORIZED', 'message': 'errorMessage', }, - errors.Unauthenticated('errorMessage. Config: host=http://localhost, auth_type=noop', - error_code='UNAUTHORIZED')), - (403, {}, { - 'error_code': 'FORBIDDEN', - 'message': 'errorMessage', - }, - errors.PermissionDenied('errorMessage. Config: host=http://localhost, auth_type=noop', - error_code='FORBIDDEN')), - (429, {}, { - 'error_code': 'TOO_MANY_REQUESTS', - 'message': 'errorMessage', - }, errors.TooManyRequests('errorMessage', error_code='TOO_MANY_REQUESTS', retry_after_secs=1)), - (429, { - 'Retry-After': '100' - }, { - 'error_code': 'TOO_MANY_REQUESTS', - 'message': 'errorMessage', - }, errors.TooManyRequests('errorMessage', error_code='TOO_MANY_REQUESTS', retry_after_secs=100)), - (503, {}, { - 'error_code': 'TEMPORARILY_UNAVAILABLE', - 'message': 'errorMessage', - }, errors.TemporarilyUnavailable('errorMessage', error_code='TEMPORARILY_UNAVAILABLE', - retry_after_secs=1)), - (503, { - 'Retry-After': '100' - }, { - 'error_code': 'TEMPORARILY_UNAVAILABLE', - 'message': 'errorMessage', - }, - errors.TemporarilyUnavailable('errorMessage', error_code='TEMPORARILY_UNAVAILABLE', - retry_after_secs=100)), - (404, {}, { - 'scimType': 'scim type', - 'detail': 'detail', - 'status': 'status', - }, errors.NotFound('scim type detail', error_code='SCIM_status')), -]) + errors.Unauthenticated('errorMessage. Config: host=http://localhost, auth_type=noop', + error_code='UNAUTHORIZED')), + (403, {}, { + 'error_code': 'FORBIDDEN', + 'message': 'errorMessage', + }, + errors.PermissionDenied('errorMessage. Config: host=http://localhost, auth_type=noop', + error_code='FORBIDDEN')), ]) def test_error(config, requests_mock, status_code, headers, body, expected_error): client = ApiClient(config) requests_mock.get("/test", json=body, status_code=status_code, headers=headers) with pytest.raises(DatabricksError) as raised: - client._perform("GET", "http://localhost/test", headers={"test": "test"}) + client.do("GET", "/test", headers={"test": "test"}) actual = raised.value assert isinstance(actual, type(expected_error)) assert str(actual) == str(expected_error) @@ -448,158 +330,6 @@ def test_error(config, requests_mock, status_code, headers, body, expected_error assert expected.metadata == actual.metadata -@contextlib.contextmanager -def http_fixture_server(handler: typing.Callable[[BaseHTTPRequestHandler], None]): - from http.server import HTTPServer - from threading import Thread - - class _handler(BaseHTTPRequestHandler): - - def __init__(self, handler: typing.Callable[[BaseHTTPRequestHandler], None], *args): - self._handler = handler - super().__init__(*args) - - def __getattr__(self, item): - if 'do_' != item[0:3]: - raise AttributeError(f'method {item} not found') - return functools.partial(self._handler, self) - - handler_factory = functools.partial(_handler, handler) - srv = HTTPServer(('localhost', 0), handler_factory) - t = Thread(target=srv.serve_forever) - try: - t.daemon = True - t.start() - yield 'http://{0}:{1}'.format(*srv.server_address) - finally: - srv.shutdown() - - -@pytest.mark.parametrize('status_code,include_retry_after', - ((429, False), (429, True), (503, False), (503, True))) -def test_http_retry_after(status_code, include_retry_after): - requests = [] - - def inner(h: BaseHTTPRequestHandler): - if len(requests) == 0: - h.send_response(status_code) - if include_retry_after: - h.send_header('Retry-After', '1') - h.send_header('Content-Type', 'application/json') - h.end_headers() - else: - h.send_response(200) - h.send_header('Content-Type', 'application/json') - h.end_headers() - h.wfile.write(b'{"foo": 1}') - requests.append(h.requestline) - - with http_fixture_server(inner) as host: - api_client = ApiClient(Config(host=host, token='_', clock=FakeClock())) - res = api_client.do('GET', '/foo') - assert 'foo' in res - - assert len(requests) == 2 - - -def test_http_retry_after_wrong_format(): - requests = [] - - def inner(h: BaseHTTPRequestHandler): - if len(requests) == 0: - h.send_response(429) - h.send_header('Retry-After', '1.58') - h.end_headers() - else: - h.send_response(200) - h.send_header('Content-Type', 'application/json') - h.end_headers() - h.wfile.write(b'{"foo": 1}') - requests.append(h.requestline) - - with http_fixture_server(inner) as host: - api_client = ApiClient(Config(host=host, token='_', clock=FakeClock())) - res = api_client.do('GET', '/foo') - assert 'foo' in res - - assert len(requests) == 2 - - -def test_http_retried_exceed_limit(): - requests = [] - - def inner(h: BaseHTTPRequestHandler): - h.send_response(429) - h.send_header('Retry-After', '1') - h.end_headers() - requests.append(h.requestline) - - with http_fixture_server(inner) as host: - api_client = ApiClient(Config(host=host, token='_', retry_timeout_seconds=1, clock=FakeClock())) - with pytest.raises(TimeoutError): - api_client.do('GET', '/foo') - - assert len(requests) == 1 - - -def test_http_retried_on_match(): - requests = [] - - def inner(h: BaseHTTPRequestHandler): - if len(requests) == 0: - h.send_response(400) - h.end_headers() - h.wfile.write(b'{"error_code": "abc", "message": "... ClusterNotReadyException ..."}') - else: - h.send_response(200) - h.end_headers() - h.wfile.write(b'{"foo": 1}') - requests.append(h.requestline) - - with http_fixture_server(inner) as host: - api_client = ApiClient(Config(host=host, token='_', clock=FakeClock())) - res = api_client.do('GET', '/foo') - assert 'foo' in res - - assert len(requests) == 2 - - -def test_http_not_retried_on_normal_errors(): - requests = [] - - def inner(h: BaseHTTPRequestHandler): - if len(requests) == 0: - h.send_response(400) - h.end_headers() - h.wfile.write(b'{"error_code": "abc", "message": "something not found"}') - requests.append(h.requestline) - - with http_fixture_server(inner) as host: - api_client = ApiClient(Config(host=host, token='_', clock=FakeClock())) - with pytest.raises(DatabricksError): - api_client.do('GET', '/foo') - - assert len(requests) == 1 - - -def test_http_retried_on_connection_error(): - requests = [] - - def inner(h: BaseHTTPRequestHandler): - if len(requests) > 0: - h.send_response(200) - h.end_headers() - h.wfile.write(b'{"foo": 1}') - requests.append(h.requestline) - - with http_fixture_server(inner) as host: - api_client = ApiClient(Config(host=host, token='_', clock=FakeClock())) - res = api_client.do('GET', '/foo') - assert 'foo' in res - - assert len(requests) == 2 - - def test_github_oidc_flow_works_with_azure(monkeypatch): def inner(h: BaseHTTPRequestHandler):