Skip to content

Commit 1c75815

Browse files
authored
[Feature] Open AI Client Mixin (#779)
## Changes Add Open AI Client Mixing with the Serving Endpoints API. Open AI Client requires a token to be authenticated. Therefore we are moving the creation of OpenAI client to the databricks sdk so that users can easily use it in both the notebook and model serving environments ## Tests Dogfood Test: https://e2-dogfood.staging.cloud.databricks.com/editor/notebooks/2337940012762945?o=6051921418418893 - [x] `make test` run locally - [x] `make fmt` applied - [ ] relevant integration tests applied --------- Signed-off-by: aravind-segu <[email protected]>
1 parent a3794b1 commit 1c75815

File tree

6 files changed

+104
-4
lines changed

6 files changed

+104
-4
lines changed

.codegen/__init__.py.tmpl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ from databricks.sdk.credentials_provider import CredentialsStrategy
55
from databricks.sdk.mixins.files import DbfsExt
66
from databricks.sdk.mixins.compute import ClustersExt
77
from databricks.sdk.mixins.workspace import WorkspaceExt
8+
from databricks.sdk.mixins.open_ai_client import ServingEndpointsExt
89
{{- range .Services}}
910
from databricks.sdk.service.{{.Package.Name}} import {{.PascalName}}API{{end}}
1011
from databricks.sdk.service.provisioning import Workspace
@@ -17,7 +18,7 @@ from typing import Optional
1718
"google_credentials" "google_service_account" }}
1819

1920
{{- define "api" -}}
20-
{{- $mixins := dict "ClustersAPI" "ClustersExt" "DbfsAPI" "DbfsExt" "WorkspaceAPI" "WorkspaceExt" -}}
21+
{{- $mixins := dict "ClustersAPI" "ClustersExt" "DbfsAPI" "DbfsExt" "WorkspaceAPI" "WorkspaceExt" "ServingEndpointsExt" "ServingEndpointsApi" -}}
2122
{{- $genApi := concat .PascalName "API" -}}
2223
{{- getOrDefault $mixins $genApi $genApi -}}
2324
{{- end -}}

NOTICE

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,22 @@ googleapis/google-auth-library-python - https://github.com/googleapis/google-aut
1212
Copyright google-auth-library-python authors
1313
License - https://github.com/googleapis/google-auth-library-python/blob/main/LICENSE
1414

15+
openai/openai-python - https://github.com/openai/openai-python
16+
Copyright 2024 OpenAI
17+
License - https://github.com/openai/openai-python/blob/main/LICENSE
18+
1519
This software contains code from the following open source projects, licensed under the BSD (3-clause) license.
1620

1721
x/oauth2 - https://cs.opensource.google/go/x/oauth2/+/master:oauth2.go
1822
Copyright 2014 The Go Authors. All rights reserved.
1923
License - https://cs.opensource.google/go/x/oauth2/+/master:LICENSE
24+
25+
encode/httpx - https://github.com/encode/httpx
26+
Copyright 2019, Encode OSS Ltd
27+
License - https://github.com/encode/httpx/blob/master/LICENSE.md
28+
29+
This software contains code from the following open source projects, licensed under the MIT license:
30+
31+
langchain-ai/langchain - https://github.com/langchain-ai/langchain/blob/master/libs/partners/openai
32+
Copyright 2023 LangChain, Inc.
33+
License - https://github.com/langchain-ai/langchain/blob/master/libs/partners/openai/LICENSE

