Skip to content

Commit 4ff52b0

Browse files
author
Anshul Gupta
committed
Support greenlet local auth for model serving
1 parent a393602 commit 4ff52b0

File tree

3 files changed

+46
-1
lines changed

3 files changed

+46
-1
lines changed

databricks/sdk/credentials_provider.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -942,6 +942,14 @@ def _get_model_dependency_oauth_token(self, should_retry=True) -> str:
942942
) from e
943943
return self.current_token
944944

945+
def _get_invokers_token_from_greenlet(self):
946+
# Attempt to retrieve 'invokers_token' from greenlet local
947+
from greenlet import greenlet, getcurrent
948+
greenlet = getcurrent()
949+
if hasattr(greenlet, 'invokers_token'):
950+
return greenlet.invokers_token
951+
raise RuntimeError("Unable to read Invokers Token in Databricks Model Serving")
952+
945953
def _get_invokers_token(self):
946954
main_thread = threading.main_thread()
947955
thread_data = main_thread.__dict__
@@ -950,7 +958,8 @@ def _get_invokers_token(self):
950958
invokers_token = thread_data["invokers_token"]
951959

952960
if invokers_token is None:
953-
raise RuntimeError("Unable to read Invokers Token in Databricks Model Serving")
961+
# This is likely async server code, so we should check greenlet local
962+
return self._get_invokers_token_from_greenlet()
954963

955964
return invokers_token
956965

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ classifiers = [
2727
dependencies = [
2828
"requests>=2.28.1,<3",
2929
"google-auth~=2.0",
30+
"greenlet",
3031
]
3132

3233
[project.urls]

tests/test_model_serving_auth.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from databricks.sdk.core import Config
77
from databricks.sdk.credentials_provider import ModelServingUserCredentials
8+
from greenlet import greenlet, getcurrent
89

910
from .conftest import raises
1011

@@ -217,7 +218,41 @@ def authenticate():
217218
thread.start()
218219
thread.join()
219220
assert successful_authentication_event.is_set()
221+
del current_thread.__dict__["invokers_token"] # Clean up invokers token
220222

223+
def test_agent_user_credentials_via_greenlet(monkeypatch, mocker):
224+
# Guarantee that the tests defaults to env variables rather than config file.
225+
#
226+
# TODO: this is hacky and we should find a better way to tell the config
227+
# that it should not read from the config file.
228+
monkeypatch.setenv("DATABRICKS_CONFIG_FILE", "x")
229+
230+
monkeypatch.setenv("IS_IN_DB_MODEL_SERVING_ENV", "true")
231+
monkeypatch.setenv("DB_MODEL_SERVING_HOST_URL", "x")
232+
monkeypatch.setattr(
233+
"databricks.sdk.credentials_provider.ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH",
234+
"tests/testdata/model-serving-test-token",
235+
)
236+
237+
invokers_token_val = "databricks_invokers_token"
238+
greenlet_local = getcurrent()
239+
setattr(greenlet_local, "invokers_token", invokers_token_val)
240+
241+
cfg = Config(credentials_strategy=ModelServingUserCredentials())
242+
assert cfg.auth_type == "model_serving_user_credentials"
243+
244+
headers = cfg.authenticate()
245+
246+
assert cfg.host == "x"
247+
assert headers.get("Authorization") == f"Bearer {invokers_token_val}"
248+
249+
# Test updates of invokers token
250+
invokers_token_val = "databricks_invokers_token_v2"
251+
setattr(greenlet_local, "invokers_token", invokers_token_val)
252+
253+
headers = cfg.authenticate()
254+
assert cfg.host == "x"
255+
assert headers.get("Authorization") == f"Bearer {invokers_token_val}"
221256

222257
# If this credential strategy is being used in a non model serving environments then use default credential strategy instead
223258
def test_agent_user_credentials_in_non_model_serving_environments(monkeypatch):

0 commit comments

Comments
 (0)