|
5 | 5 |
|
6 | 6 | from databricks.sdk.core import Config |
7 | 7 | from databricks.sdk.credentials_provider import ModelServingUserCredentials |
| 8 | +from greenlet import greenlet, getcurrent |
8 | 9 |
|
9 | 10 | from .conftest import raises |
10 | 11 |
|
@@ -217,7 +218,41 @@ def authenticate(): |
217 | 218 | thread.start() |
218 | 219 | thread.join() |
219 | 220 | assert successful_authentication_event.is_set() |
| 221 | + del current_thread.__dict__["invokers_token"] # Clean up invokers token |
220 | 222 |
|
| 223 | +def test_agent_user_credentials_via_greenlet(monkeypatch, mocker): |
| 224 | + # Guarantee that the tests defaults to env variables rather than config file. |
| 225 | + # |
| 226 | + # TODO: this is hacky and we should find a better way to tell the config |
| 227 | + # that it should not read from the config file. |
| 228 | + monkeypatch.setenv("DATABRICKS_CONFIG_FILE", "x") |
| 229 | + |
| 230 | + monkeypatch.setenv("IS_IN_DB_MODEL_SERVING_ENV", "true") |
| 231 | + monkeypatch.setenv("DB_MODEL_SERVING_HOST_URL", "x") |
| 232 | + monkeypatch.setattr( |
| 233 | + "databricks.sdk.credentials_provider.ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH", |
| 234 | + "tests/testdata/model-serving-test-token", |
| 235 | + ) |
| 236 | + |
| 237 | + invokers_token_val = "databricks_invokers_token" |
| 238 | + greenlet_local = getcurrent() |
| 239 | + setattr(greenlet_local, "invokers_token", invokers_token_val) |
| 240 | + |
| 241 | + cfg = Config(credentials_strategy=ModelServingUserCredentials()) |
| 242 | + assert cfg.auth_type == "model_serving_user_credentials" |
| 243 | + |
| 244 | + headers = cfg.authenticate() |
| 245 | + |
| 246 | + assert cfg.host == "x" |
| 247 | + assert headers.get("Authorization") == f"Bearer {invokers_token_val}" |
| 248 | + |
| 249 | + # Test updates of invokers token |
| 250 | + invokers_token_val = "databricks_invokers_token_v2" |
| 251 | + setattr(greenlet_local, "invokers_token", invokers_token_val) |
| 252 | + |
| 253 | + headers = cfg.authenticate() |
| 254 | + assert cfg.host == "x" |
| 255 | + assert headers.get("Authorization") == f"Bearer {invokers_token_val}" |
221 | 256 |
|
222 | 257 | # If this credential strategy is being used in a non model serving environments then use default credential strategy instead |
223 | 258 | def test_agent_user_credentials_in_non_model_serving_environments(monkeypatch): |
|
0 commit comments