Skip to content

Commit 7e9430e

Browse files
Merge branch 'main' into renaud.hartert/error-details
2 parents 13a0a2f + 8de985d commit 7e9430e

File tree

6 files changed

+206
-32
lines changed

6 files changed

+206
-32
lines changed

databricks/sdk/data_plane.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,82 @@
1+
from __future__ import annotations
2+
13
import threading
24
from dataclasses import dataclass
3-
from typing import Callable, List
5+
from typing import Callable, List, Optional
6+
from urllib import parse
47

8+
from databricks.sdk import oauth
59
from databricks.sdk.oauth import Token
610

11+
URL_ENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded"
12+
JWT_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"
13+
OIDC_TOKEN_PATH = "/oidc/v1/token"
14+
15+
16+
class DataPlaneTokenSource:
17+
"""
18+
EXPERIMENTAL Manages token sources for multiple DataPlane endpoints.
19+
"""
20+
21+
# TODO: Enable async once its stable. @oauth_credentials_provider must also have async enabled.
22+
def __init__(self, token_exchange_host: str, cpts: Callable[[], Token], disable_async: Optional[bool] = True):
23+
self._cpts = cpts
24+
self._token_exchange_host = token_exchange_host
25+
self._token_sources = {}
26+
self._disable_async = disable_async
27+
self._lock = threading.Lock()
28+
29+
def token(self, endpoint, auth_details):
30+
key = f"{endpoint}:{auth_details}"
31+
32+
# First, try to read without acquiring the lock to avoid contention.
33+
# Reads are atomic, so this is safe.
34+
token_source = self._token_sources.get(key)
35+
if token_source:
36+
return token_source.token()
37+
38+
# If token_source is not found, acquire the lock and check again.
39+
with self._lock:
40+
# Another thread might have created it while we were waiting for the lock.
41+
token_source = self._token_sources.get(key)
42+
if not token_source:
43+
token_source = DataPlaneEndpointTokenSource(
44+
self._token_exchange_host, self._cpts, auth_details, self._disable_async
45+
)
46+
self._token_sources[key] = token_source
47+
48+
return token_source.token()
49+
50+
51+
class DataPlaneEndpointTokenSource(oauth.Refreshable):
52+
"""
53+
EXPERIMENTAL A token source for a specific DataPlane endpoint.
54+
"""
55+
56+
def __init__(self, token_exchange_host: str, cpts: Callable[[], Token], auth_details: str, disable_async: bool):
57+
super().__init__(disable_async=disable_async)
58+
self._auth_details = auth_details
59+
self._cpts = cpts
60+
self._token_exchange_host = token_exchange_host
61+
62+
def refresh(self) -> Token:
63+
control_plane_token = self._cpts()
64+
headers = {"Content-Type": URL_ENCODED_CONTENT_TYPE}
65+
params = parse.urlencode(
66+
{
67+
"grant_type": JWT_BEARER_GRANT_TYPE,
68+
"authorization_details": self._auth_details,
69+
"assertion": control_plane_token.access_token,
70+
}
71+
)
72+
return oauth.retrieve_token(
73+
client_id="",
74+
client_secret="",
75+
token_url=self._token_exchange_host + OIDC_TOKEN_PATH,
76+
params=params,
77+
headers=headers,
78+
)
79+
780

881
@dataclass
982
class DataPlaneDetails:
@@ -17,6 +90,9 @@ class DataPlaneDetails:
1790
"""Token to query the DataPlane endpoint."""
1891

1992

93+
## Old implementation. #TODO: Remove after the new implementation is used
94+
95+
2096
class DataPlaneService:
2197
"""Helper class to fetch and manage DataPlane details."""
2298

tests/integration/conftest.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import pytest
99

10-
from databricks.sdk import AccountClient, WorkspaceClient
10+
from databricks.sdk import AccountClient, FilesAPI, FilesExt, WorkspaceClient
1111
from databricks.sdk.service.catalog import VolumeType
1212

1313

@@ -125,6 +125,18 @@ def volume(ucws, schema):
125125
ucws.volumes.delete(volume.full_name)
126126

127127

