Skip to content

Commit 1c6cd26

Browse files
committed
Update Credential Strategy
1 parent 037a853 commit 1c6cd26

File tree

2 files changed

+36
-74
lines changed

2 files changed

+36
-74
lines changed

databricks/sdk/credentials_provider.py

Lines changed: 25 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -314,12 +314,11 @@ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
314314
# detect Azure AD Tenant ID if it's not specified directly
315315
token_endpoint = cfg.oidc_endpoints.token_endpoint
316316
cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint, '').split('/')[0]
317-
inner = ClientCredentials(
318-
client_id=cfg.azure_client_id,
319-
client_secret="", # we have no (rotatable) secrets in OIDC flow
320-
token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
321-
endpoint_params=params,
322-
use_params=True)
317+
inner = ClientCredentials(client_id=cfg.azure_client_id,
318+
client_secret="", # we have no (rotatable) secrets in OIDC flow
319+
token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token",
320+
endpoint_params=params,
321+
use_params=True)
323322

324323
def refreshed_headers() -> Dict[str, str]:
325324
token = inner.token()
@@ -725,11 +724,10 @@ def inner() -> Dict[str, str]:
725724
# https://github.com/mlflow/mlflow/blob/1219e3ef1aac7d337a618a352cd859b336cf5c81/mlflow/legacy_databricks_cli/configure/provider.py#L332
726725
class ModelServingAuthProvider():
727726
USER_CREDENTIALS = "user_credentials"
728-
EMBEDDED_CREDENTIALS = "embedded_credentials"
729727

730728
_MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH = "/var/credentials-secret/model-dependencies-oauth-token"
731729

732-
def __init__(self, credential_type):
730+
def __init__(self, credential_type: Optional[str]):
733731
self.expiry_time = -1
734732
self.current_token = None
735733
self.refresh_duration = 300 # 300 Seconds
@@ -746,7 +744,7 @@ def should_fetch_model_serving_environment_oauth() -> bool:
746744
return (is_in_model_serving_env == "true"
747745
and os.path.isfile(ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH))
748746

749-
def get_model_dependency_oauth_token(self, should_retry=True) -> str:
747+
def _get_model_dependency_oauth_token(self, should_retry=True) -> str:
750748
# Use Cached value if it is valid
751749
if self.current_token is not None and self.expiry_time > time.time():
752750
return self.current_token
@@ -762,14 +760,14 @@ def get_model_dependency_oauth_token(self, should_retry=True) -> str:
762760
logger.warning("Unable to read oauth token on first attmept in Model Serving Environment",
763761
exc_info=e)
764762
time.sleep(0.5)
765-
return self.get_model_dependency_oauth_token(should_retry=False)
763+
return self._get_model_dependency_oauth_token(should_retry=False)
766764
else:
767765
raise RuntimeError(
768766
"Unable to read OAuth credentials from the file mounted in Databricks Model Serving"
769767
) from e
770768
return self.current_token
771769

772-
def get_invokers_token(self):
770+
def _get_invokers_token(self):
773771
current_thread = threading.current_thread()
774772
thread_data = current_thread.__dict__
775773
invokers_token = None
@@ -788,18 +786,16 @@ def get_databricks_host_token(self) -> Optional[Tuple[str, str]]:
788786
# read from DB_MODEL_SERVING_HOST_ENV_VAR if available otherwise MODEL_SERVING_HOST_ENV_VAR
789787
host = os.environ.get("DATABRICKS_MODEL_SERVING_HOST_URL") or os.environ.get(
790788
"DB_MODEL_SERVING_HOST_URL")
791-
token = self.get_model_dependency_oauth_token(
792-
) if self.credential_type == ModelServingAuthProvider.EMBEDDED_CREDENTIALS else self.get_invokers_token(
793-
)
794789

795-
return (host, token)
790+
if self.credential_type == ModelServingAuthProvider.USER_CREDENTIALS:
791+
return (host, self._get_invokers_token())
792+
else:
793+
return (host, self._get_model_dependency_oauth_token())
796794

797795

