Skip to content

Commit c9b48dc

Browse files
authored
Add timeout and retry params to ChatDatabricks (#165)
Signed-off-by: Sid Murching <[email protected]>
1 parent d4f5ec7 commit c9b48dc

File tree

6 files changed

+176
-8
lines changed

6 files changed

+176
-8
lines changed

integrations/langchain/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@ dependencies = [
1515
"mlflow>=2.20.1",
1616
"pydantic>2.10.0",
1717
"unitycatalog-langchain[databricks]>=0.2.0",
18+
"databricks-sdk>=0.65.0",
1819
"openai>=1.99.9",
1920
]
2021

2122
[project.optional-dependencies]
2223
dev = [
2324
"pytest",
2425
"typing_extensions",
25-
"databricks-sdk>=0.34.0",
2626
"ruff==0.6.4",
2727
]
2828

integrations/langchain/src/databricks_langchain/chat_models.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ class ChatDatabricks(BaseChatModel):
7272
model="databricks-claude-3-7-sonnet",
7373
temperature=0,
7474
max_tokens=500,
75+
timeout=30.0, # Timeout in seconds
76+
max_retries=3, # Maximum number of retries
7577
)
7678
7779
# Using a WorkspaceClient instance for custom authentication
@@ -248,6 +250,10 @@ class GetPopulation(BaseModel):
248250
"""Any extra parameters to pass to the endpoint."""
249251
use_responses_api: bool = False
250252
"""Whether to use the Responses API to format inputs and outputs."""
253+
timeout: Optional[float] = None
254+
"""Timeout in seconds for the HTTP request. If None, uses the default timeout."""
255+
max_retries: Optional[int] = None
256+
"""Maximum number of retries for failed requests. If None, uses the default retry count."""
251257
client: Optional[object] = Field(default=None, exclude=True) #: :meta private:
252258

253259
@property
@@ -288,7 +294,14 @@ def __init__(self, **kwargs: Any):
288294
)
289295

290296
# Always use OpenAI client (supports both chat completions and responses API)
291-
self.client = get_openai_client(workspace_client=self.workspace_client)
297+
# Prepare kwargs for the SDK call
298+
openai_kwargs = {}
299+
if self.timeout is not None:
300+
openai_kwargs["timeout"] = self.timeout
301+
if self.max_retries is not None:
302+
openai_kwargs["max_retries"] = self.max_retries
303+
304+
self.client = get_openai_client(workspace_client=self.workspace_client, **openai_kwargs)
292305

293306
self.use_responses_api = kwargs.get("use_responses_api", False)
294307
self.extra_params = self.extra_params or {}

integrations/langchain/src/databricks_langchain/utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,25 @@ def get_deployment_client(target_uri: str) -> Any:
2020
) from e
2121

2222

23-
def get_openai_client(workspace_client: Any = None) -> Any:
23+
def get_openai_client(workspace_client: Any = None, **kwargs) -> Any:
2424
"""Get an OpenAI client configured for Databricks.
2525
2626
Args:
2727
workspace_client: Optional WorkspaceClient instance to use for authentication.
2828
If not provided, creates a default WorkspaceClient.
29+
**kwargs: Additional keyword arguments to pass to get_open_ai_client(),
30+
such as timeout and max_retries.
2931
"""
3032
try:
3133
from databricks.sdk import WorkspaceClient
3234

3335
# If workspace_client is provided, use it directly
3436
if workspace_client is not None:
35-
return workspace_client.serving_endpoints.get_open_ai_client()
36-
37-
# Otherwise, create default workspace client
38-
workspace_client = WorkspaceClient()
39-
return workspace_client.serving_endpoints.get_open_ai_client()
37+
return workspace_client.serving_endpoints.get_open_ai_client(**kwargs)
38+
else:
39+
# Otherwise, create default workspace client
40+
workspace_client = WorkspaceClient()
41+
return workspace_client.serving_endpoints.get_open_ai_client(**kwargs)
4042

4143
except ImportError as e:
4244
raise ImportError(

integrations/langchain/tests/integration_tests/test_chat_models.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,3 +715,45 @@ def test_chat_databricks_utf8_encoding(model):
715715
if hasattr(chunk, "content") and chunk.content:
716716
full_content += chunk.content
717717
assert "blåbær" in full_content.lower()
718+
719+
720+
def test_chat_databricks_with_timeout_and_retries():
721+
"""Test that ChatDatabricks can be initialized with timeout and max_retries parameters."""
722+
from unittest.mock import Mock, patch
723+
724+
# Mock the OpenAI client
725+
mock_openai_client = Mock()
726+
mock_workspace_client = Mock()
727+
mock_workspace_client.serving_endpoints.get_open_ai_client.return_value = mock_openai_client
728+
729+
with patch("databricks.sdk.WorkspaceClient", return_value=mock_workspace_client):
730+
# Create ChatDatabricks with timeout and max_retries
731+
chat = ChatDatabricks(
732+
model="databricks-meta-llama-3-3-70b-instruct", timeout=45.0, max_retries=3
733+
)
734+
735+
# Verify the parameters are set correctly
736+
assert chat.timeout == 45.0
737+
assert chat.max_retries == 3
738+
739+
# Verify the client was configured with these parameters
740+
assert chat.client == mock_openai_client
741+
742+
# Test with workspace_client parameter
743+
with patch(
744+
"databricks_langchain.chat_models.get_openai_client", return_value=mock_openai_client
745+
) as mock_get_client:
746+
chat_with_ws = ChatDatabricks(
747+
model="databricks-meta-llama-3-3-70b-instruct",
748+
workspace_client=mock_workspace_client,
749+
timeout=30.0,
750+
max_retries=2,
751+
)
752+
753+
# Verify get_openai_client was called with all parameters
754+
mock_get_client.assert_called_once_with(
755+
workspace_client=mock_workspace_client, timeout=30.0, max_retries=2
756+
)
757+
758+
assert chat_with_ws.timeout == 30.0
759+
assert chat_with_ws.max_retries == 2

integrations/langchain/tests/unit_tests/test_chat_models.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,55 @@ def test_workspace_client_and_target_uri_conflict() -> None:
9191
)
9292

