Skip to content

Commit 46c310b

Browse files
committed
Add open ai client mixin
1 parent 79b096f commit 46c310b

File tree

4 files changed

+33
-2
lines changed

4 files changed

+33
-2
lines changed

.codegen/__init__.py.tmpl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ from databricks.sdk.credentials_provider import CredentialsStrategy
55
from databricks.sdk.mixins.files import DbfsExt
66
from databricks.sdk.mixins.compute import ClustersExt
77
from databricks.sdk.mixins.workspace import WorkspaceExt
8+
from databricks.sdk.mixins.open_ai_client import ServingEndpointsExt
89
{{- range .Services}}
910
from databricks.sdk.service.{{.Package.Name}} import {{.PascalName}}API{{end}}
1011
from databricks.sdk.service.provisioning import Workspace
@@ -17,7 +18,7 @@ from typing import Optional
1718
"google_credentials" "google_service_account" }}
1819

1920
{{- define "api" -}}
20-
{{- $mixins := dict "ClustersAPI" "ClustersExt" "DbfsAPI" "DbfsExt" "WorkspaceAPI" "WorkspaceExt" -}}
21+
{{- $mixins := dict "ClustersAPI" "ClustersExt" "DbfsAPI" "DbfsExt" "WorkspaceAPI" "WorkspaceExt" "ServingEndpointsExt" "ServingEndpointsApi" -}}
2122
{{- $genApi := concat .PascalName "API" -}}
2223
{{- getOrDefault $mixins $genApi $genApi -}}
2324
{{- end -}}

databricks/sdk/__init__.py

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from databricks.sdk.service.serving import ServingEndpointsAPI
2+
3+
class ServingEndpointsExt(ServingEndpointsAPI):
4+
def get_open_api_client(self):
5+
auth_headers = self._api._cfg.authenticate()
6+
7+
try:
8+
token = auth_headers["Authorization"][len("Bearer "):]
9+
except Exception:
10+
raise ValueError("Unable to extract authorization token for OpenAI Client")
11+
12+
from openai import OpenAI
13+
return OpenAI(
14+
base_url=self._api._cfg.host + "/serving-endpoints",
15+
api_key=token
16+
)

tests/test_open_ai_mixin.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import pytest
2+
from databricks.sdk.core import Config
3+
4+
def test_open_ai_client(monkeypatch):
5+
from databricks.sdk import WorkspaceClient
6+
7+
monkeypatch.setenv('DATABRICKS_HOST', 'test_host')
8+
monkeypatch.setenv('DATABRICKS_TOKEN', 'test_token')
9+
w = WorkspaceClient(config=Config())
10+
client = w.serving_endpoints.get_open_api_client()
11+
12+
assert client.base_url == "https://test_host/serving-endpoints/"
13+
assert client.api_key == "test_token"

0 commit comments

Comments
 (0)