1+ from __future__ import annotations
2+
13import threading
24from dataclasses import dataclass
3- from typing import Callable , List
5+ from typing import Callable , List , Optional
6+ from urllib import parse
47
8+ from databricks .sdk import oauth
59from databricks .sdk .oauth import Token
610
11+ URL_ENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded"
12+ JWT_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"
13+ OIDC_TOKEN_PATH = "/oidc/v1/token"
14+
15+
16+ class DataPlaneTokenSource :
17+ """
18+ EXPERIMENTAL Manages token sources for multiple DataPlane endpoints.
19+ """
20+
21+ # TODO: Enable async once its stable. @oauth_credentials_provider must also have async enabled.
22+ def __init__ (self ,
23+ token_exchange_host : str ,
24+ cpts : Callable [[], Token ],
25+ disable_async : Optional [bool ] = True ):
26+ self ._cpts = cpts
27+ self ._token_exchange_host = token_exchange_host
28+ self ._token_sources = {}
29+ self ._disable_async = disable_async
30+ self ._lock = threading .Lock ()
31+
32+ def get_token (self , endpoint , auth_details ):
33+ key = f"{ endpoint } :{ auth_details } "
34+
35+ # First, try to read without acquiring the lock to avoid contention.
36+ # Reads are atomic, so this is safe.
37+ token_source = self ._token_sources .get (key )
38+ if token_source :
39+ return token_source .token ()
40+
41+ # If token_source is not found, acquire the lock and check again.
42+ with self ._lock :
43+ # Another thread might have created it while we were waiting for the lock.
44+ token_source = self ._token_sources .get (key )
45+ if not token_source :
46+ token_source = DataPlaneEndpointTokenSource (self ._token_exchange_host , self ._cpts ,
47+ auth_details , self ._disable_async )
48+ self ._token_sources [key ] = token_source
49+
50+ return token_source .token ()
51+
52+
53+ class DataPlaneEndpointTokenSource (oauth .Refreshable ):
54+ """
55+ EXPERIMENTAL A token source for a specific DataPlane endpoint.
56+ """
57+
58+ def __init__ (self , token_exchange_host : str , cpts : Callable [[], Token ], auth_details : str ,
59+ disable_async : bool ):
60+ super ().__init__ (disable_async = disable_async )
61+ self ._auth_details = auth_details
62+ self ._cpts = cpts
63+ self ._token_exchange_host = token_exchange_host
64+
65+ def refresh (self ) -> Token :
66+ control_plane_token = self ._cpts ()
67+ headers = {"Content-Type" : URL_ENCODED_CONTENT_TYPE }
68+ params = parse .urlencode ({
69+ "grant_type" : JWT_BEARER_GRANT_TYPE ,
70+ "authorization_details" : self ._auth_details ,
71+ "assertion" : control_plane_token .access_token
72+ })
73+ return oauth .retrieve_token (client_id = "" ,
74+ client_secret = "" ,
75+ token_url = self ._token_exchange_host + OIDC_TOKEN_PATH ,
76+ params = params ,
77+ headers = headers )
78+
779
880@dataclass
981class DataPlaneDetails :
@@ -16,6 +88,9 @@ class DataPlaneDetails:
1688 """Token to query the DataPlane endpoint."""
1789
1890
91+ ## Old implementation. #TODO: Remove after the new implementation is used
92+
93+
1994class DataPlaneService :
2095 """Helper class to fetch and manage DataPlane details."""
2196 from .service .serving import DataPlaneInfo
0 commit comments