Skip to content

Commit 04ff85e

Browse files
jxnlcursoragent
andauthored
Fix the issue (#1914)
Co-authored-by: Cursor Agent <[email protected]>
1 parent 1480675 commit 04ff85e

File tree

3 files changed

+242
-7
lines changed

3 files changed

+242
-7
lines changed

docs/integrations/databricks.md

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,27 @@ description: Guide to using instructor with Databricks models
1212
First, install the required packages:
1313

1414
```bash
15-
pip install instructor
15+
uv pip install instructor openai
1616
```
1717

18-
You'll need a Databricks API key and workspace URL which you can set as environment variables:
18+
Set your Databricks workspace URL and token as environment variables:
1919

2020
```bash
21-
export DATABRICKS_API_KEY=your_api_key_here
22-
export DATABRICKS_HOST=your_workspace_url
21+
export DATABRICKS_TOKEN="your_personal_access_token"
22+
export DATABRICKS_HOST="https://your-workspace.cloud.databricks.com"
2323
```
2424

25+
`DATABRICKS_API_KEY` and `DATABRICKS_WORKSPACE_URL` are also supported if you prefer those names. The provider appends `/serving-endpoints` automatically, so the host only needs the base workspace URL.
26+
2527
## Basic Example
2628

2729
Here's how to extract structured data from Databricks models:
2830

2931
```python
30-
import os
3132
import instructor
32-
from openai import OpenAI
3333
from pydantic import BaseModel
3434

35-
# Initialize the client with Databricks base URL
35+
# Initialize the client; host and token are read from the environment
3636
client = instructor.from_provider(
3737
"databricks/dbrx-instruct",
3838
mode=instructor.Mode.TOOLS,
@@ -55,6 +55,8 @@ print(user)
5555
# Output: UserExtract(name='Jason', age=25)
5656
```
5757

58+
If you need to point at a different workspace or testing endpoint, pass `base_url="https://alt-workspace.cloud.databricks.com/serving-endpoints"`. The helper will use that value as-is without adding another suffix.
59+
5860
### Async Example
5961

6062
```python

instructor/auto_client.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
supported_providers = [
1818
"openai",
1919
"azure_openai",
20+
"databricks",
2021
"anthropic",
2122
"google",
2223
"generative-ai",
@@ -262,6 +263,93 @@ def from_provider(
262263
)
263264
raise
264265

266+
elif provider == "databricks":
267+
try:
268+
import os
269+
import openai
270+
from instructor import from_openai # type: ignore[attr-defined]
271+
272+
api_key = api_key or os.environ.get("DATABRICKS_TOKEN") or os.environ.get(
273+
"DATABRICKS_API_KEY"
274+
)
275+
if not api_key:
276+
from .core.exceptions import ConfigurationError
277+
278+
raise ConfigurationError(
279+
"DATABRICKS_TOKEN is not set. "
280+
"Set it with `export DATABRICKS_TOKEN=<your-token>` or `export DATABRICKS_API_KEY=<your-token>` "
281+
"or pass it as kwarg `api_key=<your-token>`."
282+
)
283+
284+
base_url = kwargs.pop("base_url", None)
285+
if base_url is None:
286+
base_url = (
287+
os.environ.get("DATABRICKS_BASE_URL")
288+
or os.environ.get("DATABRICKS_HOST")
289+
or os.environ.get("DATABRICKS_WORKSPACE_URL")
290+
)
291+
292+
if not base_url:
293+
from .core.exceptions import ConfigurationError
294+
295+
raise ConfigurationError(
296+
"DATABRICKS_HOST is not set. "
297+
"Set it with `export DATABRICKS_HOST=<your-workspace-url>` or `export DATABRICKS_WORKSPACE_URL=<your-workspace-url>` "
298+
"or pass `base_url=<your-workspace-url>`."
299+
)
300+
301+
base_url = str(base_url).rstrip("/")
302+
if not base_url.endswith("/serving-endpoints"):
303+
base_url = f"{base_url}/serving-endpoints"
304+
305+
openai_client_kwargs = {}
306+
for key in (
307+
"organization",
308+
"timeout",
309+
"max_retries",
310+
"default_headers",
311+
"http_client",
312+
"app_info",
313+
):
314+
if key in kwargs:
315+
openai_client_kwargs[key] = kwargs.pop(key)
316+
317+
client = (
318+
openai.AsyncOpenAI(
319+
api_key=api_key, base_url=base_url, **openai_client_kwargs
320+
)
321+
if async_client
322+
else openai.OpenAI(
323+
api_key=api_key, base_url=base_url, **openai_client_kwargs
324+
)
325+
)
326+
result = from_openai(
327+
client,
328+
model=model_name,
329+
mode=mode if mode else instructor.Mode.TOOLS,
330+
**kwargs,
331+
)
332+
logger.info(
333+
"Client initialized",
334+
extra={**provider_info, "status": "success"},
335+
)
336+
return result
337+
except ImportError:
338+
from .core.exceptions import ConfigurationError
339+
340+
raise ConfigurationError(
341+
"The openai package is required to use the Databricks provider. "
342+
"Install it with `pip install openai`."
343+
) from None
344+
except Exception as e:
345+
logger.error(
346+
"Error initializing %s client: %s",
347+
provider,
348+
e,
349+
exc_info=True,
350+
extra={**provider_info, "status": "error"},
351+
)
352+
raise
265353
elif provider == "anthropic":
266354
try:
267355
import anthropic

tests/test_auto_client.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,151 @@ def test_api_key_logging():
270270
)
271271

272272

273+
def test_databricks_provider_uses_environment_configuration():
274+
"""Ensure Databricks provider pulls host and token from the environment."""
275+
from unittest.mock import patch, MagicMock
276+
import os
277+
278+
with patch("openai.OpenAI") as mock_openai_class:
279+
mock_client = MagicMock()
280+
mock_openai_class.return_value = mock_client
281+
282+
with patch("instructor.from_openai") as mock_from_openai:
283+
mock_instructor = MagicMock()
284+
mock_from_openai.return_value = mock_instructor
285+
286+
with patch.dict(
287+
os.environ,
288+
{
289+
"DATABRICKS_HOST": "https://example.cloud.databricks.com",
290+
"DATABRICKS_TOKEN": "secret-token",
291+
},
292+
clear=True,
293+
):
294+
client = from_provider("databricks/dbrx-instruct")
295+
296+
mock_openai_class.assert_called_once()
297+
_, kwargs = mock_openai_class.call_args
298+
assert kwargs["api_key"] == "secret-token"
299+
assert (
300+
kwargs["base_url"]
301+
== "https://example.cloud.databricks.com/serving-endpoints"
302+
)
303+
mock_from_openai.assert_called_once()
304+
assert client is mock_instructor
305+
306+
307+
def test_databricks_provider_respects_custom_base_url():
308+
"""Ensure Databricks provider does not duplicate serving-endpoints suffix."""
309+
from unittest.mock import patch, MagicMock
310+
import os
311+
312+
with patch("openai.OpenAI") as mock_openai_class:
313+
mock_client = MagicMock()
314+
mock_openai_class.return_value = mock_client
315+
316+
with patch("instructor.from_openai") as mock_from_openai:
317+
mock_instructor = MagicMock()
318+
mock_from_openai.return_value = mock_instructor
319+
320+
with patch.dict(
321+
os.environ,
322+
{
323+
"DATABRICKS_TOKEN": "secret-token",
324+
},
325+
clear=True,
326+
):
327+
client = from_provider(
328+
"databricks/dbrx-instruct",
329+
base_url="https://example.cloud.databricks.com/serving-endpoints",
330+
)
331+
332+
_, kwargs = mock_openai_class.call_args
333+
assert (
334+
kwargs["base_url"]
335+
== "https://example.cloud.databricks.com/serving-endpoints"
336+
)
337+
mock_from_openai.assert_called_once()
338+
assert client is mock_instructor
339+
340+
341+
def test_databricks_provider_async_client():
342+
"""Ensure Databricks provider returns async client when requested."""
343+
from unittest.mock import patch, MagicMock
344+
import os
345+
346+
with patch("openai.AsyncOpenAI") as mock_async_openai_class:
347+
mock_client = MagicMock()
348+
mock_async_openai_class.return_value = mock_client
349+
350+
with patch("instructor.from_openai") as mock_from_openai:
351+
mock_instructor = MagicMock()
352+
mock_from_openai.return_value = mock_instructor
353+
354+
with patch.dict(
355+
os.environ,
356+
{
357+
"DATABRICKS_HOST": "https://example.cloud.databricks.com",
358+
"DATABRICKS_TOKEN": "secret-token",
359+
},
360+
clear=True,
361+
):
362+
client = from_provider(
363+
"databricks/dbrx-instruct", async_client=True
364+
)
365+
366+
mock_async_openai_class.assert_called_once()
367+
_, kwargs = mock_async_openai_class.call_args
368+
assert (
369+
kwargs["base_url"]
370+
== "https://example.cloud.databricks.com/serving-endpoints"
371+
)
372+
assert kwargs["api_key"] == "secret-token"
373+
mock_from_openai.assert_called_once()
374+
assert client is mock_instructor
375+
376+
377+
def test_databricks_provider_requires_token():
378+
"""Ensure Databricks provider raises when no token is available."""
379+
from instructor.core.exceptions import ConfigurationError
380+
from unittest.mock import patch, MagicMock
381+
import os
382+
383+
with patch("openai.OpenAI") as mock_openai_class:
384+
mock_openai_class.return_value = MagicMock()
385+
with patch("instructor.from_openai") as mock_from_openai:
386+
mock_from_openai.return_value = MagicMock()
387+
with patch.dict(
388+
os.environ,
389+
{
390+
"DATABRICKS_HOST": "https://example.cloud.databricks.com",
391+
},
392+
clear=True,
393+
):
394+
with pytest.raises(ConfigurationError):
395+
from_provider("databricks/dbrx-instruct")
396+
397+
398+
def test_databricks_provider_requires_host():
399+
"""Ensure Databricks provider raises when no host is available."""
400+
from instructor.core.exceptions import ConfigurationError
401+
from unittest.mock import patch, MagicMock
402+
import os
403+
404+
with patch("openai.OpenAI") as mock_openai_class:
405+
mock_openai_class.return_value = MagicMock()
406+
with patch("instructor.from_openai") as mock_from_openai:
407+
mock_from_openai.return_value = MagicMock()
408+
with patch.dict(
409+
os.environ,
410+
{
411+
"DATABRICKS_TOKEN": "secret-token",
412+
},
413+
clear=True,
414+
):
415+
with pytest.raises(ConfigurationError):
416+
from_provider("databricks/dbrx-instruct")
417+
273418
def test_genai_mode_parameter_passed_to_provider():
274419
"""Test that mode parameter is correctly passed to provider functions."""
275420
from unittest.mock import patch, MagicMock

0 commit comments

Comments
 (0)