Skip to content

Commit 9250922

Browse files
committed
Add Langchain Open AI Client
1 parent cbd4e30 commit 9250922

File tree

3 files changed

+30
-4
lines changed

3 files changed

+30
-4
lines changed

databricks/sdk/mixins/open_ai_client.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
class ServingEndpointsExt(ServingEndpointsAPI):
55

6-
def get_open_api_client(self):
6+
def get_open_ai_client(self):
77
auth_headers = self._api._cfg.authenticate()
88

99
try:
@@ -13,3 +13,16 @@ def get_open_api_client(self):
1313

1414
from openai import OpenAI
1515
return OpenAI(base_url=self._api._cfg.host + "/serving-endpoints", api_key=token)
16+
17+
def get_langchain_chat_open_ai_client(self, model):
18+
auth_headers = self._api._cfg.authenticate()
19+
20+
try:
21+
token = auth_headers["Authorization"][len("Bearer "):]
22+
except Exception:
23+
raise ValueError("Unable to extract authorization token for Langchain OpenAI Client")
24+
25+
from langchain_openai import ChatOpenAI
26+
return ChatOpenAI(model=model,
27+
openai_api_base=self._api._cfg.host + "/serving-endpoints",
28+
openai_api_key=token)

setup.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
extras_require={"dev": ["pytest", "pytest-cov", "pytest-xdist", "pytest-mock",
1818
"yapf", "pycodestyle", "autoflake", "isort", "wheel",
1919
"ipython", "ipywidgets", "requests-mock", "pyfakefs",
20-
"databricks-connect", "pytest-rerunfailures", "openai"],
20+
"databricks-connect", "pytest-rerunfailures", "openai",
21+
"langchain-openai"],
2122
"notebook": ["ipython>=8,<9", "ipywidgets>=8,<9"],
22-
"openai": ["openai"]},
23+
"openai": ["openai", "langchain-openai"]},
2324
author="Serge Smertin",
2425
author_email="[email protected]",
2526
description="Databricks SDK for Python (Beta)",

tests/test_open_ai_mixin.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,19 @@ def test_open_ai_client(monkeypatch):
77
monkeypatch.setenv('DATABRICKS_HOST', 'test_host')
88
monkeypatch.setenv('DATABRICKS_TOKEN', 'test_token')
99
w = WorkspaceClient(config=Config())
10-
client = w.serving_endpoints.get_open_api_client()
10+
client = w.serving_endpoints.get_open_ai_client()
1111

1212
assert client.base_url == "https://test_host/serving-endpoints/"
1313
assert client.api_key == "test_token"
14+
15+
16+
def test_langchain_open_ai_client(monkeypatch):
17+
from databricks.sdk import WorkspaceClient
18+
19+
monkeypatch.setenv('DATABRICKS_HOST', 'test_host')
20+
monkeypatch.setenv('DATABRICKS_TOKEN', 'test_token')
21+
w = WorkspaceClient(config=Config())
22+
client = w.serving_endpoints.get_langchain_chat_open_ai_client("databricks-meta-llama-3-1-70b-instruct")
23+
24+
assert client.openai_api_base == "https://test_host/serving-endpoints"
25+
assert client.model_name == "databricks-meta-llama-3-1-70b-instruct"

0 commit comments

Comments
 (0)