Skip to content

Commit a1185d2

Browse files
[Feature] Fix Model Serving User Credentials threading scenarios (#907)
## What changes are proposed in this pull request? Previously in order to get invokers tokens, we always looked at the current thread's thread data. however this is not always guaranteed as langchain is internally multi threading calls. Therefore this PR only looks at the main thread's data to get the credentials. The Main Thread is guaranteed to have invokers credentials as we set them initially in the scoring server. ## How is this tested? Added unit tests --------- Signed-off-by: aravind-segu <[email protected]> Co-authored-by: Renaud Hartert <[email protected]>
1 parent 8de985d commit a1185d2

File tree

3 files changed

+23
-2
lines changed

3 files changed

+23
-2
lines changed

NEXT_CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@
1313
* Update Jobs ListRuns API to support paginated responses ([#890](https://github.com/databricks/databricks-sdk-py/pull/890))
1414
* Introduce automated tagging ([#888](https://github.com/databricks/databricks-sdk-py/pull/888))
1515
* Update Jobs GetJob API to support paginated responses ([#869](https://github.com/databricks/databricks-sdk-py/pull/869)).
16+
* Update On Behalf Of User Authentication in Multithreaded applications ([#907](https://github.com/databricks/databricks-sdk-py/pull/907))
1617

1718
### API Changes

databricks/sdk/credentials_provider.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -853,8 +853,8 @@ def _get_model_dependency_oauth_token(self, should_retry=True) -> str:
853853
return self.current_token
854854

855855
def _get_invokers_token(self):
856-
current_thread = threading.current_thread()
857-
thread_data = current_thread.__dict__
856+
main_thread = threading.main_thread()
857+
thread_data = main_thread.__dict__
858858
invokers_token = None
859859
if "invokers_token" in thread_data:
860860
invokers_token = thread_data["invokers_token"]

tests/test_model_serving_auth.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,26 @@ def test_agent_user_credentials(monkeypatch, mocker):
198198
assert cfg.host == "x"
199199
assert headers.get("Authorization") == f"Bearer {invokers_token_val}"
200200

201+
# Test invokers token in child thread
202+
203+
successful_authentication_event = threading.Event()
204+
205+
def authenticate():
206+
try:
207+
cfg = Config(credentials_strategy=ModelServingUserCredentials())
208+
headers = cfg.authenticate()
209+
assert cfg.host == "x"
210+
assert headers.get("Authorization") == f"Bearer databricks_invokers_token_v2"
211+
successful_authentication_event.set()
212+
except Exception:
213+
successful_authentication_event.clear()
214+
215+
thread = threading.Thread(target=authenticate)
216+
217+
thread.start()
218+
thread.join()
219+
assert successful_authentication_event.is_set()
220+
201221

202222
# If this credential strategy is being used in a non model serving environments then use default credential strategy instead
203223
def test_agent_user_credentials_in_non_model_serving_environments(monkeypatch):

0 commit comments

Comments
 (0)