Skip to content

Commit 4797295

Browse files
committed
Add support for agent user and embedded credential strategies
Signed-off-by: aravind-segu <[email protected]>
1 parent 5339396 commit 4797295

File tree

2 files changed

+162
-36
lines changed

2 files changed

+162
-36
lines changed

databricks/sdk/credentials_provider.py

Lines changed: 74 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import platform
1010
import subprocess
1111
import sys
12+
import threading
1213
import time
1314
from datetime import datetime
1415
from typing import Callable, Dict, List, Optional, Tuple, Union
@@ -313,12 +314,11 @@ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
313314
# detect Azure AD Tenant ID if it's not specified directly
314315
token_endpoint = cfg.oidc_endpoints.token_endpoint
315316
cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint, '').split('/')[0]
316-
inner = ClientCredentials(
317-
client_id=cfg.azure_client_id,
318-
client_secret="", # we have no (rotatable) secrets in OIDC flow
319-
token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
320-
endpoint_params=params,
321-
use_params=True)
317+
inner = ClientCredentials(client_id=cfg.azure_client_id,
318+
client_secret="", # we have no (rotatable) secrets in OIDC flow
319+
token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
320+
endpoint_params=params,
321+
use_params=True)
322322

323323
def refreshed_headers() -> Dict[str, str]:
324324
token = inner.token()
@@ -717,14 +717,18 @@ def inner() -> Dict[str, str]:
717717
# This Code is derived from Mlflow DatabricksModelServingConfigProvider
718718
# https://github.com/mlflow/mlflow/blob/1219e3ef1aac7d337a618a352cd859b336cf5c81/mlflow/legacy_databricks_cli/configure/provider.py#L332
719719
class ModelServingAuthProvider():
720+
USER_CREDENTIALS = "user_credentials"
721+
EMBEDDED_CREDENTIALS = "embedded_credentials"
722+
720723
_MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH = "/var/credentials-secret/model-dependencies-oauth-token"
721724

722-
def __init__(self):
725+
def __init__(self, credential_type):
723726
self.expiry_time = -1
724727
self.current_token = None
725728
self.refresh_duration = 300 # 300 Seconds
729+
self.credential_type = credential_type
726730

