Skip to content

Commit 41f5f4b

Browse files
authored
[Feature] Introduce new Credential Strategies for Agents (#882)
## What changes are proposed in this pull request? This PR introduces two new credential strategies for Agents, (AgentEmbeddedCredentials, AgentUserCredentials). Agents currently use the databricks.sdk in order to interact with databricks resources. However the authentication method for these resources is a little unique where we store the token for the authentication in a Credential File on the Kubernetes Container. Therefore in the past we added the Model Serving Credential Strategy to the defaultCredentials list to read this file. Now we want to introduce a new authentication where the user's token is instead stored in a thread local variable. Agent users will initialize clients as follows: ``` from databricks.sdk.credentials_provider import ModelServingUserCredentials invokers_client = WorkspaceClient(credential_strategy = ModelServingUserCredentials()) definers_client = WorkspaceClient() ``` Then the users can use the invoker_client to interact with resources with the invokers token or the definers_client to interact with resources using the old method of authentication. Additionally as the users will be using these clients to test their code locally in Databricks Notebooks, if the code is not being run on model serving environments, users need to be able to authenticate using the DefaultCredential strategies. More details: https://docs.google.com/document/d/14qLVjyxIAk581w287TWElstIeh8-DR30ab9Z6B_Vydg/edit?usp=sharing ## How is this tested? Added unit tests --------- Signed-off-by: aravind-segu <[email protected]>
1 parent 3c391a0 commit 41f5f4b

File tree

2 files changed

+119
-18
lines changed

2 files changed

+119
-18
lines changed

databricks/sdk/credentials_provider.py

Lines changed: 71 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import platform
1010
import subprocess
1111
import sys
12+
import threading
1213
import time
1314
from datetime import datetime
1415
from typing import Callable, Dict, List, Optional, Tuple, Union
@@ -723,14 +724,17 @@ def inner() -> Dict[str, str]:
723724
# This Code is derived from Mlflow DatabricksModelServingConfigProvider
724725
# https://github.com/mlflow/mlflow/blob/1219e3ef1aac7d337a618a352cd859b336cf5c81/mlflow/legacy_databricks_cli/configure/provider.py#L332
725726
class ModelServingAuthProvider():
727+
USER_CREDENTIALS = "user_credentials"
728+
726729
_MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH = "/var/credentials-secret/model-dependencies-oauth-token"
727730

728-
def __init__(self):
731+
def __init__(self, credential_type: Optional[str]):
729732
self.expiry_time = -1
730733
self.current_token = None
731734
self.refresh_duration = 300 # 300 Seconds
735+
self.credential_type = credential_type
732736

733-
def should_fetch_model_serving_environment_oauth(self) -> bool:
737+
def should_fetch_model_serving_environment_oauth() -> bool:
734738
"""
735739
Check whether this is the model serving environment
736740
Additionally check if the oauth token file path exists
@@ -739,15 +743,15 @@ def should_fetch_model_serving_environment_oauth(self) -> bool:
739743
is_in_model_serving_env = (os.environ.get("IS_IN_DB_MODEL_SERVING_ENV")
740744
or os.environ.get("IS_IN_DATABRICKS_MODEL_SERVING_ENV") or "false")
741745
return (is_in_model_serving_env == "true"
742-
and os.path.isfile(self._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH))
746+
and os.path.isfile(ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH))
743747

744-
def get_model_dependency_oauth_token(self, should_retry=True) -> str:
748+
def _get_model_dependency_oauth_token(self, should_retry=True) -> str:
745749
# Use Cached value if it is valid
746750
if self.current_token is not None and self.expiry_time > time.time():
747751
return self.current_token
748752

749753
try:
750-
with open(self._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH) as f:
754+
with open(ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH) as f:
751755
oauth_dict = json.load(f)
752756
self.current_token = oauth_dict["OAUTH_TOKEN"][0]["oauthTokenValue"]
753757
self.expiry_time = time.time() + self.refresh_duration
@@ -757,32 +761,43 @@ def get_model_dependency_oauth_token(self, should_retry=True) -> str:
757761
logger.warning("Unable to read oauth token on first attmept in Model Serving Environment",
758762
exc_info=e)
759763
time.sleep(0.5)
760-
return self.get_model_dependency_oauth_token(should_retry=False)
764+
return self._get_model_dependency_oauth_token(should_retry=False)
761765
else:
762766
raise RuntimeError(
763767
"Unable to read OAuth credentials from the file mounted in Databricks Model Serving"
764768
) from e
765769
return self.current_token
766770

771+
def _get_invokers_token(self):
772+
current_thread = threading.current_thread()
773+
thread_data = current_thread.__dict__
774+
invokers_token = None
775+
if "invokers_token" in thread_data:
776+
invokers_token = thread_data["invokers_token"]
777+
778+
if invokers_token is None:
779+
raise RuntimeError("Unable to read Invokers Token in Databricks Model Serving")
780+
781+
return invokers_token
782+
767783
def get_databricks_host_token(self) -> Optional[Tuple[str, str]]:
768-
if not self.should_fetch_model_serving_environment_oauth():
784+
if not ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
769785
return None
770786

771787
# read from DB_MODEL_SERVING_HOST_ENV_VAR if available otherwise MODEL_SERVING_HOST_ENV_VAR
772788
host = os.environ.get("DATABRICKS_MODEL_SERVING_HOST_URL") or os.environ.get(
773789
"DB_MODEL_SERVING_HOST_URL")
774-
token = self.get_model_dependency_oauth_token()
775790

776-
return (host, token)
791+
if self.credential_type == ModelServingAuthProvider.USER_CREDENTIALS:
792+
return (host, self._get_invokers_token())
793+
else:
794+
return (host, self._get_model_dependency_oauth_token())
777795

778796

779-
@credentials_strategy('model-serving', [])
780-
def model_serving_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
797+
def model_serving_auth_visitor(cfg: 'Config',
798+
credential_type: Optional[str] = None) -> Optional[CredentialsProvider]:
781799
try:
782-
model_serving_auth_provider = ModelServingAuthProvider()
783-
if not model_serving_auth_provider.should_fetch_model_serving_environment_oauth():
784-
logger.debug("model-serving: Not in Databricks Model Serving, skipping")
785-
return None
800+
model_serving_auth_provider = ModelServingAuthProvider(credential_type)
786801
host, token = model_serving_auth_provider.get_databricks_host_token()
787802
if token is None:
788803
raise ValueError(
@@ -793,7 +808,6 @@ def model_serving_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
793808
except Exception as e:
794809
logger.warning("Unable to get auth from Databricks Model Serving Environment", exc_info=e)
795810
return None
796-
797811
logger.info("Using Databricks Model Serving Authentication")
798812

799813
def inner() -> Dict[str, str]:
@@ -804,6 +818,15 @@ def inner() -> Dict[str, str]:
804818
return inner
805819

806820

821+
@credentials_strategy('model-serving', [])
822+
def model_serving_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
823+
if not ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
824+
logger.debug("model-serving: Not in Databricks Model Serving, skipping")
825+
return None
826+
827+
return model_serving_auth_visitor(cfg)
828+
829+
807830
class DefaultCredentials:
808831
""" Select the first applicable credential provider from the chain """
809832

@@ -846,3 +869,35 @@ def __call__(self, cfg: 'Config') -> CredentialsProvider:
846869
raise ValueError(
847870
f'cannot configure default credentials, please check {auth_flow_url} to configure credentials for your preferred authentication method.'
848871
)
872+
873+
874+
class ModelServingUserCredentials(CredentialsStrategy):
875+
"""
876+
This credential strategy is designed for authenticating the Databricks SDK in the model serving environment using user-specific rights.
877+
In the model serving environment, the strategy retrieves a downscoped user token from the thread-local variable.
878+
In any other environments, the class defaults to the DefaultCredentialStrategy.
879+
To use this credential strategy, instantiate the WorkspaceClient with the ModelServingUserCredentials strategy as follows:
880+
881+
invokers_client = WorkspaceClient(credential_strategy = ModelServingUserCredentials())
882+
"""
883+
884+
def __init__(self):
885+
self.credential_type = ModelServingAuthProvider.USER_CREDENTIALS
886+
self.default_credentials = DefaultCredentials()
887+
888+
def auth_type(self):
889+
if ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
890+
return "model_serving_" + self.credential_type
891+
else:
892+
return self.default_credentials.auth_type()
893+
894+
def __call__(self, cfg: 'Config') -> CredentialsProvider:
895+
if ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
896+
header_factory = model_serving_auth_visitor(cfg, self.credential_type)
897+
if not header_factory:
898+
raise ValueError(
899+
f"Unable to authenticate using {self.credential_type} in Databricks Model Serving Environment"
900+
)
901+
return header_factory
902+
else:
903+
return self.default_credentials(cfg)

tests/test_model_serving_auth.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import threading
12
import time
23

34
import pytest
45

56
from databricks.sdk.core import Config
7+
from databricks.sdk.credentials_provider import ModelServingUserCredentials
68

79
from .conftest import raises
810

@@ -39,7 +41,6 @@ def test_model_serving_auth(env_values, del_env_values, oauth_file_name, monkeyp
3941
mocker.patch('databricks.sdk.config.Config._known_file_config_loader')
4042

4143
cfg = Config()
42-
4344
assert cfg.auth_type == 'model-serving'
4445
headers = cfg.authenticate()
4546
assert (cfg.host == 'x')
@@ -93,7 +94,6 @@ def test_model_serving_auth_refresh(monkeypatch, mocker):
9394
assert (cfg.host == 'x')
9495
assert headers.get(
9596
"Authorization") == 'Bearer databricks_sdk_unit_test_token' # Token defined in the test file
96-
9797
# Simulate refreshing the token by patching to to a new file
9898
monkeypatch.setattr(
9999
"databricks.sdk.credentials_provider.ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH",
@@ -113,3 +113,49 @@ def test_model_serving_auth_refresh(monkeypatch, mocker):
113113
assert (cfg.host == 'x')
114114
# Read V2 now
115115
assert headers.get("Authorization") == 'Bearer databricks_sdk_unit_test_token_v2'
116+
117+
118+
def test_agent_user_credentials(monkeypatch, mocker):
119+
monkeypatch.setenv('IS_IN_DB_MODEL_SERVING_ENV', 'true')
120+
monkeypatch.setenv('DB_MODEL_SERVING_HOST_URL', 'x')
121+
monkeypatch.setattr(
122+
"databricks.sdk.credentials_provider.ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH",
123+
"tests/testdata/model-serving-test-token")
124+
125+
invokers_token_val = "databricks_invokers_token"
126+
current_thread = threading.current_thread()
127+
thread_data = current_thread.__dict__
128+
thread_data["invokers_token"] = invokers_token_val
129+
130+
cfg = Config(credentials_strategy=ModelServingUserCredentials())
131+
assert cfg.auth_type == 'model_serving_user_credentials'
132+
133+
headers = cfg.authenticate()
134+
135+
assert (cfg.host == 'x')
136+
assert headers.get("Authorization") == f'Bearer {invokers_token_val}'
137+
138+
# Test updates of invokers token
139+
invokers_token_val = "databricks_invokers_token_v2"
140+
current_thread = threading.current_thread()
141+
thread_data = current_thread.__dict__
142+
thread_data["invokers_token"] = invokers_token_val
143+
144+
headers = cfg.authenticate()
145+
assert (cfg.host == 'x')
146+
assert headers.get("Authorization") == f'Bearer {invokers_token_val}'
147+
148+
149+
# If this credential strategy is being used in a non model serving environments then use default credential strategy instead
150+
def test_agent_user_credentials_in_non_model_serving_environments(monkeypatch):
151+
152+
monkeypatch.setenv('DATABRICKS_HOST', 'x')
153+
monkeypatch.setenv('DATABRICKS_TOKEN', 'token')
154+
155+
cfg = Config(credentials_strategy=ModelServingUserCredentials())
156+
assert cfg.auth_type == 'pat' # Auth type is PAT as it is no longer in a model serving environment
157+
158+
headers = cfg.authenticate()
159+
160+
assert (cfg.host == 'https://x')
161+
assert headers.get("Authorization") == f'Bearer token'

0 commit comments

Comments
 (0)