|
| 1 | +import httpx |
| 2 | + |
1 | 3 | from databricks.sdk.service.serving import ServingEndpointsAPI |
2 | 4 |
|
3 | 5 |
|
4 | 6 | class ServingEndpointsExt(ServingEndpointsAPI): |
5 | 7 |
|
6 | | - def get_open_ai_client(self): |
7 | | - auth_headers = self._api._cfg.authenticate() |
| 8 | + # Using the HTTP Client to pass in the databricks authorization |
| 9 | + # This method will be called on every invocation, so when using with model serving will always get the refreshed token |
| 10 | + def _get_authorized_http_client(self): |
8 | 11 |
|
9 | | - try: |
10 | | - token = auth_headers["Authorization"][len("Bearer "):] |
11 | | - except Exception: |
12 | | - raise ValueError("Unable to extract authorization token for OpenAI Client") |
| 12 | + class BearerAuth(httpx.Auth): |
| 13 | + |
| 14 | + def __init__(self, get_headers_func): |
| 15 | + self.get_headers_func = get_headers_func |
| 16 | + |
| 17 | + def auth_flow(self, request: httpx.Request) -> httpx.Request: |
| 18 | + auth_headers = self.get_headers_func() |
| 19 | + request.headers["Authorization"] = auth_headers["Authorization"] |
| 20 | + yield request |
| 21 | + |
| 22 | + databricks_token_auth = BearerAuth(self._api._cfg.authenticate) |
13 | 23 |
|
| 24 | + # Create an HTTP client with Bearer Token authentication |
| 25 | + http_client = httpx.Client(auth=databricks_token_auth) |
| 26 | + return http_client |
| 27 | + |
| 28 | + def get_open_ai_client(self): |
14 | 29 | try: |
15 | 30 | from openai import OpenAI |
16 | 31 | except Exception: |
17 | 32 | raise ImportError( |
18 | 33 | "Open AI is not installed. Please install the Databricks SDK with the following command `pip isntall databricks-sdk[openai]`" |
19 | 34 | ) |
20 | 35 |
|
21 | | - return OpenAI(base_url=self._api._cfg.host + "/serving-endpoints", api_key=token) |
| 36 | + return OpenAI( |
| 37 | + base_url=self._api._cfg.host + "/serving-endpoints", |
| 38 | + api_key="no-token", # Passing in a placeholder to pass validations, this will not be used |
| 39 | + http_client=self._get_authorized_http_client()) |
22 | 40 |
|
23 | 41 | def get_langchain_chat_open_ai_client(self, model): |
24 | | - auth_headers = self._api._cfg.authenticate() |
25 | | - |
26 | 42 | try: |
27 | 43 | from langchain_openai import ChatOpenAI |
28 | 44 | except Exception: |
29 | 45 | raise ImportError( |
30 | 46 | "Langchain Open AI is not installed. Please install the Databricks SDK with the following command `pip isntall databricks-sdk[openai]` and ensure you are using python>3.7" |
31 | 47 | ) |
32 | 48 |
|
33 | | - try: |
34 | | - token = auth_headers["Authorization"][len("Bearer "):] |
35 | | - except Exception: |
36 | | - raise ValueError("Unable to extract authorization token for Langchain OpenAI Client") |
37 | | - |
38 | | - return ChatOpenAI(model=model, |
39 | | - openai_api_base=self._api._cfg.host + "/serving-endpoints", |
40 | | - openai_api_key=token) |
| 49 | + return ChatOpenAI( |
| 50 | + model=model, |
| 51 | + openai_api_base=self._api._cfg.host + "/serving-endpoints", |
| 52 | + api_key="no-token", # Passing in a placeholder to pass validations, this will not be used |
| 53 | + http_client=self._get_authorized_http_client()) |
0 commit comments