Skip to content

Commit f06bb27

Browse files
authored
[Fix] Fix deserialization of 401/403 errors (#758)
## Changes #741 introduced a change to how an error message was modified in `ApiClient._perform`. Previously, arguments to the DatabricksError constructor were modified as a dictionary in `_perform`. After that change, `get_api_error` started to return a `DatabricksError` instance whose attributes were modified. The `message` attribute referred to in that change does not exist in the DatabricksError class: there is a `message` constructor parameter, but it is not set as an attribute. This PR refactors the error handling logic slightly to restore the original behavior. In doing this, we decouple all error-parsing and customizing logic out of ApiClient. This also sets us up to allow for further extension of error parsing and customization in the future, a feature that I have seen present in other SDKs. Fixes #755. ## Tests <!-- How is this tested? Please see the checklist below and also describe any other relevant tests --> - [ ] `make test` run locally - [ ] `make fmt` applied - [ ] relevant integration tests applied
1 parent c3aad28 commit f06bb27

File tree

8 files changed

+324
-184
lines changed

8 files changed

+324
-184
lines changed

databricks/sdk/core.py

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .config import *
1111
# To preserve backwards compatibility (as these definitions were previously in this module)
1212
from .credentials_provider import *
13-
from .errors import DatabricksError, get_api_error
13+
from .errors import DatabricksError, _ErrorCustomizer, _Parser
1414
from .logger import RoundTrip
1515
from .oauth import retrieve_token
1616
from .retries import retried
@@ -71,6 +71,8 @@ def __init__(self, cfg: Config = None):
7171
# Default to 60 seconds
7272
self._http_timeout_seconds = cfg.http_timeout_seconds if cfg.http_timeout_seconds else 60
7373

74+
self._error_parser = _Parser(extra_error_customizers=[_AddDebugErrorCustomizer(cfg)])
75+
7476
@property
7577
def account_id(self) -> str:
7678
return self._cfg.account_id
@@ -219,27 +221,6 @@ def _is_retryable(err: BaseException) -> Optional[str]:
219221
return f'matched {substring}'
220222
return None
221223

222-
@classmethod
223-
def _parse_retry_after(cls, response: requests.Response) -> Optional[int]:
224-
retry_after = response.headers.get("Retry-After")
225-
if retry_after is None:
226-
# 429 requests should include a `Retry-After` header, but if it's missing,
227-
# we default to 1 second.
228-
return cls._RETRY_AFTER_DEFAULT
229-
# If the request is throttled, try parse the `Retry-After` header and sleep
230-
# for the specified number of seconds. Note that this header can contain either
231-
# an integer or a RFC1123 datetime string.
232-
# See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
233-
#
234-
# For simplicity, we only try to parse it as an integer, as this is what Databricks
235-
# platform returns. Otherwise, we fall back and don't sleep.
236-
try:
237-
return int(retry_after)
238-
except ValueError:
239-
logger.debug(f'Invalid Retry-After header received: {retry_after}. Defaulting to 1')
240-
# defaulting to 1 sleep second to make self._is_retryable() simpler
241-
return cls._RETRY_AFTER_DEFAULT
242-
243224
def _perform(self,
244225
method: str,
245226
url: str,
@@ -261,15 +242,8 @@ def _perform(self,
261242
stream=raw,
262243
timeout=self._http_timeout_seconds)
263244
self._record_request_log(response, raw=raw or data is not None or files is not None)
264-
error = get_api_error(response)
245+
error = self._error_parser.get_api_error(response)
265246
if error is not None:
266-
status_code = response.status_code
267-
is_http_unauthorized_or_forbidden = status_code in (401, 403)
268-
is_too_many_requests_or_unavailable = status_code in (429, 503)
269-
if is_http_unauthorized_or_forbidden:
270-
error.message = self._cfg.wrap_debug_info(error.message)
271-
if is_too_many_requests_or_unavailable:
272-
error.retry_after_secs = self._parse_retry_after(response)
273247
raise error from None
274248
return response
275249

@@ -279,6 +253,19 @@ def _record_request_log(self, response: requests.Response, raw: bool = False) ->
279253
logger.debug(RoundTrip(response, self._cfg.debug_headers, self._debug_truncate_bytes, raw).generate())
280254

281255

256+
class _AddDebugErrorCustomizer(_ErrorCustomizer):
257+
"""An error customizer that adds debug information about the configuration to unauthenticated and
258+
unauthorized errors."""
259+
260+
def __init__(self, cfg: Config):
261+
self._cfg = cfg
262+
263+
def customize_error(self, response: requests.Response, kwargs: dict):
264+
if response.status_code in (401, 403):
265+
message = kwargs.get('message', 'request failed')
266+
kwargs['message'] = self._cfg.wrap_debug_info(message)
267+
268+
282269
class StreamingResponse(BinaryIO):
283270
_response: requests.Response
284271
_buffer: bytes

databricks/sdk/errors/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .base import DatabricksError, ErrorDetail
2-
from .mapper import _error_mapper
3-
from .parser import get_api_error
2+
from .customizer import _ErrorCustomizer
3+
from .parser import _Parser
44
from .platform import *
55
from .private_link import PrivateLinkValidationError
66
from .sdk import *
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import abc
2+
import logging
3+
4+
import requests
5+
6+
7+
class _ErrorCustomizer(abc.ABC):
8+
"""A customizer for errors from the Databricks REST API."""
9+
10+
@abc.abstractmethod
11+
def customize_error(self, response: requests.Response, kwargs: dict):
12+
"""Customize the error constructor parameters."""
13+
14+
15+
class _RetryAfterCustomizer(_ErrorCustomizer):
16+
"""An error customizer that sets the retry_after_secs parameter based on the Retry-After header."""
17+
18+
_DEFAULT_RETRY_AFTER_SECONDS = 1
19+
"""The default number of seconds to wait before retrying a request if the Retry-After header is missing or is not
20+
a valid integer."""
21+
22+
@classmethod
23+
def _parse_retry_after(cls, response: requests.Response) -> int:
24+
retry_after = response.headers.get("Retry-After")
25+
if retry_after is None:
26+
logging.debug(
27+
f'No Retry-After header received in response with status code 429 or 503. Defaulting to {cls._DEFAULT_RETRY_AFTER_SECONDS}'
28+
)
29+
# 429 requests should include a `Retry-After` header, but if it's missing,
30+
# we default to 1 second.
31+
return cls._DEFAULT_RETRY_AFTER_SECONDS
32+
# If the request is throttled, try parse the `Retry-After` header and sleep
33+
# for the specified number of seconds. Note that this header can contain either
34+
# an integer or a RFC1123 datetime string.
35+
# See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
36+
#
37+
# For simplicity, we only try to parse it as an integer, as this is what Databricks
38+
# platform returns. Otherwise, we fall back and don't sleep.
39+
try:
40+
return int(retry_after)
41+
except ValueError:
42+
logging.debug(
43+
f'Invalid Retry-After header received: {retry_after}. Defaulting to {cls._DEFAULT_RETRY_AFTER_SECONDS}'
44+
)
45+
# defaulting to 1 sleep second to make self._is_retryable() simpler
46+
return cls._DEFAULT_RETRY_AFTER_SECONDS
47+
48+
def customize_error(self, response: requests.Response, kwargs: dict):
49+
if response.status_code in (429, 503):
50+
kwargs['retry_after_secs'] = self._parse_retry_after(response)
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import abc
2+
import json
3+
import logging
4+
import re
5+
from typing import Optional
6+
7+
import requests
8+
9+
10+
class _ErrorDeserializer(abc.ABC):
11+
"""A parser for errors from the Databricks REST API."""
12+
13+
@abc.abstractmethod
14+
def deserialize_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]:
15+
"""Parses an error from the Databricks REST API. If the error cannot be parsed, returns None."""
16+
17+
18+
class _EmptyDeserializer(_ErrorDeserializer):
19+
"""A parser that handles empty responses."""
20+
21+
def deserialize_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]:
22+
if len(response_body) == 0:
23+
return {'message': response.reason}
24+
return None
25+
26+
27+
class _StandardErrorDeserializer(_ErrorDeserializer):
28+
"""
29+
Parses errors from the Databricks REST API using the standard error format.
30+
"""
31+
32+
def deserialize_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]:
33+
try:
34+
payload_str = response_body.decode('utf-8')
35+
resp = json.loads(payload_str)
36+
except UnicodeDecodeError as e:
37+
logging.debug('_StandardErrorParser: unable to decode response using utf-8', exc_info=e)
38+
return None
39+
except json.JSONDecodeError as e:
40+
logging.debug('_StandardErrorParser: unable to deserialize response as json', exc_info=e)
41+
return None
42+
if not isinstance(resp, dict):
43+
logging.debug('_StandardErrorParser: response is valid JSON but not a dictionary')
44+
return None
45+
46+
error_args = {
47+
'message': resp.get('message', 'request failed'),
48+
'error_code': resp.get('error_code'),
49+
'details': resp.get('details'),
50+
}
51+
52+
# Handle API 1.2-style errors
53+
if 'error' in resp:
54+
error_args['message'] = resp['error']
55+
56+
# Handle SCIM Errors
57+
detail = resp.get('detail')
58+
status = resp.get('status')
59+
scim_type = resp.get('scimType')
60+
if detail:
61+
# Handle SCIM error message details
62+
# @see https://tools.ietf.org/html/rfc7644#section-3.7.3
63+
if detail == "null":
64+
detail = "SCIM API Internal Error"
65+
error_args['message'] = f"{scim_type} {detail}".strip(" ")
66+
error_args['error_code'] = f"SCIM_{status}"
67+
return error_args
68+
69+
70+
class _StringErrorDeserializer(_ErrorDeserializer):
71+
"""
72+
Parses errors from the Databricks REST API in the format "ERROR_CODE: MESSAGE".
73+
"""
74+
75+
__STRING_ERROR_REGEX = re.compile(r'([A-Z_]+): (.*)')
76+
77+
def deserialize_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]:
78+
payload_str = response_body.decode('utf-8')
79+
match = self.__STRING_ERROR_REGEX.match(payload_str)
80+
if not match:
81+
logging.debug('_StringErrorParser: unable to parse response as string')
82+
return None
83+
error_code, message = match.groups()
84+
return {'error_code': error_code, 'message': message, 'status': response.status_code, }
85+
86+
87+
class _HtmlErrorDeserializer(_ErrorDeserializer):
88+
"""
89+
Parses errors from the Databricks REST API in HTML format.
90+
"""
91+
92+
__HTML_ERROR_REGEXES = [re.compile(r'<pre>(.*)</pre>'), re.compile(r'<title>(.*)</title>'), ]
93+
94+
def deserialize_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]:
95+
payload_str = response_body.decode('utf-8')
96+
for regex in self.__HTML_ERROR_REGEXES:
97+
match = regex.search(payload_str)
98+
if match:
99+
message = match.group(1) if match.group(1) else response.reason
100+
return {
101+
'status': response.status_code,
102+
'message': message,
103+
'error_code': response.reason.upper().replace(' ', '_')
104+
}
105+
logging.debug('_HtmlErrorParser: no <pre> tag found in error response')
106+
return None

0 commit comments

Comments
 (0)