128+
@pytest.fixture(scope="session", params=[False, True])
129+
def files_api(request, ucws) -> FilesAPI:
130+
if request.param:
131+
# ensure new Files API client is used for files of any size
132+
ucws.config.multipart_upload_min_stream_size = 0
133+
# enable new Files API client
134+
return FilesExt(ucws.api_client, ucws.config)
135+
else:
136+
# use the default client
137+
return ucws.files
138+
139+
128140
@pytest.fixture()
129141
def workspace_dir(w, random):
130142
directory = f"/Users/{w.current_user.me().user_name}/dir-{random(12)}"

tests/integration/test_auth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,12 @@ def _get_lts_versions(w) -> typing.List[SparkVersion]:
108108
return lts_runtimes
109109

110110

111-
def test_runtime_auth_from_jobs_volumes(ucws, fresh_wheel_file, env_or_skip, random, volume):
111+
def test_runtime_auth_from_jobs_volumes(ucws, files_api, fresh_wheel_file, env_or_skip, random, volume):
112112
dbr_versions = [v for v in _get_lts_versions(ucws) if int(v.key.split(".")[0]) >= 15]
113113

114114
volume_wheel = f"{volume}/tmp/wheels/{random(10)}/{fresh_wheel_file.name}"
115115
with fresh_wheel_file.open("rb") as f:
116-
ucws.files.upload(volume_wheel, f)
116+
files_api.upload(volume_wheel, f)
117117

118118
lib = Library(whl=volume_wheel)
119119
return _test_runtime_auth_from_jobs_inner(ucws, env_or_skip, random, dbr_versions, lib)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from databricks.sdk.data_plane import DataPlaneTokenSource
2+
3+
4+
def test_data_plane_token_source(ucws, env_or_skip):
5+
endpoint = env_or_skip("SERVING_ENDPOINT_NAME")
6+
serving_endpoint = ucws.serving_endpoints.get(endpoint)
7+
assert serving_endpoint.data_plane_info is not None
8+
assert serving_endpoint.data_plane_info.query_info is not None
9+
10+
info = serving_endpoint.data_plane_info.query_info
11+
12+
ts = DataPlaneTokenSource(ucws.config.host, ucws._config.oauth_token)
13+
dp_token = ts.token(info.endpoint_url, info.authorization_details)
14+
15+
assert dp_token.valid

tests/integration/test_files.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -218,30 +218,30 @@ def create_volume(w, catalog, schema, volume):
218218
return ResourceWithCleanup(lambda: w.volumes.delete(res.full_name))
219219

220220

221-
def test_files_api_upload_download(ucws, random):
221+
def test_files_api_upload_download(ucws, files_api, random):
222222
w = ucws
223223
schema = "filesit-" + random()
224224
volume = "filesit-" + random()
225225
with ResourceWithCleanup.create_schema(w, "main", schema):
226226
with ResourceWithCleanup.create_volume(w, "main", schema, volume):
227227
f = io.BytesIO(b"some text data")
228228
target_file = f"/Volumes/main/{schema}/{volume}/filesit-with-?-and-#-{random()}.txt"
229-
w.files.upload(target_file, f)
230-
with w.files.download(target_file).contents as f:
229+
files_api.upload(target_file, f)
230+
with files_api.download(target_file).contents as f:
231231
assert f.read() == b"some text data"
232232

233233

234-
def test_files_api_read_twice_from_one_download(ucws, random):
234+
def test_files_api_read_twice_from_one_download(ucws, files_api, random):
235235
w = ucws
236236
schema = "filesit-" + random()
237237
volume = "filesit-" + random()
238238
with ResourceWithCleanup.create_schema(w, "main", schema):
239239
with ResourceWithCleanup.create_volume(w, "main", schema, volume):
240240
f = io.BytesIO(b"some text data")
241241
target_file = f"/Volumes/main/{schema}/{volume}/filesit-{random()}.txt"
242-
w.files.upload(target_file, f)
242+
files_api.upload(target_file, f)
243243

244-
res = w.files.download(target_file).contents
244+
res = files_api.download(target_file).contents
245245

246246
with res:
247247
assert res.read() == b"some text data"
@@ -251,82 +251,82 @@ def test_files_api_read_twice_from_one_download(ucws, random):
251251
res.read()
252252

253253