9393

94+
def test_timeout_and_max_retries_parameters() -> None:
95+
"""Test that timeout and max_retries parameters are properly passed to the OpenAI client."""
96+
from unittest.mock import Mock, patch
97+
98+
mock_openai_client = Mock()
99+
mock_openai_client.timeout = None
100+
mock_openai_client.max_retries = None
101+
102+
with patch(
103+
"databricks_langchain.chat_models.get_openai_client", return_value=mock_openai_client
104+
) as mock_get_client:
105+
# Test with timeout and max_retries
106+
llm = ChatDatabricks(model="test-model", timeout=60.0, max_retries=5)
107+
108+
# Verify get_openai_client was called with the correct parameters
109+
mock_get_client.assert_called_once_with(workspace_client=None, timeout=60.0, max_retries=5)
110+
111+
# Test that client is set
112+
assert llm.client == mock_openai_client
113+
assert llm.timeout == 60.0
114+
assert llm.max_retries == 5
115+
116+
117+
def test_timeout_and_max_retries_with_workspace_client() -> None:
118+
"""Test timeout and max_retries parameters work with workspace_client."""
119+
from unittest.mock import Mock, patch
120+
121+
mock_workspace_client = Mock()
122+
mock_openai_client = Mock()
123+
mock_openai_client.timeout = None
124+
mock_openai_client.max_retries = None
125+
126+
with patch(
127+
"databricks_langchain.chat_models.get_openai_client", return_value=mock_openai_client
128+
) as mock_get_client:
129+
llm = ChatDatabricks(
130+
model="test-model", workspace_client=mock_workspace_client, timeout=30.0, max_retries=2
131+
)
132+
133+
# Verify get_openai_client was called with all parameters
134+
mock_get_client.assert_called_once_with(
135+
workspace_client=mock_workspace_client, timeout=30.0, max_retries=2
136+
)
137+
138+
assert llm.client == mock_openai_client
139+
assert llm.timeout == 30.0
140+
assert llm.max_retries == 2
141+
142+
94143
def test_default_workspace_client() -> None:
95144
"""Test that default WorkspaceClient is created when none provided."""
96145
from unittest.mock import Mock, patch
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""Test utilities module."""
2+
3+
from unittest.mock import Mock, patch
4+
5+
from databricks_langchain.utils import get_openai_client
6+
7+
8+
def test_get_openai_client_with_timeout_and_max_retries() -> None:
9+
"""Test that get_openai_client properly passes timeout and max_retries as kwargs to the SDK."""
10+
11+
mock_openai_client = Mock()
12+
13+
mock_workspace_client = Mock()
14+
mock_workspace_client.serving_endpoints.get_open_ai_client.return_value = mock_openai_client
15+
16+
# Test with workspace_client, timeout, and max_retries
17+
client = get_openai_client(workspace_client=mock_workspace_client, timeout=45.0, max_retries=3)
18+
19+
# Verify the OpenAI client was obtained with the correct kwargs
20+
mock_workspace_client.serving_endpoints.get_open_ai_client.assert_called_once_with(
21+
timeout=45.0, max_retries=3
22+
)
23+
24+
# Verify the client is returned
25+
assert client == mock_openai_client
26+
27+
28+
def test_get_openai_client_with_default_workspace_client() -> None:
29+
"""Test get_openai_client creates default WorkspaceClient when none provided."""
30+
31+
mock_openai_client = Mock()
32+
33+
mock_workspace_client = Mock()
34+
mock_workspace_client.serving_endpoints.get_open_ai_client.return_value = mock_openai_client
35+
36+
with patch("databricks.sdk.WorkspaceClient", return_value=mock_workspace_client):
37+
client = get_openai_client(timeout=30.0, max_retries=2)
38+
39+
# Verify default WorkspaceClient was created and kwargs were passed
40+
mock_workspace_client.serving_endpoints.get_open_ai_client.assert_called_once_with(
41+
timeout=30.0, max_retries=2
42+
)
43+
44+
# Verify the client is returned
45+
assert client == mock_openai_client
46+
47+
48+
def test_get_openai_client_without_timeout_and_retries() -> None:
49+
"""Test get_openai_client doesn't pass kwargs when not provided."""
50+
51+
mock_openai_client = Mock()
52+
53+
mock_workspace_client = Mock()
54+
mock_workspace_client.serving_endpoints.get_open_ai_client.return_value = mock_openai_client
55+
56+
client = get_openai_client(workspace_client=mock_workspace_client)
57+
58+
# Verify the OpenAI client was obtained without kwargs
59+
mock_workspace_client.serving_endpoints.get_open_ai_client.assert_called_once_with()
60+
61+
# Verify the client is returned
62+
assert client == mock_openai_client

0 commit comments

Comments
 (0)