diff --git a/integrations/langchain/src/databricks_langchain/chat_models.py b/integrations/langchain/src/databricks_langchain/chat_models.py index df2760eb..3a1773b4 100644 --- a/integrations/langchain/src/databricks_langchain/chat_models.py +++ b/integrations/langchain/src/databricks_langchain/chat_models.py @@ -252,7 +252,10 @@ def endpoint(self, value: str) -> None: def __init__(self, **kwargs: Any): super().__init__(**kwargs) - self.client = get_deployment_client(self.target_uri) + if "client" in kwargs: + self.client = kwargs["client"] + else: + self.client = get_deployment_client(self.target_uri) self.extra_params = self.extra_params or {} @property diff --git a/integrations/langchain/tests/unit_tests/test_chat_models.py b/integrations/langchain/tests/unit_tests/test_chat_models.py index 451a845a..66e2eeec 100644 --- a/integrations/langchain/tests/unit_tests/test_chat_models.py +++ b/integrations/langchain/tests/unit_tests/test_chat_models.py @@ -1,6 +1,7 @@ """Test chat model integration.""" import json +from unittest.mock import MagicMock import mlflow # type: ignore # noqa: F401 import pytest @@ -365,3 +366,41 @@ def test_convert_response_to_chat_result_llm_output(llm: ChatDatabricks) -> None assert "content" not in result.llm_output assert "role" not in result.llm_output assert "type" not in result.llm_output + + +def test_chat_model_with_custom_client() -> None: + """Test that ChatDatabricks can be instantiated with a custom client.""" + mock_custom_client = MagicMock(spec=mlflow.deployments.BaseDeploymentClient) + # Configure the mock's predict method to return a valid response structure + mock_predict_response = { + "choices": [ + { + "message": {"role": "assistant", "content": "Test response"}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + "model": "test-model", + } + mock_custom_client.predict.return_value = mock_predict_response + + chat_model = ChatDatabricks( + model="test-model", + target_uri="databricks", # This shouldn't be used if client is provided + client=mock_custom_client, + ) + + assert chat_model.client is mock_custom_client + + sample_messages = [HumanMessage(content="Hello")] + chat_model._generate(messages=sample_messages) + + mock_custom_client.predict.assert_called_once() + # Check that the 'endpoint' argument to predict matches chat_model.model + _, call_kwargs = mock_custom_client.predict.call_args + assert call_kwargs["endpoint"] == "test-model" + # Check structure of 'inputs' argument + assert "messages" in call_kwargs["inputs"] + assert len(call_kwargs["inputs"]["messages"]) == 1 + assert call_kwargs["inputs"]["messages"][0]["content"] == "Hello" + assert call_kwargs["inputs"]["messages"][0]["role"] == "user"