Skip to content

Commit 9107e9c

Browse files
committed
[Internal] Add DataPlane token source
1 parent e550ca1 commit 9107e9c

File tree

3 files changed

+169
-1
lines changed

3 files changed

+169
-1
lines changed

databricks/sdk/data_plane.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,81 @@
1+
from __future__ import annotations
2+
13
import threading
24
from dataclasses import dataclass
3-
from typing import Callable, List
5+
from typing import Callable, List, Optional
6+
from urllib import parse
47

8+
from databricks.sdk import oauth
59
from databricks.sdk.oauth import Token
610

11+
URL_ENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded"
12+
JWT_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"
13+
OIDC_TOKEN_PATH = "/oidc/v1/token"
14+
15+
16+
class DataPlaneTokenSource:
17+
"""
18+
EXPERIMENTAL Manages token sources for multiple DataPlane endpoints.
19+
"""
20+
21+
# TODO: Enable async once its stable. @oauth_credentials_provider must also have async enabled.
22+
def __init__(self,
23+
token_exchange_host: str,
24+
cpts: Callable[[], Token],
25+
disable_async: Optional[bool] = True):
26+
self._cpts = cpts
27+
self._token_exchange_host = token_exchange_host
28+
self._token_sources = {}
29+
self._disable_async = disable_async
30+
self._lock = threading.Lock()
31+
32+
def get_token(self, endpoint, auth_details):
33+
key = f"{endpoint}:{auth_details}"
34+
35+
# First, try to read without acquiring the lock to avoid contention.
36+
# Reads are atomic, so this is safe.
37+
token_source = self._token_sources.get(key)
38+
if token_source:
39+
return token_source.token()
40+
41+
# If token_source is not found, acquire the lock and check again.
42+
with self._lock:
43+
# Another thread might have created it while we were waiting for the lock.
44+
token_source = self._token_sources.get(key)
45+
if not token_source:
46+
token_source = DataPlaneEndpointTokenSource(self._token_exchange_host, self._cpts,
47+
auth_details, self._disable_async)
48+
self._token_sources[key] = token_source
49+
50+
return token_source.token()
51+
52+
53+
class DataPlaneEndpointTokenSource(oauth.Refreshable):
54+
"""
55+
EXPERIMENTAL A token source for a specific DataPlane endpoint.
56+
"""
57+
58+
def __init__(self, token_exchange_host: str, cpts: Callable[[], Token], auth_details: str,
59+
disable_async: bool):
60+
super().__init__(disable_async=disable_async)
61+
self._auth_details = auth_details
62+
self._cpts = cpts
63+
self._token_exchange_host = token_exchange_host
64+
65+
def refresh(self) -> Token:
66+
control_plane_token = self._cpts()
67+
headers = {"Content-Type": URL_ENCODED_CONTENT_TYPE}
68+
params = parse.urlencode({
69+
"grant_type": JWT_BEARER_GRANT_TYPE,
70+
"authorization_details": self._auth_details,
71+
"assertion": control_plane_token.access_token
72+
})
73+
return oauth.retrieve_token(client_id="",
74+
client_secret="",
75+
token_url=self._token_exchange_host + OIDC_TOKEN_PATH,
76+
params=params,
77+
headers=headers)
78+
779

880
@dataclass
981
class DataPlaneDetails:
@@ -16,6 +88,9 @@ class DataPlaneDetails:
1688
"""Token to query the DataPlane endpoint."""
1789

1890

91+
## Old implementation. #TODO: Remove after the new implementation is used
92+
93+
1994
class DataPlaneService:
2095
"""Helper class to fetch and manage DataPlane details."""
2196
from .service.serving import DataPlaneInfo
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from databricks.sdk.data_plane import DataPlaneTokenSource
2+
3+
4+
def test_data_plane_token_source(ucws, env_or_skip):
5+
endpoint = env_or_skip("SERVING_ENDPOINT_NAME")
6+
serving_endpoint = ucws.serving_endpoints.get(endpoint)
7+
assert serving_endpoint.data_plane_info is not None
8+
assert serving_endpoint.data_plane_info.query_info is not None
9+
10+
info = serving_endpoint.data_plane_info.query_info
11+
12+
ts = DataPlaneTokenSource(ucws.config.host, ucws._config.oauth_token)
13+
dp_token = ts.token(info.endpoint_url, info.authorization_details)
14+
15+
assert dp_token.valid

tests/test_data_plane.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,87 @@
11
from datetime import datetime, timedelta
2+
from unittest.mock import patch
3+
from urllib import parse
24

5+
from databricks.sdk import data_plane, oauth
36
from databricks.sdk.data_plane import DataPlaneService
47
from databricks.sdk.oauth import Token
58
from databricks.sdk.service.serving import DataPlaneInfo
69

10+
cp_token = Token(access_token="control plane token",
11+
token_type="type",
12+
expiry=datetime.now() + timedelta(hours=1))
13+
dp_token = Token(access_token="data plane token",
14+
token_type="type",
15+
expiry=datetime.now() + timedelta(hours=1))
16+
17+
18+
def success_callable(token: oauth.Token):
19+
20+
def success() -> oauth.Token:
21+
return token
22+
23+
return success
24+
25+
26+
def test_endpoint_token_source_get_token(config):
27+
token_source = data_plane.DataPlaneEndpointTokenSource(config.host,
28+
success_callable(cp_token),
29+
"authDetails",
30+
disable_async=True)
31+
32+
with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token:
33+
token_source.token()
34+
35+
retrieve_token.assert_called_once()
36+
args, kwargs = retrieve_token.call_args
37+
38+
assert kwargs["token_url"] == config.host + "/oidc/v1/token"
39+
assert kwargs["params"] == parse.urlencode({
40+
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
41+
"authorization_details": "authDetails",
42+
"assertion": cp_token.access_token
43+
})
44+
assert kwargs["headers"] == {"Content-Type": "application/x-www-form-urlencoded"}
45+
46+
47+
def test_token_source_get_token_not_existing(config):
48+
token_source = data_plane.DataPlaneTokenSource(config.host,
49+
success_callable(cp_token),
50+
disable_async=True)
51+
52+
with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token:
53+
result_token = token_source.token(endpoint="endpoint", auth_details="authDetails")
54+
55+
retrieve_token.assert_called_once()
56+
assert result_token.access_token == dp_token.access_token
57+
assert "endpoint:authDetails" in token_source._token_sources
58+
59+
60+
class MockEndpointTokenSource:
61+
62+
def __init__(self, token: oauth.Token):
63+
self._token = token
64+
65+
def token(self):
66+
return self._token
67+
68+
69+
def test_token_source_get_token_existing(config):
70+
another_token = Token(access_token="another token",
71+
token_type="type",
72+
expiry=datetime.now() + timedelta(hours=1))
73+
token_source = data_plane.DataPlaneTokenSource(config.host, success_callable(token), disable_async=True)
74+
token_source._token_sources["endpoint:authDetails"] = MockEndpointTokenSource(another_token)
75+
76+
with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token:
77+
result_token = token_source.token(endpoint="endpoint", auth_details="authDetails")
78+
79+
retrieve_token.assert_not_called()
80+
assert result_token.access_token == another_token.access_token
81+
82+
83+
## These tests are for the old implementation. #TODO: Remove after the new implementation is used
84+
785
info = DataPlaneInfo(authorization_details="authDetails", endpoint_url="url")
886

987
token = Token(access_token="token", token_type="type", expiry=datetime.now() + timedelta(hours=1))

0 commit comments

Comments
 (0)