Skip to content

Commit 9391bec

Browse files
authored
Merge branch 'main' into main
2 parents 73c08ed + 1c75815 commit 9391bec

File tree

6 files changed

+104
-4
lines changed

6 files changed

+104
-4
lines changed

.codegen/__init__.py.tmpl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ from databricks.sdk.credentials_provider import CredentialsStrategy
55
from databricks.sdk.mixins.files import DbfsExt
66
from databricks.sdk.mixins.compute import ClustersExt
77
from databricks.sdk.mixins.workspace import WorkspaceExt
8+
from databricks.sdk.mixins.open_ai_client import ServingEndpointsExt
89
{{- range .Services}}
910
from databricks.sdk.service.{{.Package.Name}} import {{.PascalName}}API{{end}}
1011
from databricks.sdk.service.provisioning import Workspace
@@ -17,7 +18,7 @@ from typing import Optional
1718
"google_credentials" "google_service_account" }}
1819

1920
{{- define "api" -}}
20-
{{- $mixins := dict "ClustersAPI" "ClustersExt" "DbfsAPI" "DbfsExt" "WorkspaceAPI" "WorkspaceExt" -}}
21+
{{- $mixins := dict "ClustersAPI" "ClustersExt" "DbfsAPI" "DbfsExt" "WorkspaceAPI" "WorkspaceExt" "ServingEndpointsExt" "ServingEndpointsApi" -}}
2122
{{- $genApi := concat .PascalName "API" -}}
2223
{{- getOrDefault $mixins $genApi $genApi -}}
2324
{{- end -}}

NOTICE

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,22 @@ googleapis/google-auth-library-python - https://github.com/googleapis/google-aut
1212
Copyright google-auth-library-python authors
1313
License - https://github.com/googleapis/google-auth-library-python/blob/main/LICENSE
1414

15+
openai/openai-python - https://github.com/openai/openai-python
16+
Copyright 2024 OpenAI
17+
License - https://github.com/openai/openai-python/blob/main/LICENSE
18+
1519
This software contains code from the following open source projects, licensed under the BSD (3-clause) license.
1620

1721
x/oauth2 - https://cs.opensource.google/go/x/oauth2/+/master:oauth2.go
1822
Copyright 2014 The Go Authors. All rights reserved.
1923
License - https://cs.opensource.google/go/x/oauth2/+/master:LICENSE
24+
25+
encode/httpx - https://github.com/encode/httpx
26+
Copyright 2019, Encode OSS Ltd
27+
License - https://github.com/encode/httpx/blob/master/LICENSE.md
28+
29+
This software contains code from the following open source projects, licensed under the MIT license:
30+
31+
langchain-ai/langchain - https://github.com/langchain-ai/langchain/blob/master/libs/partners/openai
32+
Copyright 2023 LangChain, Inc.
33+
License - https://github.com/langchain-ai/langchain/blob/master/libs/partners/openai/LICENSE

databricks/sdk/__init__.py

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from databricks.sdk.service.serving import ServingEndpointsAPI
2+
3+
4+
class ServingEndpointsExt(ServingEndpointsAPI):
5+
6+
# Using the HTTP Client to pass in the databricks authorization
7+
# This method will be called on every invocation, so when using with model serving will always get the refreshed token
8+
def _get_authorized_http_client(self):
9+
import httpx
10+
11+
class BearerAuth(httpx.Auth):
12+
13+
def __init__(self, get_headers_func):
14+
self.get_headers_func = get_headers_func
15+
16+
def auth_flow(self, request: httpx.Request) -> httpx.Request:
17+
auth_headers = self.get_headers_func()
18+
request.headers["Authorization"] = auth_headers["Authorization"]
19+
yield request
20+
21+
databricks_token_auth = BearerAuth(self._api._cfg.authenticate)
22+
23+
# Create an HTTP client with Bearer Token authentication
24+
http_client = httpx.Client(auth=databricks_token_auth)
25+
return http_client
26+
27+
def get_open_ai_client(self):
28+
try:
29+
from openai import OpenAI
30+
except Exception:
31+
raise ImportError(
32+
"Open AI is not installed. Please install the Databricks SDK with the following command `pip isntall databricks-sdk[openai]`"
33+
)
34+
35+
return OpenAI(
36+
base_url=self._api._cfg.host + "/serving-endpoints",
37+
api_key="no-token", # Passing in a placeholder to pass validations, this will not be used
38+
http_client=self._get_authorized_http_client())
39+
40+
def get_langchain_chat_open_ai_client(self, model):
41+
try:
42+
from langchain_openai import ChatOpenAI
43+
except Exception:
44+
raise ImportError(
45+
"Langchain Open AI is not installed. Please install the Databricks SDK with the following command `pip isntall databricks-sdk[openai]` and ensure you are using python>3.7"
46+
)
47+
48+
return ChatOpenAI(
49+
model=model,
50+
openai_api_base=self._api._cfg.host + "/serving-endpoints",
51+
api_key="no-token", # Passing in a placeholder to pass validations, this will not be used
52+
http_client=self._get_authorized_http_client())

setup.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
extras_require={"dev": ["pytest", "pytest-cov", "pytest-xdist", "pytest-mock",
1818
"yapf", "pycodestyle", "autoflake", "isort", "wheel",
1919
"ipython", "ipywidgets", "requests-mock", "pyfakefs",
20-
"databricks-connect", "pytest-rerunfailures"],
21-
"notebook": ["ipython>=8,<9", "ipywidgets>=8,<9"]},
20+
"databricks-connect", "pytest-rerunfailures", "openai",
21+
'langchain-openai; python_version > "3.7"', "httpx"],
22+
"notebook": ["ipython>=8,<9", "ipywidgets>=8,<9"],
23+
"openai": ["openai", 'langchain-openai; python_version > "3.7"', "httpx"]},
2224
author="Serge Smertin",
2325
author_email="[email protected]",
2426
description="Databricks SDK for Python (Beta)",

tests/test_open_ai_mixin.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import sys
2+
3+
import pytest
4+
5+
from databricks.sdk.core import Config
6+
7+
8+
def test_open_ai_client(monkeypatch):
9+
from databricks.sdk import WorkspaceClient
10+
11+
monkeypatch.setenv('DATABRICKS_HOST', 'test_host')
12+
monkeypatch.setenv('DATABRICKS_TOKEN', 'test_token')
13+
w = WorkspaceClient(config=Config())
14+
client = w.serving_endpoints.get_open_ai_client()
15+
16+
assert client.base_url == "https://test_host/serving-endpoints/"
17+
assert client.api_key == "no-token"
18+
19+
20+
@pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python > 3.7")
21+
def test_langchain_open_ai_client(monkeypatch):
22+
from databricks.sdk import WorkspaceClient
23+
24+
monkeypatch.setenv('DATABRICKS_HOST', 'test_host')
25+
monkeypatch.setenv('DATABRICKS_TOKEN', 'test_token')
26+
w = WorkspaceClient(config=Config())
27+
client = w.serving_endpoints.get_langchain_chat_open_ai_client("databricks-meta-llama-3-1-70b-instruct")
28+
29+
assert client.openai_api_base == "https://test_host/serving-endpoints"
30+
assert client.model_name == "databricks-meta-llama-3-1-70b-instruct"

0 commit comments

Comments
 (0)