Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion integrations/langchain/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@ dependencies = [
"unitycatalog-langchain[databricks]>=0.2.0",
"databricks-connect>=16.1.1,<16.4",
"openai>=1.97.1",
"databricks-sdk>=0.63.0",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: update this after SDK releases

]

[project.optional-dependencies]
dev = [
"pytest",
"typing_extensions",
"databricks-sdk>=0.34.0",
"databricks-sdk>=0.63.0",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we add it to the normal list of dependencies, i think we can delete it from the dev reqs

"ruff==0.6.4",
]

Expand Down
18 changes: 17 additions & 1 deletion integrations/langchain/src/databricks_langchain/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class ChatDatabricks(BaseChatModel):
model="databricks-claude-3-7-sonnet",
temperature=0,
max_tokens=500,
timeout=30.0, # Timeout in seconds
max_retries=3, # Maximum number of retries
)

# Using a WorkspaceClient instance for custom authentication
Expand Down Expand Up @@ -248,6 +250,10 @@ class GetPopulation(BaseModel):
"""Any extra parameters to pass to the endpoint."""
use_responses_api: bool = False
"""Whether to use the Responses API to format inputs and outputs."""
timeout: Optional[float] = None
"""Timeout in seconds for the HTTP request. If None, uses the default timeout."""
max_retries: Optional[int] = None
"""Maximum number of retries for failed requests. If None, uses the default retry count."""
client: Optional[object] = Field(default=None, exclude=True) #: :meta private:

@property
Expand Down Expand Up @@ -288,7 +294,17 @@ def __init__(self, **kwargs: Any):
)

# Always use OpenAI client (supports both chat completions and responses API)
self.client = get_openai_client(workspace_client=self.workspace_client)
# Prepare kwargs for the SDK call
openai_kwargs = {}
if self.timeout is not None:
openai_kwargs["timeout"] = self.timeout
if self.max_retries is not None:
openai_kwargs["max_retries"] = self.max_retries

self.client = get_openai_client(
workspace_client=self.workspace_client,
**openai_kwargs
)

self.use_responses_api = kwargs.get("use_responses_api", False)
self.extra_params = self.extra_params or {}
Expand Down
14 changes: 8 additions & 6 deletions integrations/langchain/src/databricks_langchain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,25 @@ def get_deployment_client(target_uri: str) -> Any:
) from e


def get_openai_client(workspace_client: Any = None) -> Any:
def get_openai_client(workspace_client: Any = None, **kwargs) -> Any:
"""Get an OpenAI client configured for Databricks.

Args:
workspace_client: Optional WorkspaceClient instance to use for authentication.
If not provided, creates a default WorkspaceClient.
**kwargs: Additional keyword arguments to pass to get_open_ai_client(),
such as timeout and max_retries.
"""
try:
from databricks.sdk import WorkspaceClient

# If workspace_client is provided, use it directly
if workspace_client is not None:
return workspace_client.serving_endpoints.get_open_ai_client()

# Otherwise, create default workspace client
workspace_client = WorkspaceClient()
return workspace_client.serving_endpoints.get_open_ai_client()
return workspace_client.serving_endpoints.get_open_ai_client(**kwargs)
else:
# Otherwise, create default workspace client
workspace_client = WorkspaceClient()
return workspace_client.serving_endpoints.get_open_ai_client(**kwargs)

except ImportError as e:
raise ImportError(
Expand Down
44 changes: 44 additions & 0 deletions integrations/langchain/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,3 +715,47 @@ def test_chat_databricks_utf8_encoding(model):
if hasattr(chunk, "content") and chunk.content:
full_content += chunk.content
assert "blåbær" in full_content.lower()


def test_chat_databricks_with_timeout_and_retries():
"""Test that ChatDatabricks can be initialized with timeout and max_retries parameters."""
from unittest.mock import Mock, patch

# Mock the OpenAI client
mock_openai_client = Mock()
mock_workspace_client = Mock()
mock_workspace_client.serving_endpoints.get_open_ai_client.return_value = mock_openai_client

with patch("databricks.sdk.WorkspaceClient", return_value=mock_workspace_client):
# Create ChatDatabricks with timeout and max_retries
chat = ChatDatabricks(
model="databricks-meta-llama-3-3-70b-instruct",
timeout=45.0,
max_retries=3
)

# Verify the parameters are set correctly
assert chat.timeout == 45.0
assert chat.max_retries == 3

# Verify the client was configured with these parameters
assert chat.client == mock_openai_client

# Test with workspace_client parameter
with patch("databricks_langchain.chat_models.get_openai_client", return_value=mock_openai_client) as mock_get_client:
chat_with_ws = ChatDatabricks(
model="databricks-meta-llama-3-3-70b-instruct",
workspace_client=mock_workspace_client,
timeout=30.0,
max_retries=2
)

# Verify get_openai_client was called with all parameters
mock_get_client.assert_called_once_with(
workspace_client=mock_workspace_client,
timeout=30.0,
max_retries=2
)

assert chat_with_ws.timeout == 30.0
assert chat_with_ws.max_retries == 2
77 changes: 74 additions & 3 deletions integrations/langchain/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ def test_workspace_client_parameter() -> None:
llm = ChatDatabricks(model="test-model", workspace_client=mock_workspace_client)

assert llm.client == mock_openai_client
mock_get_client.assert_called_once_with(workspace_client=mock_workspace_client)
# Now expects no additional kwargs when timeout/max_retries are None
mock_get_client.assert_called_once_with(
workspace_client=mock_workspace_client
)


def test_workspace_client_and_target_uri_conflict() -> None:
Expand All @@ -91,6 +94,68 @@ def test_workspace_client_and_target_uri_conflict() -> None:
)


