|
1 | 1 | from datetime import datetime, timedelta |
| 2 | +from unittest.mock import patch |
| 3 | +from urllib import parse |
2 | 4 |
|
| 5 | +from databricks.sdk import data_plane, oauth |
3 | 6 | from databricks.sdk.data_plane import DataPlaneService |
4 | 7 | from databricks.sdk.oauth import Token |
5 | 8 | from databricks.sdk.service.serving import DataPlaneInfo |
6 | 9 |
|
| 10 | +cp_token = Token(access_token="control plane token", |
| 11 | + token_type="type", |
| 12 | + expiry=datetime.now() + timedelta(hours=1)) |
| 13 | +dp_token = Token(access_token="data plane token", |
| 14 | + token_type="type", |
| 15 | + expiry=datetime.now() + timedelta(hours=1)) |
| 16 | + |
| 17 | + |
| 18 | +def success_callable(token: oauth.Token): |
| 19 | + |
| 20 | + def success() -> oauth.Token: |
| 21 | + return token |
| 22 | + |
| 23 | + return success |
| 24 | + |
| 25 | + |
| 26 | +def test_endpoint_token_source_get_token(config): |
| 27 | + config.client_id = "client_id" |
| 28 | + config.client_secret = "client_secret" |
| 29 | + config.oauth_token = success_callable(cp_token) |
| 30 | + token_source = data_plane.DataPlaneEndpointTokenSource(config, "authDetails", disable_async=True) |
| 31 | + |
| 32 | + with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token: |
| 33 | + token_source.token() |
| 34 | + |
| 35 | + retrieve_token.assert_called_once() |
| 36 | + args, kwargs = retrieve_token.call_args |
| 37 | + |
| 38 | + assert kwargs["client_id"] == config.client_id |
| 39 | + assert kwargs["client_secret"] == config.client_secret |
| 40 | + assert kwargs["token_url"] == config.host + "/oidc/v1/token" |
| 41 | + assert kwargs["params"] == parse.urlencode({ |
| 42 | + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", |
| 43 | + "authorization_details": "authDetails", |
| 44 | + "assertion": cp_token.access_token |
| 45 | + }) |
| 46 | + assert kwargs["headers"] == {"Content-Type": "application/x-www-form-urlencoded"} |
| 47 | + |
| 48 | + |
| 49 | +def test_token_source_get_token_not_existing(config): |
| 50 | + config.client_id = "client_id" |
| 51 | + config.client_secret = "client_secret" |
| 52 | + config.oauth_token = success_callable(cp_token) |
| 53 | + token_source = data_plane.DataPlaneTokenSource(config, disable_async=True) |
| 54 | + |
| 55 | + with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token: |
| 56 | + result_token = token_source.token(endpoint="endpoint", auth_details="authDetails") |
| 57 | + |
| 58 | + retrieve_token.assert_called_once() |
| 59 | + assert result_token.access_token == dp_token.access_token |
| 60 | + assert "endpoint:authDetails" in token_source._token_sources |
| 61 | + |
| 62 | + |
| 63 | +class MockEndpointTokenSource: |
| 64 | + |
| 65 | + def __init__(self, token: oauth.Token): |
| 66 | + self._token = token |
| 67 | + |
| 68 | + def token(self): |
| 69 | + return self._token |
| 70 | + |
| 71 | + |
| 72 | +def test_token_source_get_token_existing(config): |
| 73 | + config.client_id = "client_id" |
| 74 | + config.client_secret = "client_secret" |
| 75 | + config.oauth_token = success_callable(cp_token) |
| 76 | + another_token = Token(access_token="another token", |
| 77 | + token_type="type", |
| 78 | + expiry=datetime.now() + timedelta(hours=1)) |
| 79 | + token_source = data_plane.DataPlaneTokenSource(config, disable_async=True) |
| 80 | + token_source._token_sources["endpoint:authDetails"] = MockEndpointTokenSource(another_token) |
| 81 | + |
| 82 | + with patch("databricks.sdk.oauth.retrieve_token", return_value=dp_token) as retrieve_token: |
| 83 | + result_token = token_source.token(endpoint="endpoint", auth_details="authDetails") |
| 84 | + |
| 85 | + retrieve_token.assert_not_called() |
| 86 | + assert result_token.access_token == another_token.access_token |
| 87 | + |
| 88 | + |
| 89 | +## These tests are for the old implementation. #TODO: Remove after the new implementation is used |
| 90 | + |
7 | 91 | info = DataPlaneInfo(authorization_details="authDetails", endpoint_url="url") |
8 | 92 |
|
9 | 93 | token = Token(access_token="token", token_type="type", expiry=datetime.now() + timedelta(hours=1)) |
|
0 commit comments