diff --git a/libs/community/langchain_community/chat_models/moonshot.py b/libs/community/langchain_community/chat_models/moonshot.py index 6d31426fd..48ffd4949 100644 --- a/libs/community/langchain_community/chat_models/moonshot.py +++ b/libs/community/langchain_community/chat_models/moonshot.py @@ -1,12 +1,18 @@ """Wrapper around Moonshot chat models.""" -from typing import Dict +from typing import Any, Callable, Dict, Sequence, Type, Union +from langchain_core.language_models import LanguageModelInput +from langchain_core.messages import AIMessage +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool from langchain_core.utils import ( convert_to_secret_str, get_from_dict_or_env, pre_init, ) +from langchain_core.utils.function_calling import convert_to_openai_tool +from pydantic import BaseModel from langchain_community.chat_models import ChatOpenAI from langchain_community.llms.moonshot import MOONSHOT_SERVICE_URL_BASE, MoonshotCommon @@ -172,9 +178,11 @@ def validate_environment(cls, values: Dict) -> Dict: client_params = { "api_key": values["moonshot_api_key"].get_secret_value(), - "base_url": values["base_url"] - if "base_url" in values - else MOONSHOT_SERVICE_URL_BASE, + "base_url": ( + values["base_url"] + if "base_url" in values + else MOONSHOT_SERVICE_URL_BASE + ), } if not values.get("client"): @@ -185,3 +193,22 @@ def validate_environment(cls, values: Dict) -> Dict: ).chat.completions return values + + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, AIMessage]: + """Bind tool-like objects to this chat model. + + Args: + tools: A list of tool definitions to bind to this chat model. + Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic + models, callables, and BaseTools will be automatically converted to + their schema dictionary representation. + **kwargs: Any additional parameters to pass to the + :class:`~langchain.runnable.Runnable` constructor. + """ + + formatted_tools = [convert_to_openai_tool(tool) for tool in tools] + return super().bind(tools=formatted_tools, **kwargs) diff --git a/libs/community/tests/unit_tests/chat_models/test_moonshot.py b/libs/community/tests/unit_tests/chat_models/test_moonshot.py new file mode 100644 index 000000000..531d8c179 --- /dev/null +++ b/libs/community/tests/unit_tests/chat_models/test_moonshot.py @@ -0,0 +1,13 @@ +import pytest +from langchain_core.runnables import Runnable + +from langchain_community.chat_models.moonshot import MoonshotChat + +mock_tool_list = [lambda: f"tool-id-{i}" for i in range(3)] + + +@pytest.mark.requires("openai") +def test_moonshot_bind_tools() -> None: + llm = MoonshotChat(name="moonshot") + ret: Runnable = llm.bind_tools(mock_tool_list) + assert len(ret.kwargs["tools"]) == 3