Skip to content

Commit 5530671

Browse files
committed
black fmt
1 parent 86243af commit 5530671

File tree

3 files changed

+34
-40
lines changed

3 files changed

+34
-40
lines changed

databricks/sdk/data_plane.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,14 @@ class DataPlaneTokenSource:
1919
"""
2020

2121
# TODO: Enable async once its stable. @oauth_credentials_provider must also have async enabled.
22-
def __init__(self,
23-
token_exchange_host: str,
24-
cpts: Callable[[], Token],
25-
disable_async: Optional[bool] = True):
22+
def __init__(self, token_exchange_host: str, cpts: Callable[[], Token], disable_async: Optional[bool] = True):
2623
self._cpts = cpts
2724
self._token_exchange_host = token_exchange_host
2825
self._token_sources = {}
2926
self._disable_async = disable_async
3027
self._lock = threading.Lock()
3128

32-
def get_token(self, endpoint, auth_details):
29+
def token(self, endpoint, auth_details):
3330
key = f"{endpoint}:{auth_details}"
3431

3532
# First, try to read without acquiring the lock to avoid contention.
@@ -43,8 +40,9 @@ def get_token(self, endpoint, auth_details):
4340
# Another thread might have created it while we were waiting for the lock.
4441
token_source = self._token_sources.get(key)
4542
if not token_source:
46-
token_source = DataPlaneEndpointTokenSource(self._token_exchange_host, self._cpts,
47-
auth_details, self._disable_async)
43+
token_source = DataPlaneEndpointTokenSource(
44+
self._token_exchange_host, self._cpts, auth_details, self._disable_async
45+
)
4846
self._token_sources[key] = token_source
4947

5048
return token_source.token()
@@ -55,8 +53,7 @@ class DataPlaneEndpointTokenSource(oauth.Refreshable):
5553
EXPERIMENTAL A token source for a specific DataPlane endpoint.
5654
"""
5755

58-
def __init__(self, token_exchange_host: str, cpts: Callable[[], Token], auth_details: str,
59-
disable_async: bool):
56+
def __init__(self, token_exchange_host: str, cpts: Callable[[], Token], auth_details: str, disable_async: bool):
6057
super().__init__(disable_async=disable_async)
6158
self._auth_details = auth_details
6259
self._cpts = cpts
@@ -65,16 +62,20 @@ def __init__(self, token_exchange_host: str, cpts: Callable[[], Token], auth_det
6562
def refresh(self) -> Token:
6663
control_plane_token = self._cpts()
6764
headers = {"Content-Type": URL_ENCODED_CONTENT_TYPE}
68-
params = parse.urlencode({
69-
"grant_type": JWT_BEARER_GRANT_TYPE,
70-
"authorization_details": self._auth_details,
71-
"assertion": control_plane_token.access_token
72-
})
73-
return oauth.retrieve_token(client_id="",
74-
client_secret="",
75-
token_url=self._token_exchange_host + OIDC_TOKEN_PATH,
76-
params=params,
77-
headers=headers)
65+
params = parse.urlencode(
66+
{
67+
"grant_type": JWT_BEARER_GRANT_TYPE,
68+
"authorization_details": self._auth_details,
69+
"assertion": control_plane_token.access_token,
70+
}
71+
)
72+
return oauth.retrieve_token(
73+
client_id="",
74+
client_secret="",
75+
token_url=self._token_exchange_host + OIDC_TOKEN_PATH,
76+
params=params,
77+
headers=headers,
78+
)
7879

7980

8081
@dataclass
File renamed without changes.

tests/test_data_plane.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,8 @@
77
from databricks.sdk.oauth import Token
88
from databricks.sdk.service.serving import DataPlaneInfo
99

10-
cp_token = Token(access_token="control plane token",
11-
token_type="type",
12-
expiry=datetime.now() + timedelta(hours=1))
13-
dp_token = Token(access_token="data plane token",
14-
token_type="type",
15-
expiry=datetime.now() + timedelta(hours=1))
10+
cp_token = Token(access_token="control plane token", token_type="type", expiry=datetime.now() + timedelta(hours=1))
11+
dp_token = Token(access_token="data plane token", token_type="type", expiry=datetime.now() + timedelta(hours=1))
1612

1713

1814
def success_callable(token: oauth.Token):
@@ -24,10 +20,9 @@ def success() -> oauth.Token:
2420

2521

2622
def test_endpoint_token_source_get_token(config):
27-
token_source = data_plane.DataPlaneEndpointTokenSource(config.host,
28-
success_callable(cp_token),
29-
"authDetails",
30-
disable_async=True)
23+
token_source = data_plane.DataPlaneEndpointTokenSource(
24+
config.host, success_callable(cp_token), "authDetails", disable_async=True
25+
)
3126

3227
with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token:
3328
token_source.token()
@@ -36,18 +31,18 @@ def test_endpoint_token_source_get_token(config):
3631
args, kwargs = retrieve_token.call_args
3732

3833
assert kwargs["token_url"] == config.host + "/oidc/v1/token"
39-
assert kwargs["params"] == parse.urlencode({
40-
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
41-
"authorization_details": "authDetails",
42-
"assertion": cp_token.access_token
43-
})
34+
assert kwargs["params"] == parse.urlencode(
35+
{
36+
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
37+
"authorization_details": "authDetails",
38+
"assertion": cp_token.access_token,
39+
}
40+
)
4441
assert kwargs["headers"] == {"Content-Type": "application/x-www-form-urlencoded"}
4542

4643

4744
def test_token_source_get_token_not_existing(config):
48-
token_source = data_plane.DataPlaneTokenSource(config.host,
49-
success_callable(cp_token),
50-
disable_async=True)
45+
token_source = data_plane.DataPlaneTokenSource(config.host, success_callable(cp_token), disable_async=True)
5146

5247
with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token:
5348
result_token = token_source.token(endpoint="endpoint", auth_details="authDetails")
@@ -67,9 +62,7 @@ def token(self):
6762

6863

6964
def test_token_source_get_token_existing(config):
70-
another_token = Token(access_token="another token",
71-
token_type="type",
72-
expiry=datetime.now() + timedelta(hours=1))
65+
another_token = Token(access_token="another token", token_type="type", expiry=datetime.now() + timedelta(hours=1))
7366
token_source = data_plane.DataPlaneTokenSource(config.host, success_callable(token), disable_async=True)
7467
token_source._token_sources["endpoint:authDetails"] = MockEndpointTokenSource(another_token)
7568

0 commit comments

Comments
 (0)