Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion integrations/langchain/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ dependencies = [
"mlflow>=2.20.1",
"pydantic>2.10.0",
"unitycatalog-langchain[databricks]>=0.2.0",
"databricks-sdk>=0.65.0",
"openai>=1.99.9",
]

[project.optional-dependencies]
dev = [
"pytest",
"typing_extensions",
"databricks-sdk>=0.34.0",
"ruff==0.6.4",
]

Expand Down
15 changes: 14 additions & 1 deletion integrations/langchain/src/databricks_langchain/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class ChatDatabricks(BaseChatModel):
model="databricks-claude-3-7-sonnet",
temperature=0,
max_tokens=500,
timeout=30.0, # Timeout in seconds
max_retries=3, # Maximum number of retries
)

# Using a WorkspaceClient instance for custom authentication
Expand Down Expand Up @@ -248,6 +250,10 @@ class GetPopulation(BaseModel):
"""Any extra parameters to pass to the endpoint."""
use_responses_api: bool = False
"""Whether to use the Responses API to format inputs and outputs."""
timeout: Optional[float] = None
"""Timeout in seconds for the HTTP request. If None, uses the default timeout."""
max_retries: Optional[int] = None
"""Maximum number of retries for failed requests. If None, uses the default retry count."""
client: Optional[object] = Field(default=None, exclude=True) #: :meta private:

@property
Expand Down Expand Up @@ -288,7 +294,14 @@ def __init__(self, **kwargs: Any):
)

# Always use OpenAI client (supports both chat completions and responses API)
self.client = get_openai_client(workspace_client=self.workspace_client)
# Prepare kwargs for the SDK call
openai_kwargs = {}
if self.timeout is not None:
openai_kwargs["timeout"] = self.timeout
if self.max_retries is not None:
openai_kwargs["max_retries"] = self.max_retries

self.client = get_openai_client(workspace_client=self.workspace_client, **openai_kwargs)

self.use_responses_api = kwargs.get("use_responses_api", False)
self.extra_params = self.extra_params or {}
Expand Down
14 changes: 8 additions & 6 deletions integrations/langchain/src/databricks_langchain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,25 @@ def get_deployment_client(target_uri: str) -> Any:
) from e


def get_openai_client(workspace_client: Any = None) -> Any:
def get_openai_client(workspace_client: Any = None, **kwargs) -> Any:
"""Get an OpenAI client configured for Databricks.

Args:
workspace_client: Optional WorkspaceClient instance to use for authentication.
If not provided, creates a default WorkspaceClient.
**kwargs: Additional keyword arguments to pass to get_open_ai_client(),
such as timeout and max_retries.
"""
try:
from databricks.sdk import WorkspaceClient

# If workspace_client is provided, use it directly
if workspace_client is not None:
return workspace_client.serving_endpoints.get_open_ai_client()

# Otherwise, create default workspace client
workspace_client = WorkspaceClient()
return workspace_client.serving_endpoints.get_open_ai_client()
return workspace_client.serving_endpoints.get_open_ai_client(**kwargs)
else:
# Otherwise, create default workspace client
workspace_client = WorkspaceClient()
return workspace_client.serving_endpoints.get_open_ai_client(**kwargs)

except ImportError as e:
raise ImportError(
Expand Down
42 changes: 42 additions & 0 deletions integrations/langchain/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,3 +715,45 @@ def test_chat_databricks_utf8_encoding(model):
if hasattr(chunk, "content") and chunk.content:
full_content += chunk.content
assert "blåbær" in full_content.lower()


def test_chat_databricks_with_timeout_and_retries():
"""Test that ChatDatabricks can be initialized with timeout and max_retries parameters."""
from unittest.mock import Mock, patch

# Mock the OpenAI client
mock_openai_client = Mock()
mock_workspace_client = Mock()
mock_workspace_client.serving_endpoints.get_open_ai_client.return_value = mock_openai_client

with patch("databricks.sdk.WorkspaceClient", return_value=mock_workspace_client):
# Create ChatDatabricks with timeout and max_retries
chat = ChatDatabricks(
model="databricks-meta-llama-3-3-70b-instruct", timeout=45.0, max_retries=3
)

# Verify the parameters are set correctly
assert chat.timeout == 45.0
assert chat.max_retries == 3

# Verify the client was configured with these parameters
assert chat.client == mock_openai_client

# Test with workspace_client parameter
with patch(
"databricks_langchain.chat_models.get_openai_client", return_value=mock_openai_client
) as mock_get_client:
chat_with_ws = ChatDatabricks(
model="databricks-meta-llama-3-3-70b-instruct",
workspace_client=mock_workspace_client,
timeout=30.0,
max_retries=2,
)

# Verify get_openai_client was called with all parameters
mock_get_client.assert_called_once_with(
workspace_client=mock_workspace_client, timeout=30.0, max_retries=2
)

assert chat_with_ws.timeout == 30.0
assert chat_with_ws.max_retries == 2
49 changes: 49 additions & 0 deletions integrations/langchain/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,55 @@ def test_workspace_client_and_target_uri_conflict() -> None:
)


