Skip to content

Commit c18dc36

Browse files
committed
Merge branch 'main' into update-api-spec
2 parents 9d27212 + 21f8ff7 commit c18dc36

File tree

15 files changed

+784
-12
lines changed

15 files changed

+784
-12
lines changed

NEXT_CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44

55
### New Features and Improvements
66

7-
* Add native support for authentication through Azure DevOps OIDC
7+
* Add native support for authentication through Azure DevOps OIDC.
88

99
### Bug Fixes
10+
* Fix a security issue that resulted in bearer tokens being logged in exception messages.
1011

1112
### Documentation
1213

databricks/sdk/_base_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,10 @@ def __init__(
9999
# Default to 60 seconds
100100
self._http_timeout_seconds = http_timeout_seconds or 60
101101

102-
self._error_parser = _Parser(extra_error_customizers=extra_error_customizers)
102+
self._error_parser = _Parser(
103+
extra_error_customizers=extra_error_customizers,
104+
debug_headers=debug_headers,
105+
)
103106

104107
def _authenticate(self, r: requests.PreparedRequest) -> requests.PreparedRequest:
105108
if self._header_factory:

databricks/sdk/common/lro.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from datetime import timedelta
2+
from typing import Optional
3+
4+
5+
class LroOptions:
6+
"""LroOptions is the options for the Long Running Operations.
7+
DO NOT USE THIS OPTION. This option is still under development
8+
and can be updated in the future without notice.
9+
"""
10+
11+
def __init__(self, *, timeout: Optional[timedelta] = None):
12+
"""
13+
Args:
14+
timeout: The timeout for the Long Running Operations.
15+
If not set, the default timeout is 20 minutes.
16+
"""
17+
self.timeout = timeout or timedelta(minutes=20)

databricks/sdk/common/types/__init__.py

Whitespace-only changes.
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
class FieldMask(object):
2+
"""Class for FieldMask message type."""
3+
4+
# This is based on the base implementation from protobuf.
5+
# https://pigweed.googlesource.com/third_party/github/protocolbuffers/protobuf/+/HEAD/python/google/protobuf/internal/field_mask.py
6+
# The original implementation only works with proto generated classes.
7+
# Since our classes are not generated from proto files, we need to implement it manually.
8+
9+
def __init__(self, field_mask=None):
10+
"""Initializes the FieldMask."""
11+
if field_mask:
12+
self.paths = field_mask
13+
14+
def ToJsonString(self) -> str:
15+
"""Converts FieldMask to string."""
16+
return ",".join(self.paths)
17+
18+
def FromJsonString(self, value: str) -> None:
19+
"""Converts string to FieldMask."""
20+
if not isinstance(value, str):
21+
raise ValueError("FieldMask JSON value not a string: {!r}".format(value))
22+
if value:
23+
self.paths = value.split(",")
24+
else:
25+
self.paths = []
26+
27+
def __eq__(self, other) -> bool:
28+
"""Check equality based on paths."""
29+
if not isinstance(other, FieldMask):
30+
return False
31+
return self.paths == other.paths
32+
33+
def __hash__(self) -> int:
34+
"""Hash based on paths tuple."""
35+
return hash(tuple(self.paths))
36+
37+
def __repr__(self) -> str:
38+
"""String representation for debugging."""
39+
return f"FieldMask(paths={self.paths})"

databricks/sdk/dbutils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,11 @@ def __init__(self) -> None:
210210
class RemoteDbUtils:
211211

212212
def __init__(self, config: "Config" = None):
213-
self._config = Config() if not config else config
213+
# Create a shallow copy of the config to allow the use of a custom
214+
# user-agent while avoiding modifying the original config.
215+
self._config = Config() if not config else config.copy()
216+
self._config.with_user_agent_extra("dbutils", "remote")
217+
214218
self._client = ApiClient(self._config)
215219
self._clusters = compute_ext.ClustersExt(self._client)
216220
self._commands = compute.CommandExecutionAPI(self._client)

databricks/sdk/errors/parser.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,15 @@
3131
]
3232

3333

