diff --git a/integrations/langchain/pyproject.toml b/integrations/langchain/pyproject.toml index b03e1b2f..ffdf16ba 100644 --- a/integrations/langchain/pyproject.toml +++ b/integrations/langchain/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "mlflow>=2.20.1", "pydantic>2.10.0", "unitycatalog-langchain[databricks]>=0.2.0", + "databricks-sdk>=0.65.0", "openai>=1.99.9", ] @@ -22,7 +23,6 @@ dependencies = [ dev = [ "pytest", "typing_extensions", - "databricks-sdk>=0.34.0", "ruff==0.6.4", ] diff --git a/integrations/langchain/src/databricks_langchain/chat_models.py b/integrations/langchain/src/databricks_langchain/chat_models.py index 8a09c776..896e20df 100644 --- a/integrations/langchain/src/databricks_langchain/chat_models.py +++ b/integrations/langchain/src/databricks_langchain/chat_models.py @@ -72,6 +72,8 @@ class ChatDatabricks(BaseChatModel): model="databricks-claude-3-7-sonnet", temperature=0, max_tokens=500, + timeout=30.0, # Timeout in seconds + max_retries=3, # Maximum number of retries ) # Using a WorkspaceClient instance for custom authentication @@ -248,6 +250,10 @@ class GetPopulation(BaseModel): """Any extra parameters to pass to the endpoint.""" use_responses_api: bool = False """Whether to use the Responses API to format inputs and outputs.""" + timeout: Optional[float] = None + """Timeout in seconds for the HTTP request. If None, uses the default timeout.""" + max_retries: Optional[int] = None + """Maximum number of retries for failed requests. If None, uses the default retry count.""" client: Optional[object] = Field(default=None, exclude=True) #: :meta private: @property @@ -288,7 +294,14 @@ def __init__(self, **kwargs: Any): ) # Always use OpenAI client (supports both chat completions and responses API) - self.client = get_openai_client(workspace_client=self.workspace_client) + # Prepare kwargs for the SDK call + openai_kwargs = {} + if self.timeout is not None: + openai_kwargs["timeout"] = self.timeout + if self.max_retries is not None: + openai_kwargs["max_retries"] = self.max_retries + + self.client = get_openai_client(workspace_client=self.workspace_client, **openai_kwargs) self.use_responses_api = kwargs.get("use_responses_api", False) self.extra_params = self.extra_params or {} diff --git a/integrations/langchain/src/databricks_langchain/utils.py b/integrations/langchain/src/databricks_langchain/utils.py index 1494f101..fe35891e 100644 --- a/integrations/langchain/src/databricks_langchain/utils.py +++ b/integrations/langchain/src/databricks_langchain/utils.py @@ -20,23 +20,25 @@ def get_deployment_client(target_uri: str) -> Any: ) from e -def get_openai_client(workspace_client: Any = None) -> Any: +def get_openai_client(workspace_client: Any = None, **kwargs) -> Any: """Get an OpenAI client configured for Databricks. Args: workspace_client: Optional WorkspaceClient instance to use for authentication. If not provided, creates a default WorkspaceClient. + **kwargs: Additional keyword arguments to pass to get_open_ai_client(), + such as timeout and max_retries. """ try: from databricks.sdk import WorkspaceClient # If workspace_client is provided, use it directly if workspace_client is not None: - return workspace_client.serving_endpoints.get_open_ai_client() - - # Otherwise, create default workspace client - workspace_client = WorkspaceClient() - return workspace_client.serving_endpoints.get_open_ai_client() + return workspace_client.serving_endpoints.get_open_ai_client(**kwargs) + else: + # Otherwise, create default workspace client + workspace_client = WorkspaceClient() + return workspace_client.serving_endpoints.get_open_ai_client(**kwargs) except ImportError as e: raise ImportError( diff --git a/integrations/langchain/tests/integration_tests/test_chat_models.py b/integrations/langchain/tests/integration_tests/test_chat_models.py index dbda1dae..42233d37 100644 --- a/integrations/langchain/tests/integration_tests/test_chat_models.py +++ b/integrations/langchain/tests/integration_tests/test_chat_models.py @@ -715,3 +715,45 @@ def test_chat_databricks_utf8_encoding(model): if hasattr(chunk, "content") and chunk.content: full_content += chunk.content assert "blåbær" in full_content.lower() + + +def test_chat_databricks_with_timeout_and_retries(): + """Test that ChatDatabricks can be initialized with timeout and max_retries parameters.""" + from unittest.mock import Mock, patch + + # Mock the OpenAI client + mock_openai_client = Mock() + mock_workspace_client = Mock() + mock_workspace_client.serving_endpoints.get_open_ai_client.return_value = mock_openai_client + + with patch("databricks.sdk.WorkspaceClient", return_value=mock_workspace_client): + # Create ChatDatabricks with timeout and max_retries + chat = ChatDatabricks( + model="databricks-meta-llama-3-3-70b-instruct", timeout=45.0, max_retries=3 + ) + + # Verify the parameters are set correctly + assert chat.timeout == 45.0 + assert chat.max_retries == 3 + + # Verify the client was configured with these parameters + assert chat.client == mock_openai_client + + # Test with workspace_client parameter + with patch( + "databricks_langchain.chat_models.get_openai_client", return_value=mock_openai_client + ) as mock_get_client: + chat_with_ws = ChatDatabricks( + model="databricks-meta-llama-3-3-70b-instruct", + workspace_client=mock_workspace_client, + timeout=30.0, + max_retries=2, + ) + + # Verify get_openai_client was called with all parameters + mock_get_client.assert_called_once_with( + workspace_client=mock_workspace_client, timeout=30.0, max_retries=2 + ) + + assert chat_with_ws.timeout == 30.0 + assert chat_with_ws.max_retries == 2 diff --git a/integrations/langchain/tests/unit_tests/test_chat_models.py b/integrations/langchain/tests/unit_tests/test_chat_models.py index 943cb9b8..2ee6c12e 100644 --- a/integrations/langchain/tests/unit_tests/test_chat_models.py +++ b/integrations/langchain/tests/unit_tests/test_chat_models.py @@ -91,6 +91,55 @@ def test_workspace_client_and_target_uri_conflict() -> None: ) +def test_timeout_and_max_retries_parameters() -> None: + """Test that timeout and max_retries parameters are properly passed to the OpenAI client.""" + from unittest.mock import Mock, patch + + mock_openai_client = Mock() + mock_openai_client.timeout = None + mock_openai_client.max_retries = None + + with patch( + "databricks_langchain.chat_models.get_openai_client", return_value=mock_openai_client + ) as mock_get_client: + # Test with timeout and max_retries + llm = ChatDatabricks(model="test-model", timeout=60.0, max_retries=5) + + # Verify get_openai_client was called with the correct parameters + mock_get_client.assert_called_once_with(workspace_client=None, timeout=60.0, max_retries=5) + + # Test that client is set + assert llm.client == mock_openai_client + assert llm.timeout == 60.0 + assert llm.max_retries == 5 + + +def test_timeout_and_max_retries_with_workspace_client() -> None: + """Test timeout and max_retries parameters work with workspace_client.""" + from unittest.mock import Mock, patch + + mock_workspace_client = Mock() + mock_openai_client = Mock() + mock_openai_client.timeout = None + mock_openai_client.max_retries = None + + with patch( + "databricks_langchain.chat_models.get_openai_client", return_value=mock_openai_client + ) as mock_get_client: + llm = ChatDatabricks( + model="test-model", workspace_client=mock_workspace_client, timeout=30.0, max_retries=2 + ) + + # Verify get_openai_client was called with all parameters + mock_get_client.assert_called_once_with( + workspace_client=mock_workspace_client, timeout=30.0, max_retries=2 + ) + + assert llm.client == mock_openai_client + assert llm.timeout == 30.0 + assert llm.max_retries == 2 + + def test_default_workspace_client() -> None: """Test that default WorkspaceClient is created when none provided.""" from unittest.mock import Mock, patch diff --git a/integrations/langchain/tests/unit_tests/test_utils.py b/integrations/langchain/tests/unit_tests/test_utils.py new file mode 100644 index 00000000..8ee0f58e --- /dev/null +++ b/integrations/langchain/tests/unit_tests/test_utils.py @@ -0,0 +1,62 @@ +"""Test utilities module.""" + +from unittest.mock import Mock, patch + +from databricks_langchain.utils import get_openai_client + + +def test_get_openai_client_with_timeout_and_max_retries() -> None: + """Test that get_openai_client properly passes timeout and max_retries as kwargs to the SDK.""" + + mock_openai_client = Mock() + + mock_workspace_client = Mock() + mock_workspace_client.serving_endpoints.get_open_ai_client.return_value = mock_openai_client + + # Test with workspace_client, timeout, and max_retries + client = get_openai_client(workspace_client=mock_workspace_client, timeout=45.0, max_retries=3) + + # Verify the OpenAI client was obtained with the correct kwargs + mock_workspace_client.serving_endpoints.get_open_ai_client.assert_called_once_with( + timeout=45.0, max_retries=3 + ) + + # Verify the client is returned + assert client == mock_openai_client + + +def test_get_openai_client_with_default_workspace_client() -> None: + """Test get_openai_client creates default WorkspaceClient when none provided.""" + + mock_openai_client = Mock() + + mock_workspace_client = Mock() + mock_workspace_client.serving_endpoints.get_open_ai_client.return_value = mock_openai_client + + with patch("databricks.sdk.WorkspaceClient", return_value=mock_workspace_client): + client = get_openai_client(timeout=30.0, max_retries=2) + + # Verify default WorkspaceClient was created and kwargs were passed + mock_workspace_client.serving_endpoints.get_open_ai_client.assert_called_once_with( + timeout=30.0, max_retries=2 + ) + + # Verify the client is returned + assert client == mock_openai_client + + +def test_get_openai_client_without_timeout_and_retries() -> None: + """Test get_openai_client doesn't pass kwargs when not provided.""" + + mock_openai_client = Mock() + + mock_workspace_client = Mock() + mock_workspace_client.serving_endpoints.get_open_ai_client.return_value = mock_openai_client + + client = get_openai_client(workspace_client=mock_workspace_client) + + # Verify the OpenAI client was obtained without kwargs + mock_workspace_client.serving_endpoints.get_open_ai_client.assert_called_once_with() + + # Verify the client is returned + assert client == mock_openai_client