Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions databricks/sdk/mixins/open_ai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ def get_open_ai_client(self):
api_key="no-token", # Passing in a placeholder to pass validations, this will not be used
http_client=self._get_authorized_http_client())

def get_async_open_ai_client(self):
try:
from openai import AsyncOpenAI
except Exception:
raise ImportError(
"Open AI is not installed. Please install the Databricks SDK with the following command `pip install databricks-sdk[openai]`"
)

return AsyncOpenAI(
base_url=self._api._cfg.host + "/serving-endpoints",
api_key="no-token", # Passing in a placeholder to pass validations, this will not be used
http_client=self._get_authorized_http_client())

def get_langchain_chat_open_ai_client(self, model):
try:
from langchain_openai import ChatOpenAI
Expand Down
11 changes: 11 additions & 0 deletions tests/test_open_ai_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@ def test_open_ai_client(monkeypatch):
assert client.base_url == "https://test_host/serving-endpoints/"
assert client.api_key == "no-token"

def test_async_open_ai_client(monkeypatch):
from databricks.sdk import WorkspaceClient

monkeypatch.setenv('DATABRICKS_HOST', 'test_host')
monkeypatch.setenv('DATABRICKS_TOKEN', 'test_token')
w = WorkspaceClient(config=Config())
client = w.serving_endpoints.get_async_open_ai_client()

assert client.base_url == "https://test_host/serving-endpoints/"
assert client.api_key == "no-token"


@pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python > 3.7")
def test_langchain_open_ai_client(monkeypatch):
Expand Down
Loading