Skip to content

Commit 8d205cc

Browse files
authored
Add support for passing additional kwargs when instantiating an OpenAI client for Databricks model serving (#1025)
## What changes are proposed in this pull request? Per customer request, adds support for passing additional optional kwargs to `databricks.sdk.WorkspaceClient().get_open_ai_client()`. This enables use cases like specifying a custom timeout and retry policy for latency-sensitive workloads, e.g: ``` >>> # Configure custom timeout and retries >>> client = workspace_client.serving_endpoints.get_open_ai_client( ... timeout=30.0, ... max_retries=5 ... ) ``` ## How is this tested? Updated unit tests --------- Signed-off-by: Sid Murching <[email protected]>
1 parent ce437f4 commit 8d205cc

File tree

3 files changed

+117
-6
lines changed

3 files changed

+117
-6
lines changed

NEXT_CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
### New Features and Improvements
66

7+
* Added support for passing additional kwargs to `WorkspaceClient().serving_endpoints.get_open_ai_client()` ([#1025](https://github.com/databricks/databricks-sdk-py/pull/1025)). Users can now pass standard OpenAI client parameters like `timeout` and `max_retries` when creating an OpenAI client for Databricks Model Serving.
8+
79
### Bug Fixes
810

911
### Documentation

databricks/sdk/mixins/open_ai_client.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,68 @@ def auth_flow(self, request: httpx.Request) -> httpx.Request:
3131
http_client = httpx.Client(auth=databricks_token_auth)
3232
return http_client
3333

34-
def get_open_ai_client(self):
34+
def get_open_ai_client(self, **kwargs):
35+
"""Create an OpenAI client configured for Databricks Model Serving.
36+
37+
Returns an OpenAI client instance that is pre-configured to send requests to
38+
Databricks Model Serving endpoints. The client uses Databricks authentication
39+
to query endpoints within the workspace associated with the current WorkspaceClient
40+
instance.
41+
42+
Args:
43+
**kwargs: Additional parameters to pass to the OpenAI client constructor.
44+
Common parameters include:
45+
- timeout (float): Request timeout in seconds (e.g., 30.0)
46+
- max_retries (int): Maximum number of retries for failed requests (e.g., 3)
47+
- default_headers (dict): Additional headers to include with requests
48+
- default_query (dict): Additional query parameters to include with requests
49+
50+
Any parameter accepted by the OpenAI client constructor can be passed here,
51+
except for the following parameters which are reserved for Databricks integration:
52+
base_url, api_key, http_client
53+
54+
Returns:
55+
OpenAI: An OpenAI client instance configured for Databricks Model Serving.
56+
57+
Raises:
58+
ImportError: If the OpenAI library is not installed.
59+
ValueError: If any reserved Databricks parameters are provided in kwargs.
60+
61+
Example:
62+
>>> client = workspace_client.serving_endpoints.get_open_ai_client()
63+
>>> # With custom timeout and retries
64+
>>> client = workspace_client.serving_endpoints.get_open_ai_client(
65+
... timeout=30.0,
66+
... max_retries=5
67+
... )
68+
"""
3569
try:
3670
from openai import OpenAI
3771
except Exception:
3872
raise ImportError(
3973
"Open AI is not installed. Please install the Databricks SDK with the following command `pip install databricks-sdk[openai]`"
4074
)
4175

42-
return OpenAI(
43-
base_url=self._api._cfg.host + "/serving-endpoints",
44-
api_key="no-token", # Passing in a placeholder to pass validations, this will not be used
45-
http_client=self._get_authorized_http_client(),
46-
)
76+
# Check for reserved parameters that should not be overridden
77+
reserved_params = {"base_url", "api_key", "http_client"}
78+
conflicting_params = reserved_params.intersection(kwargs.keys())
79+
if conflicting_params:
80+
raise ValueError(
81+
f"Cannot override reserved Databricks parameters: {', '.join(sorted(conflicting_params))}. "
82+
f"These parameters are automatically configured for Databricks Model Serving."
83+
)
84+
85+
# Default parameters that are required for Databricks integration
86+
client_params = {
87+
"base_url": self._api._cfg.host + "/serving-endpoints",
88+
"api_key": "no-token", # Passing in a placeholder to pass validations, this will not be used
89+
"http_client": self._get_authorized_http_client(),
90+
}
91+
92+
# Update with any additional parameters passed by the user
93+
client_params.update(kwargs)
94+
95+
return OpenAI(**client_params)
4796

4897
def get_langchain_chat_open_ai_client(self, model):
4998
try:

tests/test_open_ai_mixin.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,66 @@ def test_open_ai_client(monkeypatch):
1919
assert client.api_key == "no-token"
2020

2121

22+
def test_open_ai_client_with_custom_params(monkeypatch):
23+
from databricks.sdk import WorkspaceClient
24+
25+
monkeypatch.setenv("DATABRICKS_HOST", "test_host")
26+
monkeypatch.setenv("DATABRICKS_TOKEN", "test_token")
27+
w = WorkspaceClient(config=Config())
28+
29+
# Test with timeout and max_retries parameters
30+
client = w.serving_endpoints.get_open_ai_client(timeout=30.0, max_retries=3)
31+
32+
assert client.base_url == "https://test_host/serving-endpoints/"
33+
assert client.api_key == "no-token"
34+
assert client.timeout == 30.0
35+
assert client.max_retries == 3
36+
37+
38+
def test_open_ai_client_with_additional_kwargs(monkeypatch):
39+
from databricks.sdk import WorkspaceClient
40+
41+
monkeypatch.setenv("DATABRICKS_HOST", "test_host")
42+
monkeypatch.setenv("DATABRICKS_TOKEN", "test_token")
43+
w = WorkspaceClient(config=Config())
44+
45+
# Test with additional kwargs that OpenAI client might accept
46+
client = w.serving_endpoints.get_open_ai_client(
47+
timeout=60.0, max_retries=5, default_headers={"Custom-Header": "test-value"}
48+
)
49+
50+
assert client.base_url == "https://test_host/serving-endpoints/"
51+
assert client.api_key == "no-token"
52+
assert client.timeout == 60.0
53+
assert client.max_retries == 5
54+
assert "Custom-Header" in client.default_headers
55+
assert client.default_headers["Custom-Header"] == "test-value"
56+
57+
58+
def test_open_ai_client_prevents_reserved_param_override(monkeypatch):
59+
from databricks.sdk import WorkspaceClient
60+
61+
monkeypatch.setenv("DATABRICKS_HOST", "test_host")
62+
monkeypatch.setenv("DATABRICKS_TOKEN", "test_token")
63+
w = WorkspaceClient(config=Config())
64+
65+
# Test that trying to override base_url raises an error
66+
with pytest.raises(ValueError, match="Cannot override reserved Databricks parameters: base_url"):
67+
w.serving_endpoints.get_open_ai_client(base_url="https://custom-host")
68+
69+
# Test that trying to override api_key raises an error
70+
with pytest.raises(ValueError, match="Cannot override reserved Databricks parameters: api_key"):
71+
w.serving_endpoints.get_open_ai_client(api_key="custom-key")
72+
73+
# Test that trying to override http_client raises an error
74+
with pytest.raises(ValueError, match="Cannot override reserved Databricks parameters: http_client"):
75+
w.serving_endpoints.get_open_ai_client(http_client=None)
76+
77+
# Test that trying to override multiple reserved params shows all of them
78+
with pytest.raises(ValueError, match="Cannot override reserved Databricks parameters: api_key, base_url"):
79+
w.serving_endpoints.get_open_ai_client(base_url="https://custom-host", api_key="custom-key")
80+
81+
2282
@pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python > 3.7")
2383
def test_langchain_open_ai_client(monkeypatch):
2484
from databricks.sdk import WorkspaceClient

0 commit comments

Comments
 (0)