def test_timeout_and_max_retries_parameters() -> None:
"""Test that timeout and max_retries parameters are properly passed to the OpenAI client."""
from unittest.mock import Mock, patch

mock_openai_client = Mock()
mock_openai_client.timeout = None
mock_openai_client.max_retries = None

with patch(
"databricks_langchain.chat_models.get_openai_client", return_value=mock_openai_client
) as mock_get_client:
# Test with timeout and max_retries
llm = ChatDatabricks(model="test-model", timeout=60.0, max_retries=5)

# Verify get_openai_client was called with the correct parameters
mock_get_client.assert_called_once_with(workspace_client=None, timeout=60.0, max_retries=5)

# Test that client is set
assert llm.client == mock_openai_client
assert llm.timeout == 60.0
assert llm.max_retries == 5


def test_timeout_and_max_retries_with_workspace_client() -> None:
"""Test timeout and max_retries parameters work with workspace_client."""
from unittest.mock import Mock, patch

mock_workspace_client = Mock()
mock_openai_client = Mock()
mock_openai_client.timeout = None
mock_openai_client.max_retries = None

with patch(
"databricks_langchain.chat_models.get_openai_client", return_value=mock_openai_client
) as mock_get_client:
llm = ChatDatabricks(
model="test-model", workspace_client=mock_workspace_client, timeout=30.0, max_retries=2
)

# Verify get_openai_client was called with all parameters
mock_get_client.assert_called_once_with(
workspace_client=mock_workspace_client, timeout=30.0, max_retries=2
)

assert llm.client == mock_openai_client
assert llm.timeout == 30.0
assert llm.max_retries == 2


def test_default_workspace_client() -> None:
"""Test that default WorkspaceClient is created when none provided."""
from unittest.mock import Mock, patch
Expand Down
62 changes: 62 additions & 0 deletions integrations/langchain/tests/unit_tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Test utilities module."""

from unittest.mock import Mock, patch

from databricks_langchain.utils import get_openai_client


def test_get_openai_client_with_timeout_and_max_retries() -> None:
"""Test that get_openai_client properly passes timeout and max_retries as kwargs to the SDK."""

mock_openai_client = Mock()

mock_workspace_client = Mock()
mock_workspace_client.serving_endpoints.get_open_ai_client.return_value = mock_openai_client

# Test with workspace_client, timeout, and max_retries
client = get_openai_client(workspace_client=mock_workspace_client, timeout=45.0, max_retries=3)

# Verify the OpenAI client was obtained with the correct kwargs
mock_workspace_client.serving_endpoints.get_open_ai_client.assert_called_once_with(
timeout=45.0, max_retries=3
)

# Verify the client is returned
assert client == mock_openai_client


def test_get_openai_client_with_default_workspace_client() -> None:
"""Test get_openai_client creates default WorkspaceClient when none provided."""

mock_openai_client = Mock()

mock_workspace_client = Mock()
mock_workspace_client.serving_endpoints.get_open_ai_client.return_value = mock_openai_client

with patch("databricks.sdk.WorkspaceClient", return_value=mock_workspace_client):
client = get_openai_client(timeout=30.0, max_retries=2)

# Verify default WorkspaceClient was created and kwargs were passed
mock_workspace_client.serving_endpoints.get_open_ai_client.assert_called_once_with(
timeout=30.0, max_retries=2
)

# Verify the client is returned
assert client == mock_openai_client


def test_get_openai_client_without_timeout_and_retries() -> None:
"""Test get_openai_client doesn't pass kwargs when not provided."""

mock_openai_client = Mock()

mock_workspace_client = Mock()
mock_workspace_client.serving_endpoints.get_open_ai_client.return_value = mock_openai_client

client = get_openai_client(workspace_client=mock_workspace_client)

# Verify the OpenAI client was obtained without kwargs
mock_workspace_client.serving_endpoints.get_open_ai_client.assert_called_once_with()

# Verify the client is returned
assert client == mock_openai_client