254-
def test_files_api_delete_file(ucws, random):
254+
def test_files_api_delete_file(ucws, files_api, random):
255255
w = ucws
256256
schema = "filesit-" + random()
257257
volume = "filesit-" + random()
258258
with ResourceWithCleanup.create_schema(w, "main", schema):
259259
with ResourceWithCleanup.create_volume(w, "main", schema, volume):
260260
f = io.BytesIO(b"some text data")
261261
target_file = f"/Volumes/main/{schema}/{volume}/filesit-{random()}.txt"
262-
w.files.upload(target_file, f)
263-
w.files.delete(target_file)
262+
files_api.upload(target_file, f)
263+
files_api.delete(target_file)
264264

265265

266-
def test_files_api_get_metadata(ucws, random):
266+
def test_files_api_get_metadata(ucws, files_api, random):
267267
w = ucws
268268
schema = "filesit-" + random()
269269
volume = "filesit-" + random()
270270
with ResourceWithCleanup.create_schema(w, "main", schema):
271271
with ResourceWithCleanup.create_volume(w, "main", schema, volume):
272272
f = io.BytesIO(b"some text data")
273273
target_file = f"/Volumes/main/{schema}/{volume}/filesit-{random()}.txt"
274-
w.files.upload(target_file, f)
275-
m = w.files.get_metadata(target_file)
274+
files_api.upload(target_file, f)
275+
m = files_api.get_metadata(target_file)
276276
assert m.content_type == "application/octet-stream"
277277
assert m.content_length == 14
278278
assert m.last_modified is not None
279279

280280

281-
def test_files_api_create_directory(ucws, random):
281+
def test_files_api_create_directory(ucws, files_api, random):
282282
w = ucws
283283
schema = "filesit-" + random()
284284
volume = "filesit-" + random()
285285
with ResourceWithCleanup.create_schema(w, "main", schema):
286286
with ResourceWithCleanup.create_volume(w, "main", schema, volume):
287287
target_directory = f"/Volumes/main/{schema}/{volume}/filesit-{random()}/"
288-
w.files.create_directory(target_directory)
288+
files_api.create_directory(target_directory)
289289

290290

291-
def test_files_api_list_directory_contents(ucws, random):
291+
def test_files_api_list_directory_contents(ucws, files_api, random):
292292
w = ucws
293293
schema = "filesit-" + random()
294294
volume = "filesit-" + random()
295295
with ResourceWithCleanup.create_schema(w, "main", schema):
296296
with ResourceWithCleanup.create_volume(w, "main", schema, volume):
297297
target_directory = f"/Volumes/main/{schema}/{volume}/filesit-{random()}"
298-
w.files.upload(target_directory + "/file1.txt", io.BytesIO(b"some text data"))
299-
w.files.upload(target_directory + "/file2.txt", io.BytesIO(b"some text data"))
300-
w.files.upload(target_directory + "/file3.txt", io.BytesIO(b"some text data"))
298+
files_api.upload(target_directory + "/file1.txt", io.BytesIO(b"some text data"))
299+
files_api.upload(target_directory + "/file2.txt", io.BytesIO(b"some text data"))
300+
files_api.upload(target_directory + "/file3.txt", io.BytesIO(b"some text data"))
301301

302-
result = list(w.files.list_directory_contents(target_directory))
302+
result = list(files_api.list_directory_contents(target_directory))
303303
assert len(result) == 3
304304

305305

306-
def test_files_api_delete_directory(ucws, random):
306+
def test_files_api_delete_directory(ucws, files_api, random):
307307
w = ucws
308308
schema = "filesit-" + random()
309309
volume = "filesit-" + random()
310310
with ResourceWithCleanup.create_schema(w, "main", schema):
311311
with ResourceWithCleanup.create_volume(w, "main", schema, volume):
312312
target_directory = f"/Volumes/main/{schema}/{volume}/filesit-{random()}/"
313-
w.files.create_directory(target_directory)
314-
w.files.delete_directory(target_directory)
313+
files_api.create_directory(target_directory)
314+
files_api.delete_directory(target_directory)
315315

316316

317-
def test_files_api_get_directory_metadata(ucws, random):
317+
def test_files_api_get_directory_metadata(ucws, files_api, random):
318318
w = ucws
319319
schema = "filesit-" + random()
320320
volume = "filesit-" + random()
321321
with ResourceWithCleanup.create_schema(w, "main", schema):
322322
with ResourceWithCleanup.create_volume(w, "main", schema, volume):
323323
target_directory = f"/Volumes/main/{schema}/{volume}/filesit-{random()}/"
324-
w.files.create_directory(target_directory)
325-
w.files.get_directory_metadata(target_directory)
324+
files_api.create_directory(target_directory)
325+
files_api.get_directory_metadata(target_directory)
326326

