Skip to content

Commit 6f7389c

Browse files
committed
Ensure Invokers Rights only looks at main thread for invokers token
Signed-off-by: aravind-segu <[email protected]>
1 parent 83a921f commit 6f7389c

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

databricks/sdk/credentials_provider.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -769,8 +769,8 @@ def _get_model_dependency_oauth_token(self, should_retry=True) -> str:
769769
return self.current_token
770770

771771
def _get_invokers_token(self):
772-
current_thread = threading.current_thread()
773-
thread_data = current_thread.__dict__
772+
main_thread = threading.main_thread()
773+
thread_data = main_thread.__dict__
774774
invokers_token = None
775775
if "invokers_token" in thread_data:
776776
invokers_token = thread_data["invokers_token"]

tests/test_model_serving_auth.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ def test_model_serving_auth_refresh(monkeypatch, mocker):
114114
# Read V2 now
115115
assert headers.get("Authorization") == 'Bearer databricks_sdk_unit_test_token_v2'
116116

117-
118117
def test_agent_user_credentials(monkeypatch, mocker):
119118
monkeypatch.setenv('IS_IN_DB_MODEL_SERVING_ENV', 'true')
120119
monkeypatch.setenv('DB_MODEL_SERVING_HOST_URL', 'x')
@@ -145,6 +144,26 @@ def test_agent_user_credentials(monkeypatch, mocker):
145144
assert (cfg.host == 'x')
146145
assert headers.get("Authorization") == f'Bearer {invokers_token_val}'
147146

147+
# Test invokers token in child thread
148+
149+
successful_authentication_event = threading.Event()
150+
151+
def authenticate():
152+
try:
153+
cfg = Config(credentials_strategy=ModelServingUserCredentials())
154+
headers = cfg.authenticate()
155+
assert (cfg.host == 'x')
156+
assert headers.get("Authorization") == f'Bearer databricks_invokers_token_v2'
157+
successful_authentication_event.set()
158+
except Exception as e:
159+
successful_authentication_event.clear()
160+
161+
thread = threading.Thread(target=authenticate)
162+
163+
thread.start()
164+
thread.join()
165+
assert(successful_authentication_event.is_set())
166+
148167

149168
# If this credential strategy is being used in a non model serving environments then use default credential strategy instead
150169
def test_agent_user_credentials_in_non_model_serving_environments(monkeypatch):

0 commit comments

Comments
 (0)