|
1 | 1 | from datetime import datetime, timedelta |
| 2 | +from unittest.mock import patch |
| 3 | +from urllib import parse |
2 | 4 |
|
| 5 | +from databricks.sdk import data_plane, oauth |
3 | 6 | from databricks.sdk.data_plane import DataPlaneService |
4 | 7 | from databricks.sdk.oauth import Token |
5 | 8 | from databricks.sdk.service.serving import DataPlaneInfo |
6 | 9 |
|
| 10 | + |
| 11 | +cp_token = Token(access_token="control plane token", token_type="type", expiry=datetime.now() + timedelta(hours=1)) |
| 12 | +dp_token = Token(access_token="data plane token", token_type="type", expiry=datetime.now() + timedelta(hours=1)) |
| 13 | + |
| 14 | +def success_callable(token: oauth.Token): |
| 15 | + def success() -> oauth.Token: |
| 16 | + return token |
| 17 | + return success |
| 18 | + |
| 19 | +def patch_retrieve_token(token: oauth.Token): |
| 20 | + return patch("databricks.sdk.oauth.retrieve_token", return_value=token) |
| 21 | + |
| 22 | +def test_endpoint_token_source_get_token(config): |
| 23 | + config.client_id = "client_id" |
| 24 | + config.client_secret = "client_secret" |
| 25 | + token_source = data_plane.DataPlaneEndpointTokenSource(config, success_callable(cp_token), "authDetails", disable_async=True) |
| 26 | + |
| 27 | + with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token: |
| 28 | + token_source.token() |
| 29 | + |
| 30 | + retrieve_token.assert_called_once() |
| 31 | + args, kwargs = retrieve_token.call_args |
| 32 | + |
| 33 | + assert kwargs["client_id"] == config.client_id |
| 34 | + assert kwargs["client_secret"] == config.client_secret |
| 35 | + assert kwargs["token_url"] == config.host + "/oidc/v1/token" |
| 36 | + assert kwargs["params"] == parse.urlencode({"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", "authorization_details": "authDetails", "assertion": cp_token.access_token}) |
| 37 | + assert kwargs["headers"] == {"Content-Type": "application/x-www-form-urlencoded"} |
| 38 | + |
| 39 | +def test_token_source_get_token_not_existing(config): |
| 40 | + config.client_id = "client_id" |
| 41 | + config.client_secret = "client_secret" |
| 42 | + token_source = data_plane.DataPlaneTokenSource(config, success_callable(cp_token), disable_async=True) |
| 43 | + |
| 44 | + with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token: |
| 45 | + result_token = token_source.token(endpoint="endpoint", auth_details="authDetails") |
| 46 | + |
| 47 | + retrieve_token.assert_called_once() |
| 48 | + assert result_token.access_token == dp_token.access_token |
| 49 | + assert "endpoint:authDetails" in token_source._token_sources |
| 50 | + |
| 51 | +class MockEndpointTokenSource: |
| 52 | + |
| 53 | + def __init__(self, token: oauth.Token): |
| 54 | + self._token = token |
| 55 | + |
| 56 | + def token(self): |
| 57 | + return self._token |
| 58 | + |
| 59 | +def test_token_source_get_token_existing(config): |
| 60 | + config.client_id = "client_id" |
| 61 | + config.client_secret = "client_secret" |
| 62 | + another_token = Token(access_token="another token", token_type="type", expiry=datetime.now() + timedelta(hours=1)) |
| 63 | + token_source = data_plane.DataPlaneTokenSource(config, success_callable(cp_token), disable_async=True) |
| 64 | + token_source._token_sources["endpoint:authDetails"] = MockEndpointTokenSource(another_token) |
| 65 | + |
| 66 | + with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token: |
| 67 | + result_token = token_source.token(endpoint="endpoint", auth_details="authDetails") |
| 68 | + |
| 69 | + retrieve_token.assert_not_called() |
| 70 | + assert result_token.access_token == another_token.access_token |
| 71 | + |
| 72 | + |
| 73 | + |
| 74 | + |
| 75 | + |
| 76 | +## These tests are for the old implementation. #TODO: Remove after the new implementation is used |
| 77 | + |
7 | 78 | info = DataPlaneInfo(authorization_details="authDetails", endpoint_url="url") |
8 | 79 |
|
9 | 80 | token = Token(access_token="token", token_type="type", expiry=datetime.now() + timedelta(hours=1)) |
10 | 81 |
|
11 | | - |
12 | 82 | class MockRefresher: |
13 | 83 |
|
14 | 84 | def __init__(self, expected: str): |
|
0 commit comments