798-
def model_serving_auth_func(cfg: 'Config', credential_type) -> Optional[CredentialsProvider]:
796+
def model_serving_auth_visitor(cfg: 'Config',
797+
credential_type: Optional[str] = None) -> Optional[CredentialsProvider]:
799798
try:
800-
if not ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
801-
logger.debug("model-serving: Not in Databricks Model Serving, skipping")
802-
return None
803799
model_serving_auth_provider = ModelServingAuthProvider(credential_type)
804800
host, token = model_serving_auth_provider.get_databricks_host_token()
805801
if token is None:
@@ -823,7 +819,11 @@ def inner() -> Dict[str, str]:
823819

824820
@credentials_strategy('model-serving', [])
825821
def model_serving_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
826-
return model_serving_auth_func(cfg, ModelServingAuthProvider.EMBEDDED_CREDENTIALS)
822+
if not ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
823+
logger.debug("model-serving: Not in Databricks Model Serving, skipping")
824+
return None
825+
826+
return model_serving_auth_visitor(cfg)
827827

828828

829829
class DefaultCredentials:
@@ -870,37 +870,25 @@ def __call__(self, cfg: 'Config') -> CredentialsProvider:
870870
)
871871

872872

873-
class AgentCredentials(CredentialsStrategy):
873+
class ModelServingUserCredentials(CredentialsStrategy):
874874

875-
def __init__(self, credential_type):
876-
self.credential_type = credential_type
875+
def __init__(self):
876+
self.credential_type = ModelServingAuthProvider.USER_CREDENTIALS
877877
self.default_credentials = DefaultCredentials()
878878

879879
def auth_type(self):
880880
if ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
881-
return "agent_" + self.credential_type
881+
return "model_serving_" + self.credential_type
882882
else:
883883
return self.default_credentials.auth_type()
884884

885885
def __call__(self, cfg: 'Config') -> CredentialsProvider:
886886
if ModelServingAuthProvider.should_fetch_model_serving_environment_oauth():
887-
header_factory = model_serving_auth_func(cfg, self.credential_type)
887+
header_factory = model_serving_auth_visitor(cfg, self.credential_type)
888888
if not header_factory:
889889
raise ValueError(
890890
f"Unable to authenticate using {self.credential_type} in Databricks Model Serving Environment"
891891
)
892892
return header_factory
893893
else:
894894
return self.default_credentials(cfg)
895-
896-
897-
class AgentUserCredentials(AgentCredentials):
898-
899-
def __init__(self):
900-
super().__init__(ModelServingAuthProvider.USER_CREDENTIALS)
901-
902-
903-
class AgentEmbeddedCredentials(AgentCredentials):
904-
905-
def __init__(self):
906-
super().__init__(ModelServingAuthProvider.EMBEDDED_CREDENTIALS)

tests/test_model_serving_auth.py

Lines changed: 11 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
import pytest
55

66
from databricks.sdk.core import Config
7-
from databricks.sdk.credentials_provider import (AgentEmbeddedCredentials,
8-
AgentUserCredentials)
7+
from databricks.sdk.credentials_provider import ModelServingUserCredentials
98

109
from .conftest import raises
1110

@@ -27,9 +26,7 @@
2726
([('IS_IN_DATABRICKS_MODEL_SERVING_ENV', 'true'),
2827
('DATABRICKS_MODEL_SERVING_HOST_URL', 'x')
2928
], ['DB_MODEL_SERVING_HOST_URL'], "tests/testdata/model-serving-test-token"), ])
30-
@pytest.mark.parametrize("use_credential_strategy", [True, False])
31-
def test_model_serving_auth(env_values, del_env_values, oauth_file_name, use_credential_strategy, monkeypatch,
32-
mocker):
29+
def test_model_serving_auth(env_values, del_env_values, oauth_file_name, monkeypatch, mocker):
3330
## In mlflow we check for these two environment variables to return the correct config
3431
for (env_name, env_value) in env_values:
3532
monkeypatch.setenv(env_name, env_value)
@@ -42,12 +39,9 @@ def test_model_serving_auth(env_values, del_env_values, oauth_file_name, use_cre
4239
"databricks.sdk.credentials_provider.ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH",
4340
oauth_file_name)
4441
mocker.patch('databricks.sdk.config.Config._known_file_config_loader')
45-
if use_credential_strategy:
46-
cfg = Config(credentials_strategy=AgentEmbeddedCredentials())
47-
assert cfg.auth_type == 'agent_embedded_credentials'
48-
else:
49-
cfg = Config()
50-
assert cfg.auth_type == 'model-serving'
42+
43+
cfg = Config()
44+
assert cfg.auth_type == 'model-serving'
5145
headers = cfg.authenticate()
5246
assert (cfg.host == 'x')
5347
# Token defined in the test file
@@ -78,8 +72,7 @@ def test_model_serving_auth_errors(env_values, oauth_file_name, monkeypatch):
7872
Config()
7973