34-
def _unknown_error(response: requests.Response) -> str:
34+
def _unknown_error(response: requests.Response, debug_headers: bool = False) -> str:
3535
"""A standard error message that can be shown when an API response cannot be parsed.
3636
3737
This error message includes a link to the issue tracker for the SDK for users to report the issue to us.
38+
39+
:param response: The response object from the API request.
40+
:param debug_headers: Whether to include headers in the request log. Defaults to False to defensively handle cases where request headers might contain sensitive data (e.g. tokens).
3841
"""
39-
request_log = RoundTrip(response, debug_headers=True, debug_truncate_bytes=10 * 1024).generate()
42+
request_log = RoundTrip(response, debug_headers=debug_headers, debug_truncate_bytes=10 * 1024).generate()
4043
return (
4144
"This is likely a bug in the Databricks SDK for Python or the underlying "
4245
"API. Please report this issue with the following debugging information to the SDK issue tracker at "
@@ -56,11 +59,13 @@ def __init__(
5659
self,
5760
extra_error_parsers: List[_ErrorDeserializer] = [],
5861
extra_error_customizers: List[_ErrorCustomizer] = [],
62+
debug_headers: bool = False,
5963
):
6064
self._error_parsers = _error_deserializers + (extra_error_parsers if extra_error_parsers is not None else [])
6165
self._error_customizers = _error_customizers + (
6266
extra_error_customizers if extra_error_customizers is not None else []
6367
)
68+
self._debug_headers = debug_headers
6469

6570
def get_api_error(self, response: requests.Response) -> Optional[DatabricksError]:
6671
"""
@@ -84,7 +89,7 @@ def get_api_error(self, response: requests.Response) -> Optional[DatabricksError
8489
)
8590
return _error_mapper(
8691
response,
87-
{"message": "unable to parse response. " + _unknown_error(response)},
92+
{"message": "unable to parse response. " + _unknown_error(response, self._debug_headers)},
8893
)
8994

9095
# Private link failures happen via a redirect to the login page. From a requests-perspective, the request

databricks/sdk/retries.py

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import functools
22
import logging
33
from datetime import timedelta
4-
from random import random
5-
from typing import Callable, Optional, Sequence, Type
4+
from random import random, uniform
5+
from typing import Callable, Optional, Sequence, Tuple, Type, TypeVar
66

77
from .clock import Clock, RealClock
88

99
logger = logging.getLogger(__name__)
1010

11+
T = TypeVar("T")
12+
1113

1214
def retried(
1315
*,
@@ -67,3 +69,101 @@ def wrapper(*args, **kwargs):
6769
return wrapper
6870

6971
return decorator
72+
73+
74+
class RetryError(Exception):
75+
"""Error that can be returned from poll functions to control retry behavior."""
76+
77+
def __init__(self, err: Exception, halt: bool = False):
78+
self.err = err
79+
self.halt = halt
80+
super().__init__(str(err))
81+
82+
@staticmethod
83+
def continues(msg: str) -> "RetryError":
84+
"""Create a non-halting retry error with a message."""
85+
return RetryError(Exception(msg), halt=False)
86+
87+
@staticmethod
88+
def halt(err: Exception) -> "RetryError":
89+
"""Create a halting retry error."""
90+
return RetryError(err, halt=True)
91+
92+
93+
def _backoff(attempt: int) -> float:
94+
"""Calculate backoff time with jitter.
95+
96+
Linear backoff: attempt * 1 second, capped at 10 seconds
97+
Plus random jitter between 50ms and 750ms.
98+
"""
99+
wait = min(10, attempt)
100+
jitter = uniform(0.05, 0.75)
101+
return wait + jitter
102+
103+
104+
def poll(
105+
fn: Callable[[], Tuple[Optional[T], Optional[RetryError]]],
106+
timeout: timedelta = timedelta(minutes=20),
107+
clock: Optional[Clock] = None,
108+
) -> T:
109+
"""Poll a function until it succeeds or times out.
110+
111+
The backoff is linear backoff and jitter.
112+
113+
This function is not meant to be used directly by users.
114+
It is used internally by the SDK to poll for the result of an operation.
115+
It can be changed in the future without any notice.
116+
117+
:param fn: Function that returns (result, error).
118+
Return (None, RetryError.continues("msg")) to continue polling.
119+
Return (None, RetryError.halt(err)) to stop with error.
120+
Return (result, None) on success.
121+
:param timeout: Maximum time to poll (default: 20 minutes)
122+
:param clock: Clock implementation for testing (default: RealClock)
123+
:returns: The result of the successful function call
124+
:raises TimeoutError: If the timeout is reached
125+
:raises Exception: If a halting error is encountered
126+
127+
Example:
128+
def check_operation():
129+
op = get_operation()
130+
if not op.done:
131+
return None, RetryError.continues("operation still in progress")
132+
if op.error:
133+
return None, RetryError.halt(Exception(f"operation failed: {op.error}"))
134+
return op.result, None
135+
136+
result = poll(check_operation, timeout=timedelta(minutes=5))
137+
"""
138+
if clock is None:
139+
clock = RealClock()
140+
141+
deadline = clock.time() + timeout.total_seconds()
142+
attempt = 0
143+
last_err = None
144+
145+
while clock.time() < deadline:
146+
attempt += 1
147+
148+
try:
149+
result, err = fn()
150+
151+
if err is None:
152+
return result
153+
154+
if err.halt:
155+
raise err.err
156+
157+
# Continue polling.
158+
last_err = err.err
159+
wait = _backoff(attempt)
160+
logger.debug(f"{str(err.err).rstrip('.')}. Sleeping {wait:.3f}s")
161+
clock.sleep(wait)
162+
163+
except RetryError:
164+
raise
165+
except Exception as e:
166+
# Unexpected error, halt immediately.
167+
raise e
168+
169+
raise TimeoutError(f"Timed out after {timeout}") from last_err

databricks/sdk/service/_internal.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import datetime
22
import urllib.parse
3-
from typing import Callable, Dict, Generic, Optional, Type, TypeVar
3+
from typing import Callable, Dict, Generic, List, Optional, Type, TypeVar
4+
5+
from google.protobuf.duration_pb2 import Duration
6+
from google.protobuf.timestamp_pb2 import Timestamp
7+
8+
from databricks.sdk.common.types.fieldmask import FieldMask
49

510

611
def _from_dict(d: Dict[str, any], field: str, cls: Type) -> any:
@@ -46,6 +51,93 @@ def _escape_multi_segment_path_parameter(param: str) -> str:
4651
return urllib.parse.quote(param)
4752

4853

54+
def _timestamp(d: Dict[str, any], field: str) -> Optional[Timestamp]:
55+
"""
56+
Helper function to convert a timestamp string to a Timestamp object.
57+
It takes a dictionary and a field name, and returns a Timestamp object.
58+
The field name is the key in the dictionary that contains the timestamp string.
59+
"""
60+
if field not in d or not d[field]:
61+
return None
62+
ts = Timestamp()
63+
ts.FromJsonString(d[field])
64+
return ts
65+
66+
67+
def _repeated_timestamp(d: Dict[str, any], field: str) -> Optional[List[Timestamp]]:
68+
"""
69+
Helper function to convert a list of timestamp strings to a list of Timestamp objects.
70+
It takes a dictionary and a field name, and returns a list of Timestamp objects.
71+
The field name is the key in the dictionary that contains the list of timestamp strings.
72+
"""
73+
if field not in d or not d[field]:
74+
return None
75+
result = []
76+
for v in d[field]:
77+
ts = Timestamp()
78+
ts.FromJsonString(v)
79+
result.append(ts)
80+
return result
81+
82+
83+
def _duration(d: Dict[str, any], field: str) -> Optional[Duration]:
84+
"""
85+
Helper function to convert a duration string to a Duration object.
86+
It takes a dictionary and a field name, and returns a Duration object.
87+
The field name is the key in the dictionary that contains the duration string.
88+
"""
89+
if field not in d or not d[field]:
90+
return None
91+
dur = Duration()
92+
dur.FromJsonString(d[field])
93+
return dur
94+
95+
96+
def _repeated_duration(d: Dict[str, any], field: str) -> Optional[List[Duration]]:
97+
"""
98+
Helper function to convert a list of duration strings to a list of Duration objects.
99+
It takes a dictionary and a field name, and returns a list of Duration objects.
100+
The field name is the key in the dictionary that contains the list of duration strings.
101+
"""
102+
if field not in d or not d[field]:
103+
return None
104+
result = []
105+
for v in d[field]:
106+
dur = Duration()
107+
dur.FromJsonString(v)
108+
result.append(dur)
109+
return result
110+
111+
112+
def _fieldmask(d: Dict[str, any], field: str) -> Optional[FieldMask]:
113+
"""
114+
Helper function to convert a fieldmask string to a FieldMask object.
115+
It takes a dictionary and a field name, and returns a FieldMask object.
116+
The field name is the key in the dictionary that contains the fieldmask string.
117+
"""
118+
if field not in d or not d[field]:
119+
return None
120+
fm = FieldMask()
121+
fm.FromJsonString(d[field])
122+
return fm
123+
124+
125+
def _repeated_fieldmask(d: Dict[str, any], field: str) -> Optional[List[FieldMask]]:
126+
"""
127+
Helper function to convert a list of fieldmask strings to a list of FieldMask objects.
128+
It takes a dictionary and a field name, and returns a list of FieldMask objects.
129+
The field name is the key in the dictionary that contains the list of fieldmask strings.
130+
"""
131+
if field not in d or not d[field]:
132+
return None
133+
result = []
134+
for v in d[field]:
135+
fm = FieldMask()
136+
fm.FromJsonString(v)
137+
result.append(fm)
138+
return result
139+
140+
49141
ReturnType = TypeVar("ReturnType")
50142

51143

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ classifiers = [
2727
dependencies = [
2828
"requests>=2.28.1,<3",
2929
"google-auth~=2.0",
30+
"protobuf>=4.21.0,<7.0",
3031
]
3132

3233
[project.urls]

0 commit comments

Comments
 (0)