Skip to content

Commit 20eaca4

Browse files
committed
[Internal] Add DataPlane token source
1 parent e550ca1 commit 20eaca4

File tree

4 files changed

+186
-1
lines changed

4 files changed

+186
-1
lines changed

Makefile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ lint:
2727
test:
2828
pytest -m 'not integration and not benchmark' --cov=databricks --cov-report html tests
2929

30+
integration-dataplane:
31+
pytest -v --cov=databricks --cov-report html tests/integration/test_dataplane.py
32+
3033
integration:
3134
pytest -n auto -m 'integration and not benchmark' --reruns 2 --dist loadgroup --cov=databricks --cov-report html tests
3235

databricks/sdk/data_plane.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,69 @@
1+
from __future__ import annotations
2+
13
import threading
24
from dataclasses import dataclass
3-
from typing import Callable, List
5+
from typing import Callable, List, Optional
6+
from urllib import parse
47

8+
from databricks.sdk import config, oauth
59
from databricks.sdk.oauth import Token
610

11+
URL_ENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded"
12+
JWT_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"
13+
OIDC_TOKEN_PATH = "/oidc/v1/token"
14+
15+
16+
class DataPlaneTokenSource:
17+
"""
18+
EXPERIMENTAL Manages token sources for multiple DataPlane endpoints.
19+
"""
20+
21+
# TODO: Enable async once its stable. @oauth_credentials_provider must also have async enabled.
22+
def __init__(self, cfg: config.Config, disable_async: Optional[bool] = True):
23+
self._cfg = cfg
24+
self._token_sources = {}
25+
self._disable_async = disable_async
26+
27+
def token(self, endpoint: str, auth_details: str):
28+
"""
29+
Get a token for a specific DataPlane endpoint.
30+
:param endpoint: endpoint URL for which to get a token
31+
:param auth_details: authorization details used to generate the token
32+
:return: a token for the specified endpoint
33+
"""
34+
key = f"{endpoint}:{auth_details}"
35+
token_source = self._token_sources.get(key)
36+
if not token_source:
37+
token_source = DataPlaneEndpointTokenSource(self._cfg, auth_details, self._disable_async)
38+
self._token_sources[key] = token_source
39+
return token_source.token()
40+
41+
42+
class DataPlaneEndpointTokenSource(oauth.Refreshable):
43+
"""
44+
EXPERIMENTAL A token source for a specific DataPlane endpoint.
45+
"""
46+
47+
def __init__(self, cfg: config.Config, auth_details: str, disable_async: bool):
48+
super().__init__(disable_async=disable_async)
49+
self._auth_details = auth_details
50+
self._cpts = cfg.oauth_token
51+
self._cfg = cfg
52+
53+
def refresh(self) -> Token:
54+
control_plane_token = self._cpts()
55+
headers = {"Content-Type": URL_ENCODED_CONTENT_TYPE}
56+
params = parse.urlencode({
57+
"grant_type": JWT_BEARER_GRANT_TYPE,
58+
"authorization_details": self._auth_details,
59+
"assertion": control_plane_token.access_token
60+
})
61+
return oauth.retrieve_token(client_id=self._cfg.client_id,
62+
client_secret=self._cfg.client_secret,
63+
token_url=self._cfg.host + OIDC_TOKEN_PATH,
64+
params=params,
65+
headers=headers)
66+
767

868
@dataclass
969
class DataPlaneDetails:
@@ -16,6 +76,9 @@ class DataPlaneDetails:
1676
"""Token to query the DataPlane endpoint."""
1777

1878

79+
## Old implementation. #TODO: Remove after the new implementation is used
80+
81+
1982
class DataPlaneService:
2083
"""Helper class to fetch and manage DataPlane details."""
2184
from .service.serving import DataPlaneInfo
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import base64
2+
import io
3+
import json
4+
import re
5+
import shutil
6+
import subprocess
7+
import sys
8+
import typing
9+
import urllib.parse
10+
from functools import partial
11+
from pathlib import Path
12+
13+
import pytest
14+
15+
from databricks.sdk.data_plane import DataPlaneTokenSource
16+
from databricks.sdk.service.compute import (ClusterSpec, DataSecurityMode,
17+
Library, ResultType, SparkVersion)
18+
from databricks.sdk.service.jobs import NotebookTask, Task, ViewType
19+
from databricks.sdk.service.workspace import ImportFormat
20+
21+
22+
def test_data_plane_token_source(ucws, env_or_skip):
23+
endpoint = env_or_skip("SERVING_ENDPOINT_NAME")
24+
serving_endpoint = ucws.serving_endpoints.get(endpoint)
25+
assert serving_endpoint.data_plane_info is not None
26+
assert serving_endpoint.data_plane_info.query_info is not None
27+
28+
info = serving_endpoint.data_plane_info.query_info
29+
30+
ts = DataPlaneTokenSource(ucws.config)
31+
dp_token = ts.token(info.endpoint_url, info.authorization_details)
32+
33+
assert dp_token.valid
34+
35+