8074

81-
@pytest.mark.parametrize("use_credential_strategy", [True, False])
82-
def test_model_serving_auth_refresh(use_credential_strategy, monkeypatch, mocker):
75+
def test_model_serving_auth_refresh(monkeypatch, mocker):
8376
## In mlflow we check for these two environment variables to return the correct config
8477
monkeypatch.setenv('IS_IN_DB_MODEL_SERVING_ENV', 'true')
8578
monkeypatch.setenv('DB_MODEL_SERVING_HOST_URL', 'x')
@@ -90,12 +83,8 @@ def test_model_serving_auth_refresh(use_credential_strategy, monkeypatch, mocker
9083
"tests/testdata/model-serving-test-token")
9184
mocker.patch('databricks.sdk.config.Config._known_file_config_loader')
9285

93-
if use_credential_strategy:
94-
cfg = Config(credentials_strategy=AgentEmbeddedCredentials())
95-
assert cfg.auth_type == 'agent_embedded_credentials'
96-
else:
97-
cfg = Config()
98-
assert cfg.auth_type == 'model-serving'
86+
cfg = Config()
87+
assert cfg.auth_type == 'model-serving'
9988

10089
current_time = time.time()
10190
headers = cfg.authenticate()
@@ -135,8 +124,8 @@ def test_agent_user_credentials(monkeypatch, mocker):
135124
thread_data = current_thread.__dict__
136125
thread_data["invokers_token"] = invokers_token_val
137126

138-
cfg = Config(credentials_strategy=AgentUserCredentials())
139-
assert cfg.auth_type == 'agent_user_credentials'
127+
cfg = Config(credentials_strategy=ModelServingUserCredentials())
128+
assert cfg.auth_type == 'model_serving_user_credentials'
140129

141130
headers = cfg.authenticate()
142131

@@ -160,22 +149,7 @@ def test_agent_user_credentials_in_non_model_serving_environments(monkeypatch):
160149
monkeypatch.setenv('DATABRICKS_HOST', 'x')
161150
monkeypatch.setenv('DATABRICKS_TOKEN', 'token')
162151

163-
cfg = Config(credentials_strategy=AgentUserCredentials())
164-
assert cfg.auth_type == 'pat' # Auth type is PAT as it is no longer in a model serving environment
165-
166-
headers = cfg.authenticate()
167-
168-
assert (cfg.host == 'https://x')
169-
assert headers.get("Authorization") == f'Bearer token'
170-
171-
172-
# If this credential strategy is being used in a non model serving environments then use default credential strategy instead
173-
def test_agent_embedded_credentials_in_non_model_serving_environments(monkeypatch):
174-
175-
monkeypatch.setenv('DATABRICKS_HOST', 'x')
176-
monkeypatch.setenv('DATABRICKS_TOKEN', 'token')
177-
178-
cfg = Config(credentials_strategy=AgentEmbeddedCredentials())
152+
cfg = Config(credentials_strategy=ModelServingUserCredentials())
179153
assert cfg.auth_type == 'pat' # Auth type is PAT as it is no longer in a model serving environment
180154

181155
headers = cfg.authenticate()

0 commit comments

Comments
 (0)