diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 2fdb85ae0..0260621f6 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -4,6 +4,8 @@ ### New Features and Improvements +* Added support for passing additional kwargs to `WorkspaceClient().serving_endpoints.get_open_ai_client()` ([#1025](https://github.com/databricks/databricks-sdk-py/pull/1025)). Users can now pass standard OpenAI client parameters like `timeout` and `max_retries` when creating an OpenAI client for Databricks Model Serving. + ### Bug Fixes ### Documentation diff --git a/databricks/sdk/mixins/open_ai_client.py b/databricks/sdk/mixins/open_ai_client.py index 7d16f519d..4ab08ee5a 100644 --- a/databricks/sdk/mixins/open_ai_client.py +++ b/databricks/sdk/mixins/open_ai_client.py @@ -31,7 +31,41 @@ def auth_flow(self, request: httpx.Request) -> httpx.Request: http_client = httpx.Client(auth=databricks_token_auth) return http_client - def get_open_ai_client(self): + def get_open_ai_client(self, **kwargs): + """Create an OpenAI client configured for Databricks Model Serving. + + Returns an OpenAI client instance that is pre-configured to send requests to + Databricks Model Serving endpoints. The client uses Databricks authentication + to query endpoints within the workspace associated with the current WorkspaceClient + instance. + + Args: + **kwargs: Additional parameters to pass to the OpenAI client constructor. + Common parameters include: + - timeout (float): Request timeout in seconds (e.g., 30.0) + - max_retries (int): Maximum number of retries for failed requests (e.g., 3) + - default_headers (dict): Additional headers to include with requests + - default_query (dict): Additional query parameters to include with requests + + Any parameter accepted by the OpenAI client constructor can be passed here, + except for the following parameters which are reserved for Databricks integration: + base_url, api_key, http_client + + Returns: + OpenAI: An OpenAI client instance configured for Databricks Model Serving. + + Raises: + ImportError: If the OpenAI library is not installed. + ValueError: If any reserved Databricks parameters are provided in kwargs. + + Example: + >>> client = workspace_client.serving_endpoints.get_open_ai_client() + >>> # With custom timeout and retries + >>> client = workspace_client.serving_endpoints.get_open_ai_client( + ... timeout=30.0, + ... max_retries=5 + ... ) + """ try: from openai import OpenAI except Exception: @@ -39,11 +73,26 @@ def get_open_ai_client(self): "Open AI is not installed. Please install the Databricks SDK with the following command `pip install databricks-sdk[openai]`" ) - return OpenAI( - base_url=self._api._cfg.host + "/serving-endpoints", - api_key="no-token", # Passing in a placeholder to pass validations, this will not be used - http_client=self._get_authorized_http_client(), - ) + # Check for reserved parameters that should not be overridden + reserved_params = {"base_url", "api_key", "http_client"} + conflicting_params = reserved_params.intersection(kwargs.keys()) + if conflicting_params: + raise ValueError( + f"Cannot override reserved Databricks parameters: {', '.join(sorted(conflicting_params))}. " + f"These parameters are automatically configured for Databricks Model Serving." + ) + + # Default parameters that are required for Databricks integration + client_params = { + "base_url": self._api._cfg.host + "/serving-endpoints", + "api_key": "no-token", # Passing in a placeholder to pass validations, this will not be used + "http_client": self._get_authorized_http_client(), + } + + # Update with any additional parameters passed by the user + client_params.update(kwargs) + + return OpenAI(**client_params) def get_langchain_chat_open_ai_client(self, model): try: diff --git a/tests/test_open_ai_mixin.py b/tests/test_open_ai_mixin.py index 5c17e48f7..dfc248d0a 100644 --- a/tests/test_open_ai_mixin.py +++ b/tests/test_open_ai_mixin.py @@ -19,6 +19,66 @@ def test_open_ai_client(monkeypatch): assert client.api_key == "no-token" +def test_open_ai_client_with_custom_params(monkeypatch): + from databricks.sdk import WorkspaceClient + + monkeypatch.setenv("DATABRICKS_HOST", "test_host") + monkeypatch.setenv("DATABRICKS_TOKEN", "test_token") + w = WorkspaceClient(config=Config()) + + # Test with timeout and max_retries parameters + client = w.serving_endpoints.get_open_ai_client(timeout=30.0, max_retries=3) + + assert client.base_url == "https://test_host/serving-endpoints/" + assert client.api_key == "no-token" + assert client.timeout == 30.0 + assert client.max_retries == 3 + + +def test_open_ai_client_with_additional_kwargs(monkeypatch): + from databricks.sdk import WorkspaceClient + + monkeypatch.setenv("DATABRICKS_HOST", "test_host") + monkeypatch.setenv("DATABRICKS_TOKEN", "test_token") + w = WorkspaceClient(config=Config()) + + # Test with additional kwargs that OpenAI client might accept + client = w.serving_endpoints.get_open_ai_client( + timeout=60.0, max_retries=5, default_headers={"Custom-Header": "test-value"} + ) + + assert client.base_url == "https://test_host/serving-endpoints/" + assert client.api_key == "no-token" + assert client.timeout == 60.0 + assert client.max_retries == 5 + assert "Custom-Header" in client.default_headers + assert client.default_headers["Custom-Header"] == "test-value" + + +def test_open_ai_client_prevents_reserved_param_override(monkeypatch): + from databricks.sdk import WorkspaceClient + + monkeypatch.setenv("DATABRICKS_HOST", "test_host") + monkeypatch.setenv("DATABRICKS_TOKEN", "test_token") + w = WorkspaceClient(config=Config()) + + # Test that trying to override base_url raises an error + with pytest.raises(ValueError, match="Cannot override reserved Databricks parameters: base_url"): + w.serving_endpoints.get_open_ai_client(base_url="https://custom-host") + + # Test that trying to override api_key raises an error + with pytest.raises(ValueError, match="Cannot override reserved Databricks parameters: api_key"): + w.serving_endpoints.get_open_ai_client(api_key="custom-key") + + # Test that trying to override http_client raises an error + with pytest.raises(ValueError, match="Cannot override reserved Databricks parameters: http_client"): + w.serving_endpoints.get_open_ai_client(http_client=None) + + # Test that trying to override multiple reserved params shows all of them + with pytest.raises(ValueError, match="Cannot override reserved Databricks parameters: api_key, base_url"): + w.serving_endpoints.get_open_ai_client(base_url="https://custom-host", api_key="custom-key") + + @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python > 3.7") def test_langchain_open_ai_client(monkeypatch): from databricks.sdk import WorkspaceClient