Skip to content

Commit 9d39254

Browse files
authored
[Feature] Integrate Databricks SDK with Model Serving Auth Provider (#761)
## Changes This PR introduces a new model serving auth method to Databricks SDK. - If the correct environment variables are set to identify a model serving environment - Check to see if there is an oauth file written by the serving environment - If this file exists use the token here for authentication ## Tests Added Unit tests - [x] `make test` run locally - [x] `make fmt` applied - [x] relevant integration tests applied --------- Signed-off-by: aravind-segu <[email protected]>
1 parent d5ec433 commit 9d39254

File tree

4 files changed

+199
-2
lines changed

4 files changed

+199
-2
lines changed

databricks/sdk/credentials_provider.py

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
import platform
1010
import subprocess
1111
import sys
12+
import time
1213
from datetime import datetime
13-
from typing import Callable, Dict, List, Optional, Union
14+
from typing import Callable, Dict, List, Optional, Tuple, Union
1415

1516
import google.auth
1617
import requests
@@ -698,6 +699,90 @@ def inner() -> Dict[str, str]:
698699
return inner
699700

700701

702+
# This Code is derived from Mlflow DatabricksModelServingConfigProvider
703+
# https://github.com/mlflow/mlflow/blob/1219e3ef1aac7d337a618a352cd859b336cf5c81/mlflow/legacy_databricks_cli/configure/provider.py#L332
704+
class ModelServingAuthProvider():
705+
_MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH = "/var/credentials-secret/model-dependencies-oauth-token"
706+
707+
def __init__(self):
708+
self.expiry_time = -1
709+
self.current_token = None
710+
self.refresh_duration = 300 # 300 Seconds
711+
712+
def should_fetch_model_serving_environment_oauth(self) -> bool:
713+
"""
714+
Check whether this is the model serving environment
715+
Additionally check if the oauth token file path exists
716+
"""
717+
718+
is_in_model_serving_env = (os.environ.get("IS_IN_DB_MODEL_SERVING_ENV")
719+
or os.environ.get("IS_IN_DATABRICKS_MODEL_SERVING_ENV") or "false")
720+
return (is_in_model_serving_env == "true"
721+
and os.path.isfile(self._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH))
722+
723+
def get_model_dependency_oauth_token(self, should_retry=True) -> str:
724+
# Use Cached value if it is valid
725+
if self.current_token is not None and self.expiry_time > time.time():
726+
return self.current_token
727+
728+
try:
729+
with open(self._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH) as f:
730+
oauth_dict = json.load(f)
731+
self.current_token = oauth_dict["OAUTH_TOKEN"][0]["oauthTokenValue"]
732+
self.expiry_time = time.time() + self.refresh_duration
733+
except Exception as e:
734+
# sleep and retry in case of any race conditions with OAuth refreshing
735+
if should_retry:
736+
logger.warning("Unable to read oauth token on first attmept in Model Serving Environment",
737+
exc_info=e)
738+
time.sleep(0.5)
739+
return self.get_model_dependency_oauth_token(should_retry=False)
740+
else:
741+
raise RuntimeError(
742+
"Unable to read OAuth credentials from the file mounted in Databricks Model Serving"
743+
) from e
744+
return self.current_token
745+
746+
def get_databricks_host_token(self) -> Optional[Tuple[str, str]]:
747+
if not self.should_fetch_model_serving_environment_oauth():
748+
return None
749+
750+
# read from DB_MODEL_SERVING_HOST_ENV_VAR if available otherwise MODEL_SERVING_HOST_ENV_VAR
751+
host = os.environ.get("DATABRICKS_MODEL_SERVING_HOST_URL") or os.environ.get(
752+
"DB_MODEL_SERVING_HOST_URL")
753+
token = self.get_model_dependency_oauth_token()
754+
755+
return (host, token)
756+
757+
758+
@credentials_strategy('model-serving', [])
759+
def model_serving_auth(cfg: 'Config') -> Optional[CredentialsProvider]:
760+
try:
761+
model_serving_auth_provider = ModelServingAuthProvider()
762+
if not model_serving_auth_provider.should_fetch_model_serving_environment_oauth():
763+
logger.debug("model-serving: Not in Databricks Model Serving, skipping")
764+
return None
765+
host, token = model_serving_auth_provider.get_databricks_host_token()
766+
if token is None:
767+
raise ValueError(
768+
"Got malformed auth (empty token) when fetching auth implicitly available in Model Serving Environment. Please contact Databricks support"
769+
)
770+
if cfg.host is None:
771+
cfg.host = host
772+
except Exception as e:
773+
logger.warning("Unable to get auth from Databricks Model Serving Environment", exc_info=e)
774+
return None
775+
776+
logger.info("Using Databricks Model Serving Authentication")
777+
778+
def inner() -> Dict[str, str]:
779+
# Call here again to get the refreshed token
780+
_, token = model_serving_auth_provider.get_databricks_host_token()
781+
return {"Authorization": f"Bearer {token}"}
782+
783+
return inner
784+
785+
701786
class DefaultCredentials:
702787
""" Select the first applicable credential provider from the chain """
703788

@@ -706,7 +791,7 @@ def __init__(self) -> None:
706791
self._auth_providers = [
707792
pat_auth, basic_auth, metadata_service, oauth_service_principal, azure_service_principal,
708793
github_oidc_azure, azure_cli, external_browser, databricks_cli, runtime_native_auth,
709-
google_credentials, google_id
794+
google_credentials, google_id, model_serving_auth
710795
]
711796

712797
def auth_type(self) -> str:

tests/test_model_serving_auth.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import time
2+
3+
import pytest
4+
5+
from databricks.sdk.core import Config
6+
7+
from .conftest import raises
8+
9+
default_auth_base_error_message = \
10+
"default auth: cannot configure default credentials, " \
11+
"please check https://docs.databricks.com/en/dev-tools/auth.html#databricks-client-unified-authentication " \
12+
"to configure credentials for your preferred authentication method"
13+
14+
15+
@pytest.mark.parametrize(
16+
"env_values, oauth_file_name",
17+
[([('IS_IN_DB_MODEL_SERVING_ENV', 'true'),
18+
('DB_MODEL_SERVING_HOST_URL', 'x')], "tests/testdata/model-serving-test-token"),
19+
([('IS_IN_DATABRICKS_MODEL_SERVING_ENV', 'true'),
20+
('DB_MODEL_SERVING_HOST_URL', 'x')], "tests/testdata/model-serving-test-token"),
21+
([('IS_IN_DB_MODEL_SERVING_ENV', 'true'),
22+
('DATABRICKS_MODEL_SERVING_HOST_URL', 'x')], "tests/testdata/model-serving-test-token"),
23+
([('IS_IN_DATABRICKS_MODEL_SERVING_ENV', 'true'),
24+
('DATABRICKS_MODEL_SERVING_HOST_URL', 'x')], "tests/testdata/model-serving-test-token"), ])
25+
def test_model_serving_auth(env_values, oauth_file_name, monkeypatch):
26+
## In mlflow we check for these two environment variables to return the correct config
27+
for (env_name, env_value) in env_values:
28+
monkeypatch.setenv(env_name, env_value)
29+
# patch mlflow to read the file from the test directory
30+
monkeypatch.setattr(
31+
"databricks.sdk.credentials_provider.ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH",
32+
oauth_file_name)
33+
34+
cfg = Config()
35+
36+
assert cfg.auth_type == 'model-serving'
37+
headers = cfg.authenticate()
38+
assert (cfg.host == 'x')
39+
# Token defined in the test file
40+
assert headers.get("Authorization") == 'Bearer databricks_sdk_unit_test_token'
41+
42+
43+
@pytest.mark.parametrize("env_values, oauth_file_name", [
44+
([], "invalid_file_name"), # Not in Model Serving and Invalid File Name
45+
([('IS_IN_DB_MODEL_SERVING_ENV', 'true')], "invalid_file_name"), # In Model Serving and Invalid File Name
46+
([('IS_IN_DATABRICKS_MODEL_SERVING_ENV', 'true')
47+
], "invalid_file_name"), # In Model Serving and Invalid File Name
48+
([], "tests/testdata/model-serving-test-token") # Not in Model Serving and Valid File Name
49+
])
50+
@raises(default_auth_base_error_message)
51+
def test_model_serving_auth_errors(env_values, oauth_file_name, monkeypatch):
52+
for (env_name, env_value) in env_values:
53+
monkeypatch.setenv(env_name, env_value)
54+
monkeypatch.setattr(
55+
"databricks.sdk.credentials_provider.ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH",
56+
oauth_file_name)
57+
58+
Config()
59+
60+
61+
def test_model_serving_auth_refresh(monkeypatch):
62+
## In mlflow we check for these two environment variables to return the correct config
63+
monkeypatch.setenv('IS_IN_DB_MODEL_SERVING_ENV', 'true')
64+
monkeypatch.setenv('DB_MODEL_SERVING_HOST_URL', 'x')
65+
66+
# patch mlflow to read the file from the test directory
67+
monkeypatch.setattr(
68+
"databricks.sdk.credentials_provider.ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH",
69+
"tests/testdata/model-serving-test-token")
70+
71+
cfg = Config()
72+
assert cfg.auth_type == 'model-serving'
73+
74+
current_time = time.time()
75+
headers = cfg.authenticate()
76+
assert (cfg.host == 'x')
77+
assert headers.get(
78+
"Authorization") == 'Bearer databricks_sdk_unit_test_token' # Token defined in the test file
79+
80+
# Simulate refreshing the token by patching to to a new file
81+
monkeypatch.setattr(
82+
"databricks.sdk.credentials_provider.ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH",
83+
"tests/testdata/model-serving-test-token-v2")
84+
85+
monkeypatch.setattr('databricks.sdk.credentials_provider.time.time', lambda: current_time + 10)
86+
87+
headers = cfg.authenticate()
88+
assert (cfg.host == 'x')
89+
# Read from cache even though new path is set because expiry is still not hit
90+
assert headers.get("Authorization") == 'Bearer databricks_sdk_unit_test_token'
91+
92+
# Expiry is 300 seconds so this should force an expiry and re read from the new file path
93+
monkeypatch.setattr('databricks.sdk.credentials_provider.time.time', lambda: current_time + 600)
94+
95+
headers = cfg.authenticate()
96+
assert (cfg.host == 'x')
97+
# Read V2 now
98+
assert headers.get("Authorization") == 'Bearer databricks_sdk_unit_test_token_v2'
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"OAUTH_TOKEN": [
3+
{
4+
"oauthTokenValue": "databricks_sdk_unit_test_token"
5+
}
6+
]
7+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"OAUTH_TOKEN": [
3+
{
4+
"oauthTokenValue": "databricks_sdk_unit_test_token_v2"
5+
}
6+
]
7+
}

0 commit comments

Comments
 (0)