1+ from __future__ import annotations
2+
13import threading
24from dataclasses import dataclass
3- from typing import Callable , List
5+ from typing import Callable , List , Optional
6+ from urllib import parse
47
8+ from databricks .sdk import oauth
59from databricks .sdk .oauth import Token
610
11+ URL_ENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded"
12+ JWT_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"
13+ OIDC_TOKEN_PATH = "/oidc/v1/token"
14+
15+
16+ class DataPlaneTokenSource :
17+ """
18+ EXPERIMENTAL Manages token sources for multiple DataPlane endpoints.
19+ """
20+
21+ # TODO: Enable async once its stable. @oauth_credentials_provider must also have async enabled.
22+ def __init__ (self , token_exchange_host : str , cpts : Callable [[], Token ], disable_async : Optional [bool ] = True ):
23+ self ._cpts = cpts
24+ self ._token_exchange_host = token_exchange_host
25+ self ._token_sources = {}
26+ self ._disable_async = disable_async
27+ self ._lock = threading .Lock ()
28+
29+ def token (self , endpoint , auth_details ):
30+ key = f"{ endpoint } :{ auth_details } "
31+
32+ # First, try to read without acquiring the lock to avoid contention.
33+ # Reads are atomic, so this is safe.
34+ token_source = self ._token_sources .get (key )
35+ if token_source :
36+ return token_source .token ()
37+
38+ # If token_source is not found, acquire the lock and check again.
39+ with self ._lock :
40+ # Another thread might have created it while we were waiting for the lock.
41+ token_source = self ._token_sources .get (key )
42+ if not token_source :
43+ token_source = DataPlaneEndpointTokenSource (
44+ self ._token_exchange_host , self ._cpts , auth_details , self ._disable_async
45+ )
46+ self ._token_sources [key ] = token_source
47+
48+ return token_source .token ()
49+
50+
51+ class DataPlaneEndpointTokenSource (oauth .Refreshable ):
52+ """
53+ EXPERIMENTAL A token source for a specific DataPlane endpoint.
54+ """
55+
56+ def __init__ (self , token_exchange_host : str , cpts : Callable [[], Token ], auth_details : str , disable_async : bool ):
57+ super ().__init__ (disable_async = disable_async )
58+ self ._auth_details = auth_details
59+ self ._cpts = cpts
60+ self ._token_exchange_host = token_exchange_host
61+
62+ def refresh (self ) -> Token :
63+ control_plane_token = self ._cpts ()
64+ headers = {"Content-Type" : URL_ENCODED_CONTENT_TYPE }
65+ params = parse .urlencode (
66+ {
67+ "grant_type" : JWT_BEARER_GRANT_TYPE ,
68+ "authorization_details" : self ._auth_details ,
69+ "assertion" : control_plane_token .access_token ,
70+ }
71+ )
72+ return oauth .retrieve_token (
73+ client_id = "" ,
74+ client_secret = "" ,
75+ token_url = self ._token_exchange_host + OIDC_TOKEN_PATH ,
76+ params = params ,
77+ headers = headers ,
78+ )
79+
780
881@dataclass
982class DataPlaneDetails :
@@ -17,6 +90,9 @@ class DataPlaneDetails:
1790 """Token to query the DataPlane endpoint."""
1891
1992
93+ ## Old implementation. #TODO: Remove after the new implementation is used
94+
95+
2096class DataPlaneService :
2197 """Helper class to fetch and manage DataPlane details."""
2298
0 commit comments