def test_timeout_and_max_retries_parameters() -> None:
"""Test that timeout and max_retries parameters are properly passed to the OpenAI client."""
from unittest.mock import Mock, patch

mock_openai_client = Mock()
mock_openai_client.timeout = None
mock_openai_client.max_retries = None

with patch(
"databricks_langchain.chat_models.get_openai_client", return_value=mock_openai_client
) as mock_get_client:
# Test with timeout and max_retries
llm = ChatDatabricks(
model="test-model",
timeout=60.0,
max_retries=5
)

# Verify get_openai_client was called with the correct parameters
mock_get_client.assert_called_once_with(
workspace_client=None,
timeout=60.0,
max_retries=5
)

# Test that client is set
assert llm.client == mock_openai_client
assert llm.timeout == 60.0
assert llm.max_retries == 5


def test_timeout_and_max_retries_with_workspace_client() -> None:
"""Test timeout and max_retries parameters work with workspace_client."""
from unittest.mock import Mock, patch

mock_workspace_client = Mock()
mock_openai_client = Mock()
mock_openai_client.timeout = None
mock_openai_client.max_retries = None

with patch(
"databricks_langchain.chat_models.get_openai_client", return_value=mock_openai_client
) as mock_get_client:
llm = ChatDatabricks(
model="test-model",
workspace_client=mock_workspace_client,
timeout=30.0,
max_retries=2
)

# Verify get_openai_client was called with all parameters
mock_get_client.assert_called_once_with(
workspace_client=mock_workspace_client,
timeout=30.0,
max_retries=2
)

assert llm.client == mock_openai_client
assert llm.timeout == 30.0
assert llm.max_retries == 2


def test_default_workspace_client() -> None:
"""Test that default WorkspaceClient is created when none provided."""
from unittest.mock import Mock, patch
Expand All @@ -107,7 +172,10 @@ def test_default_workspace_client() -> None:
llm = ChatDatabricks(model="test-model")

assert llm.client == mock_openai_client
mock_get_client.assert_called_once_with(workspace_client=None)
# Now expects no additional kwargs when timeout/max_retries are None
mock_get_client.assert_called_once_with(
workspace_client=None
)


def test_target_uri_deprecation_warning() -> None:
Expand Down Expand Up @@ -960,7 +1028,10 @@ def test_chat_databricks_init_sets_client():

llm = ChatDatabricks(model="test-model")

mock_get_client.assert_called_once_with(workspace_client=None)
# Now expects no additional kwargs when timeout/max_retries are None
mock_get_client.assert_called_once_with(
workspace_client=None
)
assert llm.client == mock_client


Expand Down
70 changes: 70 additions & 0 deletions integrations/langchain/tests/unit_tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Test utilities module."""

from unittest.mock import Mock, patch

import pytest

from databricks_langchain.utils import get_openai_client


def test_get_openai_client_with_timeout_and_max_retries() -> None:
"""Test that get_openai_client properly passes timeout and max_retries as kwargs to the SDK."""

mock_openai_client = Mock()

mock_workspace_client = Mock()
mock_workspace_client.serving_endpoints.get_open_ai_client.return_value = mock_openai_client

# Test with workspace_client, timeout, and max_retries
client = get_openai_client(
workspace_client=mock_workspace_client,
timeout=45.0,
max_retries=3
)

# Verify the OpenAI client was obtained with the correct kwargs
mock_workspace_client.serving_endpoints.get_open_ai_client.assert_called_once_with(
timeout=45.0,
max_retries=3
)

# Verify the client is returned
assert client == mock_openai_client


def test_get_openai_client_with_default_workspace_client() -> None:
"""Test get_openai_client creates default WorkspaceClient when none provided."""

mock_openai_client = Mock()

mock_workspace_client = Mock()
mock_workspace_client.serving_endpoints.get_open_ai_client.return_value = mock_openai_client

with patch("databricks.sdk.WorkspaceClient", return_value=mock_workspace_client):
client = get_openai_client(timeout=30.0, max_retries=2)

# Verify default WorkspaceClient was created and kwargs were passed
mock_workspace_client.serving_endpoints.get_open_ai_client.assert_called_once_with(
timeout=30.0,
max_retries=2
)

# Verify the client is returned
assert client == mock_openai_client


def test_get_openai_client_without_timeout_and_retries() -> None:
"""Test get_openai_client doesn't pass kwargs when not provided."""

mock_openai_client = Mock()

mock_workspace_client = Mock()
mock_workspace_client.serving_endpoints.get_open_ai_client.return_value = mock_openai_client

client = get_openai_client(workspace_client=mock_workspace_client)

# Verify the OpenAI client was obtained without kwargs
mock_workspace_client.serving_endpoints.get_open_ai_client.assert_called_once_with()

# Verify the client is returned
assert client == mock_openai_client
Loading