327327

328328
@pytest.mark.benchmark
329-
def test_files_api_download_benchmark(ucws, random):
329+
def test_files_api_download_benchmark(ucws, files_api, random):
330330
w = ucws
331331
schema = "filesit-" + random()
332332
volume = "filesit-" + random()
@@ -335,7 +335,7 @@ def test_files_api_download_benchmark(ucws, random):
335335
# Create a 50 MB file
336336
f = io.BytesIO(bytes(range(256)) * 200000)
337337
target_file = f"/Volumes/main/{schema}/{volume}/filesit-benchmark-{random()}.txt"
338-
w.files.upload(target_file, f)
338+
files_api.upload(target_file, f)
339339

340340
totals = {}
341341
for chunk_size_kb in [
@@ -357,7 +357,7 @@ def test_files_api_download_benchmark(ucws, random):
357357
count = 10
358358
for i in range(count):
359359
start = time.time()
360-
f = w.files.download(target_file).contents
360+
f = files_api.download(target_file).contents
361361
f.set_chunk_size(chunk_size)
362362
with f as vf:
363363
vf.read()

tests/test_data_plane.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,80 @@
11
from datetime import datetime, timedelta
2+
from unittest.mock import patch
3+
from urllib import parse
24

5+
from databricks.sdk import data_plane, oauth
36
from databricks.sdk.data_plane import DataPlaneService
47
from databricks.sdk.oauth import Token
58
from databricks.sdk.service.serving import DataPlaneInfo
69

10+
cp_token = Token(access_token="control plane token", token_type="type", expiry=datetime.now() + timedelta(hours=1))
11+
dp_token = Token(access_token="data plane token", token_type="type", expiry=datetime.now() + timedelta(hours=1))
12+
13+
14+
def success_callable(token: oauth.Token):
15+
16+
def success() -> oauth.Token:
17+
return token
18+
19+
return success
20+
21+
22+
def test_endpoint_token_source_get_token(config):
23+
token_source = data_plane.DataPlaneEndpointTokenSource(
24+
config.host, success_callable(cp_token), "authDetails", disable_async=True
25+
)
26+
27+
with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token:
28+
token_source.token()
29+
30+
retrieve_token.assert_called_once()
31+
args, kwargs = retrieve_token.call_args
32+
33+
assert kwargs["token_url"] == config.host + "/oidc/v1/token"
34+
assert kwargs["params"] == parse.urlencode(
35+
{
36+
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
37+
"authorization_details": "authDetails",
38+
"assertion": cp_token.access_token,
39+
}
40+
)
41+
assert kwargs["headers"] == {"Content-Type": "application/x-www-form-urlencoded"}
42+
43+
44+
def test_token_source_get_token_not_existing(config):
45+
token_source = data_plane.DataPlaneTokenSource(config.host, success_callable(cp_token), disable_async=True)
46+
47+
with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token:
48+
result_token = token_source.token(endpoint="endpoint", auth_details="authDetails")
49+
50+
retrieve_token.assert_called_once()
51+
assert result_token.access_token == dp_token.access_token
52+
assert "endpoint:authDetails" in token_source._token_sources
53+
54+
55+
class MockEndpointTokenSource:
56+
57+
def __init__(self, token: oauth.Token):
58+
self._token = token
59+
60+
def token(self):
61+
return self._token
62+
63+
64+
def test_token_source_get_token_existing(config):
65+
another_token = Token(access_token="another token", token_type="type", expiry=datetime.now() + timedelta(hours=1))
66+
token_source = data_plane.DataPlaneTokenSource(config.host, success_callable(token), disable_async=True)
67+
token_source._token_sources["endpoint:authDetails"] = MockEndpointTokenSource(another_token)
68+
69+
with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token:
70+
result_token = token_source.token(endpoint="endpoint", auth_details="authDetails")
71+
72+
retrieve_token.assert_not_called()
73+
assert result_token.access_token == another_token.access_token
74+
75+
76+
## These tests are for the old implementation. #TODO: Remove after the new implementation is used
77+
778
info = DataPlaneInfo(authorization_details="authDetails", endpoint_url="url")
879

980
token = Token(

0 commit comments

Comments
 (0)