From 96a9c90e5ae10c9a6efad8cb484c19759035e497 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michal=20Bel=C3=A1k?= Date: Mon, 6 Jan 2025 16:55:04 +0100 Subject: [PATCH] add `get_async_open_ai_client` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Michal Belák --- databricks/sdk/mixins/open_ai_client.py | 13 +++++++++++++ tests/test_open_ai_mixin.py | 11 +++++++++++ 2 files changed, 24 insertions(+) diff --git a/databricks/sdk/mixins/open_ai_client.py b/databricks/sdk/mixins/open_ai_client.py index a86827128..ac12650ae 100644 --- a/databricks/sdk/mixins/open_ai_client.py +++ b/databricks/sdk/mixins/open_ai_client.py @@ -37,6 +37,19 @@ def get_open_ai_client(self): api_key="no-token", # Passing in a placeholder to pass validations, this will not be used http_client=self._get_authorized_http_client()) + def get_async_open_ai_client(self): + try: + from openai import AsyncOpenAI + except Exception: + raise ImportError( + "Open AI is not installed. Please install the Databricks SDK with the following command `pip install databricks-sdk[openai]`" + ) + + return AsyncOpenAI( + base_url=self._api._cfg.host + "/serving-endpoints", + api_key="no-token", # Passing in a placeholder to pass validations, this will not be used + http_client=self._get_authorized_http_client()) + def get_langchain_chat_open_ai_client(self, model): try: from langchain_openai import ChatOpenAI diff --git a/tests/test_open_ai_mixin.py b/tests/test_open_ai_mixin.py index 1858c66cb..c7b406680 100644 --- a/tests/test_open_ai_mixin.py +++ b/tests/test_open_ai_mixin.py @@ -16,6 +16,17 @@ def test_open_ai_client(monkeypatch): assert client.base_url == "https://test_host/serving-endpoints/" assert client.api_key == "no-token" +def test_async_open_ai_client(monkeypatch): + from databricks.sdk import WorkspaceClient + + monkeypatch.setenv('DATABRICKS_HOST', 'test_host') + monkeypatch.setenv('DATABRICKS_TOKEN', 'test_token') + w = WorkspaceClient(config=Config()) + client = w.serving_endpoints.get_async_open_ai_client() + + assert client.base_url == "https://test_host/serving-endpoints/" + assert client.api_key == "no-token" + @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python > 3.7") def test_langchain_open_ai_client(monkeypatch):