tests/test_data_plane.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,93 @@
11
from datetime import datetime, timedelta
2+
from unittest.mock import patch
3+
from urllib import parse
24

5+
from databricks.sdk import data_plane, oauth
36
from databricks.sdk.data_plane import DataPlaneService
47
from databricks.sdk.oauth import Token
58
from databricks.sdk.service.serving import DataPlaneInfo
69

10+
cp_token = Token(access_token="control plane token",
11+
token_type="type",
12+
expiry=datetime.now() + timedelta(hours=1))
13+
dp_token = Token(access_token="data plane token",
14+
token_type="type",
15+
expiry=datetime.now() + timedelta(hours=1))
16+
17+
18+
def success_callable(token: oauth.Token):
19+
20+
def success() -> oauth.Token:
21+
return token
22+
23+
return success
24+
25+
26+
def test_endpoint_token_source_get_token(config):
27+
config.client_id = "client_id"
28+
config.client_secret = "client_secret"
29+
config.oauth_token = success_callable(cp_token)
30+
token_source = data_plane.DataPlaneEndpointTokenSource(config, "authDetails", disable_async=True)
31+
32+
with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token:
33+
token_source.token()
34+
35+
retrieve_token.assert_called_once()
36+
args, kwargs = retrieve_token.call_args
37+
38+
assert kwargs["client_id"] == config.client_id
39+
assert kwargs["client_secret"] == config.client_secret
40+
assert kwargs["token_url"] == config.host + "/oidc/v1/token"
41+
assert kwargs["params"] == parse.urlencode({
42+
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
43+
"authorization_details": "authDetails",
44+
"assertion": cp_token.access_token
45+
})
46+
assert kwargs["headers"] == {"Content-Type": "application/x-www-form-urlencoded"}
47+
48+
49+
def test_token_source_get_token_not_existing(config):
50+
config.client_id = "client_id"
51+
config.client_secret = "client_secret"
52+
config.oauth_token = success_callable(cp_token)
53+
token_source = data_plane.DataPlaneTokenSource(config, disable_async=True)
54+
55+
with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token:
56+
result_token = token_source.token(endpoint="endpoint", auth_details="authDetails")
57+
58+
retrieve_token.assert_called_once()
59+
assert result_token.access_token == dp_token.access_token
60+
assert "endpoint:authDetails" in token_source._token_sources
61+
62+
63+
class MockEndpointTokenSource:
64+
65+
def __init__(self, token: oauth.Token):
66+
self._token = token
67+
68+
def token(self):
69+
return self._token
70+
71+
72+
def test_token_source_get_token_existing(config):
73+
config.client_id = "client_id"
74+
config.client_secret = "client_secret"
75+
config.oauth_token = success_callable(cp_token)
76+
another_token = Token(access_token="another token",
77+
token_type="type",
78+
expiry=datetime.now() + timedelta(hours=1))
79+
token_source = data_plane.DataPlaneTokenSource(config, disable_async=True)
80+
token_source._token_sources["endpoint:authDetails"] = MockEndpointTokenSource(another_token)
81+
82+
with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token:
83+
result_token = token_source.token(endpoint="endpoint", auth_details="authDetails")
84+
85+
retrieve_token.assert_not_called()
86+
assert result_token.access_token == another_token.access_token
87+
88+
89+
## These tests are for the old implementation. #TODO: Remove after the new implementation is used
90+
791
info = DataPlaneInfo(authorization_details="authDetails", endpoint_url="url")
892

993
token = Token(access_token="token", token_type="type", expiry=datetime.now() + timedelta(hours=1))

0 commit comments

Comments
 (0)