Skip to content

Commit 5846242

Browse files
committed
[Internal] Add DataPlane token source
1 parent e550ca1 commit 5846242

File tree

2 files changed

+135
-2
lines changed

2 files changed

+135
-2
lines changed

databricks/sdk/data_plane.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,69 @@
1+
from __future__ import annotations
2+
13
import threading
24
from dataclasses import dataclass
3-
from typing import Callable, List
5+
from typing import Callable, List, Optional
6+
7+
from urllib import parse
48

59
from databricks.sdk.oauth import Token
10+
from databricks.sdk import oauth, config, credentials_provider
11+
12+
13+
URL_ENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded"
14+
JWT_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"
15+
OIDC_TOKEN_PATH = "/oidc/v1/token"
16+
17+
class DataPlaneTokenSource:
18+
"""
19+
Manages token sources for multiple DataPlane endpoints.
20+
"""
21+
# TODO: Enable async once its stable. @oauth_credentials_provider must also have async enabled.
22+
def __init__(self, cfg: config.Config, cpts: Callable[[], Token], disable_async: Optional[bool] = True):
23+
self._cpts = cpts
24+
self._cfg = cfg
25+
self._token_sources = {}
26+
self._disable_async = disable_async
27+
28+
def token(self, endpoint: str, auth_details: str):
29+
"""
30+
Get a token for a specific DataPlane endpoint.
31+
:param endpoint: endpoint URL for which to get a token
32+
:param auth_details: authorization details used to generate the token
33+
:return: a token for the specified endpoint
34+
"""
35+
key = f"{endpoint}:{auth_details}"
36+
token_source = self._token_sources.get(key)
37+
if not token_source:
38+
token_source = DataPlaneEndpointTokenSource(self._cfg, self._cpts, auth_details, self._disable_async)
39+
self._token_sources[key] = token_source
40+
return token_source.token()
41+
42+
43+
class DataPlaneEndpointTokenSource(oauth.Refreshable):
44+
"""
45+
A token source for a specific DataPlane endpoint.
46+
"""
47+
def __init__(self, cfg: config.Config, cpts: Callable[[], Token], auth_details: str, disable_async: bool):
48+
super().__init__(disable_async=disable_async)
49+
self._auth_details = auth_details
50+
self._cpts = cpts
51+
self._cfg = cfg
52+
53+
def refresh(self) -> Token:
54+
control_plane_token = self._cpts()
55+
headers = {"Content-Type": URL_ENCODED_CONTENT_TYPE}
56+
params = parse.urlencode({
57+
"grant_type": JWT_BEARER_GRANT_TYPE,
58+
"authorization_details": self._auth_details,
59+
"assertion": control_plane_token.access_token
60+
})
61+
return oauth.retrieve_token(client_id=self._cfg.client_id,
62+
client_secret=self._cfg.client_secret,
63+
token_url=self._cfg.host + OIDC_TOKEN_PATH,
64+
params=params,
65+
headers=headers)
66+
667

768

869
@dataclass
@@ -16,6 +77,8 @@ class DataPlaneDetails:
1677
"""Token to query the DataPlane endpoint."""
1778

1879

80+
## Old implementation. #TODO: Remove after the new implementation is used
81+
1982
class DataPlaneService:
2083
"""Helper class to fetch and manage DataPlane details."""
2184
from .service.serving import DataPlaneInfo

tests/test_data_plane.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,84 @@
11
from datetime import datetime, timedelta
2+
from unittest.mock import patch
3+
from urllib import parse
24

5+
from databricks.sdk import data_plane, oauth
36
from databricks.sdk.data_plane import DataPlaneService
47
from databricks.sdk.oauth import Token
58
from databricks.sdk.service.serving import DataPlaneInfo
69

10+
11+
cp_token = Token(access_token="control plane token", token_type="type", expiry=datetime.now() + timedelta(hours=1))
12+
dp_token = Token(access_token="data plane token", token_type="type", expiry=datetime.now() + timedelta(hours=1))
13+
14+
def success_callable(token: oauth.Token):
15+
def success() -> oauth.Token:
16+
return token
17+
return success
18+
19+
def patch_retrieve_token(token: oauth.Token):
20+
return patch("databricks.sdk.oauth.retrieve_token", return_value=token)
21+
22+
def test_endpoint_token_source_get_token(config):
23+
config.client_id = "client_id"
24+
config.client_secret = "client_secret"
25+
token_source = data_plane.DataPlaneEndpointTokenSource(config, success_callable(cp_token), "authDetails", disable_async=True)
26+
27+
with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token:
28+
token_source.token()
29+
30+
retrieve_token.assert_called_once()
31+
args, kwargs = retrieve_token.call_args
32+
33+
assert kwargs["client_id"] == config.client_id
34+
assert kwargs["client_secret"] == config.client_secret
35+
assert kwargs["token_url"] == config.host + "/oidc/v1/token"
36+
assert kwargs["params"] == parse.urlencode({"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", "authorization_details": "authDetails", "assertion": cp_token.access_token})
37+
assert kwargs["headers"] == {"Content-Type": "application/x-www-form-urlencoded"}
38+
39+
def test_token_source_get_token_not_existing(config):
40+
config.client_id = "client_id"
41+
config.client_secret = "client_secret"
42+
token_source = data_plane.DataPlaneTokenSource(config, success_callable(cp_token), disable_async=True)
43+
44+
with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token:
45+
result_token = token_source.token(endpoint="endpoint", auth_details="authDetails")
46+
47+
retrieve_token.assert_called_once()
48+
assert result_token.access_token == dp_token.access_token
49+
assert "endpoint:authDetails" in token_source._token_sources
50+
51+
class MockEndpointTokenSource:
52+
53+
def __init__(self, token: oauth.Token):
54+
self._token = token
55+
56+
def token(self):
57+
return self._token
58+
59+
def test_token_source_get_token_existing(config):
60+
config.client_id = "client_id"
61+
config.client_secret = "client_secret"
62+
another_token = Token(access_token="another token", token_type="type", expiry=datetime.now() + timedelta(hours=1))
63+
token_source = data_plane.DataPlaneTokenSource(config, success_callable(cp_token), disable_async=True)
64+
token_source._token_sources["endpoint:authDetails"] = MockEndpointTokenSource(another_token)
65+
66+
with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token:
67+
result_token = token_source.token(endpoint="endpoint", auth_details="authDetails")
68+
69+
retrieve_token.assert_not_called()
70+
assert result_token.access_token == another_token.access_token
71+
72+
73+
74+
75+
76+
## These tests are for the old implementation. #TODO: Remove after the new implementation is used
77+
778
info = DataPlaneInfo(authorization_details="authDetails", endpoint_url="url")
879

980
token = Token(access_token="token", token_type="type", expiry=datetime.now() + timedelta(hours=1))
1081

11-
1282
class MockRefresher:
1383

1484
def __init__(self, expected: str):

0 commit comments

Comments
 (0)