databricks/sdk/__init__.py

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from databricks.sdk.service.serving import ServingEndpointsAPI
2+
3+
4+
class ServingEndpointsExt(ServingEndpointsAPI):
5+
6+
# Using the HTTP Client to pass in the databricks authorization
7+
# This method will be called on every invocation, so when using with model serving will always get the refreshed token
8+
def _get_authorized_http_client(self):
9+
import httpx
10+
11+
class BearerAuth(httpx.Auth):
12+
13+
def __init__(self, get_headers_func):
14+
self.get_headers_func = get_headers_func
15+
16+
def auth_flow(self, request: httpx.Request) -> httpx.Request:
17+
auth_headers = self.get_headers_func()
18+
request.headers["Authorization"] = auth_headers["Authorization"]
19+
yield request
20+
21+
databricks_token_auth = BearerAuth(self._api._cfg.authenticate)
22+
23+
# Create an HTTP client with Bearer Token authentication
24+
http_client = httpx.Client(auth=databricks_token_auth)
25+
return http_client
26+
27+
def get_open_ai_client(self):
28+
try:
29+
from openai import OpenAI
30+
except Exception:
31+
raise ImportError(
32+
"Open AI is not installed. Please install the Databricks SDK with the following command `pip isntall databricks-sdk[openai]`"
33+
)
34+
35+
return OpenAI(
36+
base_url=self._api._cfg.host + "/serving-endpoints",
37+
api_key="no-token", # Passing in a placeholder to pass validations, this will not be used
38+
http_client=self._get_authorized_http_client())
39+
40+
def get_langchain_chat_open_ai_client(self, model):
41+
try:
42+
from langchain_openai import ChatOpenAI
43+
except Exception:
44+
raise ImportError(
45+
"Langchain Open AI is not installed. Please install the Databricks SDK with the following command `pip isntall databricks-sdk[openai]` and ensure you are using python>3.7"
46+
)
47+
48+
return ChatOpenAI(
49+
model=model,
50+
openai_api_base=self._api._cfg.host + "/serving-endpoints",
51+
api_key="no-token", # Passing in a placeholder to pass validations, this will not be used
52+
http_client=self._get_authorized_http_client())

setup.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
extras_require={"dev": ["pytest", "pytest-cov", "pytest-xdist", "pytest-mock",
1818
"yapf", "pycodestyle", "autoflake", "isort", "wheel",
1919
"ipython", "ipywidgets", "requests-mock", "pyfakefs",
20-
"databricks-connect", "pytest-rerunfailures"],
21-
"notebook": ["ipython>=8,<9", "ipywidgets>=8,<9"]},
20+
"databricks-connect", "pytest-rerunfailures", "openai",
21+
'langchain-openai; python_version > "3.7"', "httpx"],
22+
"notebook": ["ipython>=8,<9", "ipywidgets>=8,<9"],
23+
"openai": ["openai", 'langchain-openai; python_version > "3.7"', "httpx"]},
2224
author="Serge Smertin",
2325
author_email="[email protected]",
2426
description="Databricks SDK for Python (Beta)",

tests/test_open_ai_mixin.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import sys
2+
3+
import pytest
4+
5+
from databricks.sdk.core import Config
6+
7+
8+
def test_open_ai_client(monkeypatch):
9+
from databricks.sdk import WorkspaceClient
10+
11+
monkeypatch.setenv('DATABRICKS_HOST', 'test_host')
12+
monkeypatch.setenv('DATABRICKS_TOKEN', 'test_token')
13+
w = WorkspaceClient(config=Config())
14+
client = w.serving_endpoints.get_open_ai_client()
15+
16+
assert client.base_url == "https://test_host/serving-endpoints/"
17+
assert client.api_key == "no-token"
18+
19+
20+
@pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python > 3.7")
21+
def test_langchain_open_ai_client(monkeypatch):
22+
from databricks.sdk import WorkspaceClient
23+
24+
monkeypatch.setenv('DATABRICKS_HOST', 'test_host')
25+
monkeypatch.setenv('DATABRICKS_TOKEN', 'test_token')
26+
w = WorkspaceClient(config=Config())
27+
client = w.serving_endpoints.get_langchain_chat_open_ai_client("databricks-meta-llama-3-1-70b-instruct")
28+
29+
assert client.openai_api_base == "https://test_host/serving-endpoints"
30+
assert client.model_name == "databricks-meta-llama-3-1-70b-instruct"

0 commit comments

Comments
 (0)