Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions libs/community/langchain_community/chat_models/moonshot.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
"""Wrapper around Moonshot chat models."""

from typing import Dict
from typing import Any, Callable, Dict, Sequence, Type, Union

from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import AIMessage
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
pre_init,
)
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import BaseModel

from langchain_community.chat_models import ChatOpenAI
from langchain_community.llms.moonshot import MOONSHOT_SERVICE_URL_BASE, MoonshotCommon
Expand Down Expand Up @@ -172,9 +178,11 @@ def validate_environment(cls, values: Dict) -> Dict:

client_params = {
"api_key": values["moonshot_api_key"].get_secret_value(),
"base_url": values["base_url"]
if "base_url" in values
else MOONSHOT_SERVICE_URL_BASE,
"base_url": (
values["base_url"]
if "base_url" in values
else MOONSHOT_SERVICE_URL_BASE
),
}

if not values.get("client"):
Expand All @@ -185,3 +193,22 @@ def validate_environment(cls, values: Dict) -> Dict:
).chat.completions

return values

def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, AIMessage]:
"""Bind tool-like objects to this chat model.

Args:
tools: A list of tool definitions to bind to this chat model.
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""

formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)
13 changes: 13 additions & 0 deletions libs/community/tests/unit_tests/chat_models/test_moonshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest
from langchain_core.runnables import Runnable

from langchain_community.chat_models.moonshot import MoonshotChat

mock_tool_list = [lambda: f"tool-id-{i}" for i in range(3)]


@pytest.mark.requires("openai")
def test_moonshot_bind_tools() -> None:
llm = MoonshotChat(name="moonshot")
ret: Runnable = llm.bind_tools(mock_tool_list)
assert len(ret.kwargs["tools"]) == 3