Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

### New Features and Improvements

* Added support for passing additional kwargs to `WorkspaceClient().serving_endpoints.get_open_ai_client()` ([#1025](https://github.com/databricks/databricks-sdk-py/pull/1025)). Users can now pass standard OpenAI client parameters like `timeout` and `max_retries` when creating an OpenAI client for Databricks Model Serving.

### Bug Fixes

### Documentation
Expand Down
61 changes: 55 additions & 6 deletions databricks/sdk/mixins/open_ai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,68 @@ def auth_flow(self, request: httpx.Request) -> httpx.Request:
http_client = httpx.Client(auth=databricks_token_auth)
return http_client

def get_open_ai_client(self):
def get_open_ai_client(self, **kwargs):
"""Create an OpenAI client configured for Databricks Model Serving.

Returns an OpenAI client instance that is pre-configured to send requests to
Databricks Model Serving endpoints. The client uses Databricks authentication
to query endpoints within the workspace associated with the current WorkspaceClient
instance.

Args:
**kwargs: Additional parameters to pass to the OpenAI client constructor.
Common parameters include:
- timeout (float): Request timeout in seconds (e.g., 30.0)
- max_retries (int): Maximum number of retries for failed requests (e.g., 3)
- default_headers (dict): Additional headers to include with requests
- default_query (dict): Additional query parameters to include with requests

Any parameter accepted by the OpenAI client constructor can be passed here,
except for the following parameters which are reserved for Databricks integration:
base_url, api_key, http_client

Returns:
OpenAI: An OpenAI client instance configured for Databricks Model Serving.

Raises:
ImportError: If the OpenAI library is not installed.
ValueError: If any reserved Databricks parameters are provided in kwargs.

Example:
>>> client = workspace_client.serving_endpoints.get_open_ai_client()
>>> # With custom timeout and retries
>>> client = workspace_client.serving_endpoints.get_open_ai_client(
... timeout=30.0,
... max_retries=5
... )
"""
try:
from openai import OpenAI
except Exception:
raise ImportError(
"Open AI is not installed. Please install the Databricks SDK with the following command `pip install databricks-sdk[openai]`"
)

return OpenAI(
base_url=self._api._cfg.host + "/serving-endpoints",
api_key="no-token", # Passing in a placeholder to pass validations, this will not be used
http_client=self._get_authorized_http_client(),
)
# Check for reserved parameters that should not be overridden
reserved_params = {"base_url", "api_key", "http_client"}
conflicting_params = reserved_params.intersection(kwargs.keys())
if conflicting_params:
raise ValueError(
f"Cannot override reserved Databricks parameters: {', '.join(sorted(conflicting_params))}. "
f"These parameters are automatically configured for Databricks Model Serving."
)

# Default parameters that are required for Databricks integration
client_params = {
"base_url": self._api._cfg.host + "/serving-endpoints",
"api_key": "no-token", # Passing in a placeholder to pass validations, this will not be used
"http_client": self._get_authorized_http_client(),
}

# Update with any additional parameters passed by the user
client_params.update(kwargs)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to make it so you can't override the default client_params since it'll break compatibility w/ databricks model serving? we could also document what the defaults are

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, seems safe to block overriding those & throw to start


return OpenAI(**client_params)

def get_langchain_chat_open_ai_client(self, model):
try:
Expand Down
60 changes: 60 additions & 0 deletions tests/test_open_ai_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,66 @@ def test_open_ai_client(monkeypatch):
assert client.api_key == "no-token"


def test_open_ai_client_with_custom_params(monkeypatch):
from databricks.sdk import WorkspaceClient

monkeypatch.setenv("DATABRICKS_HOST", "test_host")
monkeypatch.setenv("DATABRICKS_TOKEN", "test_token")
w = WorkspaceClient(config=Config())

# Test with timeout and max_retries parameters
client = w.serving_endpoints.get_open_ai_client(timeout=30.0, max_retries=3)

assert client.base_url == "https://test_host/serving-endpoints/"
assert client.api_key == "no-token"
assert client.timeout == 30.0
assert client.max_retries == 3


def test_open_ai_client_with_additional_kwargs(monkeypatch):
from databricks.sdk import WorkspaceClient

monkeypatch.setenv("DATABRICKS_HOST", "test_host")
monkeypatch.setenv("DATABRICKS_TOKEN", "test_token")
w = WorkspaceClient(config=Config())

# Test with additional kwargs that OpenAI client might accept
client = w.serving_endpoints.get_open_ai_client(
timeout=60.0, max_retries=5, default_headers={"Custom-Header": "test-value"}
)

assert client.base_url == "https://test_host/serving-endpoints/"
assert client.api_key == "no-token"
assert client.timeout == 60.0
assert client.max_retries == 5
assert "Custom-Header" in client.default_headers
assert client.default_headers["Custom-Header"] == "test-value"


def test_open_ai_client_prevents_reserved_param_override(monkeypatch):
from databricks.sdk import WorkspaceClient

monkeypatch.setenv("DATABRICKS_HOST", "test_host")
monkeypatch.setenv("DATABRICKS_TOKEN", "test_token")
w = WorkspaceClient(config=Config())

# Test that trying to override base_url raises an error
with pytest.raises(ValueError, match="Cannot override reserved Databricks parameters: base_url"):
w.serving_endpoints.get_open_ai_client(base_url="https://custom-host")

# Test that trying to override api_key raises an error
with pytest.raises(ValueError, match="Cannot override reserved Databricks parameters: api_key"):
w.serving_endpoints.get_open_ai_client(api_key="custom-key")

# Test that trying to override http_client raises an error
with pytest.raises(ValueError, match="Cannot override reserved Databricks parameters: http_client"):
w.serving_endpoints.get_open_ai_client(http_client=None)

# Test that trying to override multiple reserved params shows all of them
with pytest.raises(ValueError, match="Cannot override reserved Databricks parameters: api_key, base_url"):
w.serving_endpoints.get_open_ai_client(base_url="https://custom-host", api_key="custom-key")


@pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python > 3.7")
def test_langchain_open_ai_client(monkeypatch):
from databricks.sdk import WorkspaceClient
Expand Down
Loading