727-
def should_fetch_model_serving_environment_oauth(self) -> bool:
731+
def should_fetch_model_serving_environment_oauth() -> bool:
728732
"""
729733
Check whether this is the model serving environment
730734
Additionally check if the oauth token file path exists
@@ -733,15 +737,15 @@ def should_fetch_model_serving_environment_oauth(self) -> bool:
733737
is_in_model_serving_env = (os.environ.get("IS_IN_DB_MODEL_SERVING_ENV")
734738
or os.environ.get("IS_IN_DATABRICKS_MODEL_SERVING_ENV") or "false")
735739
return (is_in_model_serving_env == "true"
736-
and os.path.isfile(self._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH))
740+
and os.path.isfile(ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH))
737741

738742
def get_model_dependency_oauth_token(self, should_retry=True) -> str:
739743
# Use Cached value if it is valid
740744
if self.current_token is not None and self.expiry_time > time.time():
741745
return self.current_token
742746

743747
try:
744-
with open(self._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH) as f:
748+
with open(ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH) as f:
745749
oauth_dict = json.load(f)
746750
self.current_token = oauth_dict["OAUTH_TOKEN"][0]["oauthTokenValue"]
747751
self.expiry_time = time.time() + self.refresh_duration
@@ -758,25 +762,38 @@ def get_model_dependency_oauth_token(self, should_retry=True) -> str:
758762
) from e
759763
return self.current_token
760764

765+
def get_invokers_token(self):
766+
current_thread = threading.current_thread()
767+
thread_data = current_thread.__dict__
768+
invokers_token = None
769+
if "invokers_token" in thread_data:
770+
invokers_token = thread_data["invokers_token"]
771+
772+
if invokers_token is None:
773+
raise RuntimeError("Unable to read Invokers Token in Databricks Model Serving")
774+
775+
return invokers_token
776+
761777
def get_databricks_host_token(self) -> Optional[Tuple[str, str]]:
762-
if not self.should_fetch_model_serving_environment_oauth():
778+
if not ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
763779
return None
764780

765781
# read from DB_MODEL_SERVING_HOST_ENV_VAR if available otherwise MODEL_SERVING_HOST_ENV_VAR
766782
host = os.environ.get("DATABRICKS_MODEL_SERVING_HOST_URL") or os.environ.get(
767783
"DB_MODEL_SERVING_HOST_URL")
768-
token = self.get_model_dependency_oauth_token()
784+
token = self.get_model_dependency_oauth_token(
785+
) if self.credential_type == ModelServingAuthProvider.EMBEDDED_CREDENTIALS else self.get_invokers_token(
786+
)
769787

770788
return (host, token)
771789

772790

773-
@credentials_strategy('model-serving', [])
774-
def model_serving_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
791+
def model_serving_auth_func(cfg: 'Config', credential_type) -> Optional[CredentialsProvider]:
775792
try:
776-
model_serving_auth_provider = ModelServingAuthProvider()
777-
if not model_serving_auth_provider.should_fetch_model_serving_environment_oauth():
793+
if not ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
778794
logger.debug("model-serving: Not in Databricks Model Serving, skipping")
779795
return None
796+
model_serving_auth_provider = ModelServingAuthProvider(credential_type)
780797
host, token = model_serving_auth_provider.get_databricks_host_token()
781798
if token is None:
782799
raise ValueError(
@@ -787,7 +804,6 @@ def model_serving_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
787804
except Exception as e:
788805
logger.warning("Unable to get auth from Databricks Model Serving Environment", exc_info=e)
789806
return None
790-
791807
logger.info("Using Databricks Model Serving Authentication")
792808

793809
def inner() -> Dict[str, str]:
@@ -798,6 +814,11 @@ def inner() -> Dict[str, str]:
798814
return inner
799815

800816

817+
@credentials_strategy('model-serving', [])
818+
def model_serving_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
819+
return model_serving_auth_func(cfg, ModelServingAuthProvider.EMBEDDED_CREDENTIALS)
820+
821+
801822
class DefaultCredentials:
802823
""" Select the first applicable credential provider from the chain """
803824

@@ -840,3 +861,39 @@ def __call__(self, cfg: 'Config') -> CredentialsProvider:
840861
raise ValueError(
841862
f'cannot configure default credentials, please check {auth_flow_url} to configure credentials for your preferred authentication method.'
842863
)
864+
865+
866+
class AgentCredentials(CredentialsStrategy):
867+
868+
def __init__(self, credential_type):
869+
self.credential_type = credential_type
870+
self.default_credentials = DefaultCredentials()
871+
872+
def auth_type(self):
873+
if ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
874+
return "agent_" + self.credential_type
875+
else:
876+
return self.default_credentials.auth_type()
877+
878+
def __call__(self, cfg: 'Config') -> CredentialsProvider:
879+
if ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
880+
header_factory = model_serving_auth_func(cfg, self.credential_type)
881+
if not header_factory:
882+
raise ValueError(
883+
f"Unable to authenticate using {self.credential_type} in Databricks Model Serving Environment"
884+
)
885+
return header_factory
886+
else:
887+
return self.default_credentials(cfg)
888+
889+
890+
class AgentUserCredentials(AgentCredentials):
891+
892+
def __init__(self):
893+
super().__init__(ModelServingAuthProvider.USER_CREDENTIALS)
894+
895+
896+
class AgentEmbeddedCredentials(AgentCredentials):
897+
898+
def __init__(self):
899+
super().__init__(ModelServingAuthProvider.EMBEDDED_CREDENTIALS)

tests/test_model_serving_auth.py

Lines changed: 88 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import threading
12
import time
23

34
import pytest
45

56
from databricks.sdk.core import Config
7+
from databricks.sdk.credentials_provider import (AgentEmbeddedCredentials,
8+
AgentUserCredentials)
69

710
from .conftest import raises
811

@@ -24,7 +27,9 @@
2427
([('IS_IN_DATABRICKS_MODEL_SERVING_ENV', 'true'),
2528
('DATABRICKS_MODEL_SERVING_HOST_URL', 'x')
2629
], ['DB_MODEL_SERVING_HOST_URL'], "tests/testdata/model-serving-test-token"), ])
27-
def test_model_serving_auth(env_values, del_env_values, oauth_file_name, monkeypatch, mocker):
30+
@pytest.mark.parametrize("use_credential_strategy", [True, False])
31+
def test_model_serving_auth(env_values, del_env_values, oauth_file_name, use_credential_strategy, monkeypatch,
32+
mocker):
2833
## In mlflow we check for these two environment variables to return the correct config
2934
for (env_name, env_value) in env_values:
3035
monkeypatch.setenv(env_name, env_value)
@@ -37,26 +42,25 @@ def test_model_serving_auth(env_values, del_env_values, oauth_file_name, monkeyp
3742
"databricks.sdk.credentials_provider.ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH",
3843
oauth_file_name)
3944
mocker.patch('databricks.sdk.config.Config._known_file_config_loader')
40-
41-
cfg = Config()
42-
43-
assert cfg.auth_type == 'model-serving'
45+
if use_credential_strategy:
46+
cfg = Config(credentials_strategy=AgentEmbeddedCredentials())
47+
assert cfg.auth_type == 'agent_embedded_credentials'
48+
else:
49+
cfg = Config()
50+
assert cfg.auth_type == 'model-serving'
4451
headers = cfg.authenticate()
4552
assert (cfg.host == 'x')
4653
# Token defined in the test file
4754
assert headers.get("Authorization") == 'Bearer databricks_sdk_unit_test_token'
4855

4956

50-
@pytest.mark.parametrize(
51-
"env_values, oauth_file_name",
52-
[
53-
([], "invalid_file_name"), # Not in Model Serving and Invalid File Name
54-
([('IS_IN_DB_MODEL_SERVING_ENV', 'true')
55-
], "invalid_file_name"), # In Model Serving and Invalid File Name
56-
([('IS_IN_DATABRICKS_MODEL_SERVING_ENV', 'true')
57-
], "invalid_file_name"), # In Model Serving and Invalid File Name
58-
([], "tests/testdata/model-serving-test-token") # Not in Model Serving and Valid File Name
59-
])
57+
@pytest.mark.parametrize("env_values, oauth_file_name", [
58+
([], "invalid_file_name"), # Not in Model Serving and Invalid File Name
59+
([('IS_IN_DB_MODEL_SERVING_ENV', 'true')], "invalid_file_name"), # In Model Serving and Invalid File Name
60+
([('IS_IN_DATABRICKS_MODEL_SERVING_ENV', 'true')
61+
], "invalid_file_name"), # In Model Serving and Invalid File Name
62+
([], "tests/testdata/model-serving-test-token") # Not in Model Serving and Valid File Name
63+
])
6064
@raises(default_auth_base_error_message)
6165
def test_model_serving_auth_errors(env_values, oauth_file_name, monkeypatch):
6266
# Guarantee that the tests defaults to env variables rather than config file.
@@ -74,7 +78,8 @@ def test_model_serving_auth_errors(env_values, oauth_file_name, monkeypatch):
7478
Config()
7579

7680

77-
def test_model_serving_auth_refresh(monkeypatch, mocker):
81+
@pytest.mark.parametrize("use_credential_strategy", [True, False])
82+
def test_model_serving_auth_refresh(use_credential_strategy, monkeypatch, mocker):
7883
## In mlflow we check for these two environment variables to return the correct config
7984
monkeypatch.setenv('IS_IN_DB_MODEL_SERVING_ENV', 'true')
8085
monkeypatch.setenv('DB_MODEL_SERVING_HOST_URL', 'x')
@@ -85,15 +90,18 @@ def test_model_serving_auth_refresh(monkeypatch, mocker):
8590
"tests/testdata/model-serving-test-token")
8691
mocker.patch('databricks.sdk.config.Config._known_file_config_loader')
8792

88-
cfg = Config()
89-
assert cfg.auth_type == 'model-serving'
93+
if use_credential_strategy:
94+
cfg = Config(credentials_strategy=AgentEmbeddedCredentials())
95+
assert cfg.auth_type == 'agent_embedded_credentials'
96+
else:
97+
cfg = Config()
98+
assert cfg.auth_type == 'model-serving'
9099

91100
current_time = time.time()
92101
headers = cfg.authenticate()
93102
assert (cfg.host == 'x')
94103
assert headers.get(
95104
"Authorization") == 'Bearer databricks_sdk_unit_test_token' # Token defined in the test file
96-
97105
# Simulate refreshing the token by patching to to a new file
98106
monkeypatch.setattr(
99107
"databricks.sdk.credentials_provider.ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH",
@@ -113,3 +121,64 @@ def test_model_serving_auth_refresh(monkeypatch, mocker):
113121
assert (cfg.host == 'x')
114122
# Read V2 now
115123
assert headers.get("Authorization") == 'Bearer databricks_sdk_unit_test_token_v2'
124+
125+
126+
def test_agent_user_credentials(monkeypatch, mocker):
127+
monkeypatch.setenv('IS_IN_DB_MODEL_SERVING_ENV', 'true')
128+
monkeypatch.setenv('DB_MODEL_SERVING_HOST_URL', 'x')
129+
monkeypatch.setattr(
130+
"databricks.sdk.credentials_provider.ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH",
131+
"tests/testdata/model-serving-test-token")
132+
133+
invokers_token_val = "databricks_invokers_token"
134+
current_thread = threading.current_thread()
135+
thread_data = current_thread.__dict__
136+
thread_data["invokers_token"] = invokers_token_val
137+
138+
cfg = Config(credentials_strategy=AgentUserCredentials())
139+
assert cfg.auth_type == 'agent_user_credentials'
140+
141+
headers = cfg.authenticate()
142+
143+
assert (cfg.host == 'x')
144+
assert headers.get("Authorization") == f'Bearer {invokers_token_val}'
145+
146+
# Test updates of invokers token
147+
invokers_token_val = "databricks_invokers_token_v2"
148+
current_thread = threading.current_thread()
149+
thread_data = current_thread.__dict__
150+
thread_data["invokers_token"] = invokers_token_val
151+
152+
headers = cfg.authenticate()
153+
assert (cfg.host == 'x')
154+
assert headers.get("Authorization") == f'Bearer {invokers_token_val}'
155+
156+
157+
# If this credential strategy is being used in a non model serving environments then use default credential strategy instead
158+
def test_agent_user_credentials_in_non_model_serving_environments(monkeypatch):
159+
160+
monkeypatch.setenv('DATABRICKS_HOST', 'x')
161+
monkeypatch.setenv('DATABRICKS_TOKEN', 'token')
162+
163+
cfg = Config(credentials_strategy=AgentUserCredentials())
164+
assert cfg.auth_type == 'pat' # Auth type is PAT as it is no longer in a model serving environment
165+
166+
headers = cfg.authenticate()
167+
168+
assert (cfg.host == 'https://x')
169+
assert headers.get("Authorization") == f'Bearer token'
170+
171+
172+
# If this credential strategy is being used in a non model serving environments then use default credential strategy instead
173+
def test_agent_embedded_credentials_in_non_model_serving_environments(monkeypatch):
174+
175+
monkeypatch.setenv('DATABRICKS_HOST', 'x')
176+
monkeypatch.setenv('DATABRICKS_TOKEN', 'token')
177+
178+
cfg = Config(credentials_strategy=AgentEmbeddedCredentials())
179+
assert cfg.auth_type == 'pat' # Auth type is PAT as it is no longer in a model serving environment
180+
181+
headers = cfg.authenticate()
182+
183+
assert (cfg.host == 'https://x')
184+
assert headers.get("Authorization") == f'Bearer token'

0 commit comments

Comments
 (0)