1+ import threading
12import time
23
34import pytest
45
56from databricks .sdk .core import Config
7+ from databricks .sdk .credentials_provider import (AgentEmbeddedCredentials ,
8+ AgentUserCredentials )
69
710from .conftest import raises
811
2427 ([('IS_IN_DATABRICKS_MODEL_SERVING_ENV' , 'true' ),
2528 ('DATABRICKS_MODEL_SERVING_HOST_URL' , 'x' )
2629 ], ['DB_MODEL_SERVING_HOST_URL' ], "tests/testdata/model-serving-test-token" ), ])
27- def test_model_serving_auth (env_values , del_env_values , oauth_file_name , monkeypatch , mocker ):
30+ @pytest .mark .parametrize ("use_credential_strategy" , [True , False ])
31+ def test_model_serving_auth (env_values , del_env_values , oauth_file_name , use_credential_strategy , monkeypatch ,
32+ mocker ):
2833 ## In mlflow we check for these two environment variables to return the correct config
2934 for (env_name , env_value ) in env_values :
3035 monkeypatch .setenv (env_name , env_value )
@@ -37,26 +42,25 @@ def test_model_serving_auth(env_values, del_env_values, oauth_file_name, monkeyp
3742 "databricks.sdk.credentials_provider.ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH" ,
3843 oauth_file_name )
3944 mocker .patch ('databricks.sdk.config.Config._known_file_config_loader' )
40-
41- cfg = Config ()
42-
43- assert cfg .auth_type == 'model-serving'
45+ if use_credential_strategy :
46+ cfg = Config (credentials_strategy = AgentEmbeddedCredentials ())
47+ assert cfg .auth_type == 'agent_embedded_credentials'
48+ else :
49+ cfg = Config ()
50+ assert cfg .auth_type == 'model-serving'
4451 headers = cfg .authenticate ()
4552 assert (cfg .host == 'x' )
4653 # Token defined in the test file
4754 assert headers .get ("Authorization" ) == 'Bearer databricks_sdk_unit_test_token'
4855
4956
50- @pytest .mark .parametrize (
51- "env_values, oauth_file_name" ,
52- [
53- ([], "invalid_file_name" ), # Not in Model Serving and Invalid File Name
54- ([('IS_IN_DB_MODEL_SERVING_ENV' , 'true' )
55- ], "invalid_file_name" ), # In Model Serving and Invalid File Name
56- ([('IS_IN_DATABRICKS_MODEL_SERVING_ENV' , 'true' )
57- ], "invalid_file_name" ), # In Model Serving and Invalid File Name
58- ([], "tests/testdata/model-serving-test-token" ) # Not in Model Serving and Valid File Name
59- ])
57+ @pytest .mark .parametrize ("env_values, oauth_file_name" , [
58+ ([], "invalid_file_name" ), # Not in Model Serving and Invalid File Name
59+ ([('IS_IN_DB_MODEL_SERVING_ENV' , 'true' )], "invalid_file_name" ), # In Model Serving and Invalid File Name
60+ ([('IS_IN_DATABRICKS_MODEL_SERVING_ENV' , 'true' )
61+ ], "invalid_file_name" ), # In Model Serving and Invalid File Name
62+ ([], "tests/testdata/model-serving-test-token" ) # Not in Model Serving and Valid File Name
63+ ])
6064@raises (default_auth_base_error_message )
6165def test_model_serving_auth_errors (env_values , oauth_file_name , monkeypatch ):
6266 # Guarantee that the tests defaults to env variables rather than config file.
@@ -74,7 +78,8 @@ def test_model_serving_auth_errors(env_values, oauth_file_name, monkeypatch):
7478 Config ()
7579
7680
77- def test_model_serving_auth_refresh (monkeypatch , mocker ):
81+ @pytest .mark .parametrize ("use_credential_strategy" , [True , False ])
82+ def test_model_serving_auth_refresh (use_credential_strategy , monkeypatch , mocker ):
7883 ## In mlflow we check for these two environment variables to return the correct config
7984 monkeypatch .setenv ('IS_IN_DB_MODEL_SERVING_ENV' , 'true' )
8085 monkeypatch .setenv ('DB_MODEL_SERVING_HOST_URL' , 'x' )
@@ -85,15 +90,18 @@ def test_model_serving_auth_refresh(monkeypatch, mocker):
8590 "tests/testdata/model-serving-test-token" )
8691 mocker .patch ('databricks.sdk.config.Config._known_file_config_loader' )
8792
88- cfg = Config ()
89- assert cfg .auth_type == 'model-serving'
93+ if use_credential_strategy :
94+ cfg = Config (credentials_strategy = AgentEmbeddedCredentials ())
95+ assert cfg .auth_type == 'agent_embedded_credentials'
96+ else :
97+ cfg = Config ()
98+ assert cfg .auth_type == 'model-serving'
9099
91100 current_time = time .time ()
92101 headers = cfg .authenticate ()
93102 assert (cfg .host == 'x' )
94103 assert headers .get (
95104 "Authorization" ) == 'Bearer databricks_sdk_unit_test_token' # Token defined in the test file
96-
97105 # Simulate refreshing the token by patching to to a new file
98106 monkeypatch .setattr (
99107 "databricks.sdk.credentials_provider.ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH" ,
@@ -113,3 +121,64 @@ def test_model_serving_auth_refresh(monkeypatch, mocker):
113121 assert (cfg .host == 'x' )
114122 # Read V2 now
115123 assert headers .get ("Authorization" ) == 'Bearer databricks_sdk_unit_test_token_v2'
124+
125+
126+ def test_agent_user_credentials (monkeypatch , mocker ):
127+ monkeypatch .setenv ('IS_IN_DB_MODEL_SERVING_ENV' , 'true' )
128+ monkeypatch .setenv ('DB_MODEL_SERVING_HOST_URL' , 'x' )
129+ monkeypatch .setattr (
130+ "databricks.sdk.credentials_provider.ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH" ,
131+ "tests/testdata/model-serving-test-token" )
132+
133+ invokers_token_val = "databricks_invokers_token"
134+ current_thread = threading .current_thread ()
135+ thread_data = current_thread .__dict__
136+ thread_data ["invokers_token" ] = invokers_token_val
137+
138+ cfg = Config (credentials_strategy = AgentUserCredentials ())
139+ assert cfg .auth_type == 'agent_user_credentials'
140+
141+ headers = cfg .authenticate ()
142+
143+ assert (cfg .host == 'x' )
144+ assert headers .get ("Authorization" ) == f'Bearer { invokers_token_val } '
145+
146+ # Test updates of invokers token
147+ invokers_token_val = "databricks_invokers_token_v2"
148+ current_thread = threading .current_thread ()
149+ thread_data = current_thread .__dict__
150+ thread_data ["invokers_token" ] = invokers_token_val
151+
152+ headers = cfg .authenticate ()
153+ assert (cfg .host == 'x' )
154+ assert headers .get ("Authorization" ) == f'Bearer { invokers_token_val } '
155+
156+
157+ # If this credential strategy is being used in a non model serving environments then use default credential strategy instead
158+ def test_agent_user_credentials_in_non_model_serving_environments (monkeypatch ):
159+
160+ monkeypatch .setenv ('DATABRICKS_HOST' , 'x' )
161+ monkeypatch .setenv ('DATABRICKS_TOKEN' , 'token' )
162+
163+ cfg = Config (credentials_strategy = AgentUserCredentials ())
164+ assert cfg .auth_type == 'pat' # Auth type is PAT as it is no longer in a model serving environment
165+
166+ headers = cfg .authenticate ()
167+
168+ assert (cfg .host == 'https://x' )
169+ assert headers .get ("Authorization" ) == f'Bearer token'
170+
171+
172+ # If this credential strategy is being used in a non model serving environments then use default credential strategy instead
173+ def test_agent_embedded_credentials_in_non_model_serving_environments (monkeypatch ):
174+
175+ monkeypatch .setenv ('DATABRICKS_HOST' , 'x' )
176+ monkeypatch .setenv ('DATABRICKS_TOKEN' , 'token' )
177+
178+ cfg = Config (credentials_strategy = AgentEmbeddedCredentials ())
179+ assert cfg .auth_type == 'pat' # Auth type is PAT as it is no longer in a model serving environment
180+
181+ headers = cfg .authenticate ()
182+
183+ assert (cfg .host == 'https://x' )
184+ assert headers .get ("Authorization" ) == f'Bearer token'
0 commit comments