|
| 1 | +"""Tests ChatMLX tool calling.""" |
| 2 | + |
| 3 | +from typing import Dict |
| 4 | + |
| 5 | +import pytest |
| 6 | +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage |
| 7 | +from langchain_core.tools import tool |
| 8 | + |
| 9 | +from langchain_community.chat_models.mlx import ChatMLX |
| 10 | +from langchain_community.llms.mlx_pipeline import MLXPipeline |
| 11 | + |
| 12 | +# Use a Phi-3 model for more reliable tool-calling behavior |
| 13 | +MODEL_ID = "mlx-community/phi-3-mini-128k-instruct" |
| 14 | + |
| 15 | + |
| 16 | +@tool |
| 17 | +def multiply(a: int, b: int) -> int: |
| 18 | + """Multiply two integers.""" |
| 19 | + return a * b |
| 20 | + |
| 21 | + |
| 22 | +@pytest.fixture(scope="module") |
| 23 | +def chat() -> ChatMLX: |
| 24 | + """Return ChatMLX bound with the multiply tool or skip if unavailable.""" |
| 25 | + try: |
| 26 | + llm = MLXPipeline.from_model_id( |
| 27 | + model_id=MODEL_ID, pipeline_kwargs={"max_new_tokens": 150} |
| 28 | + ) |
| 29 | + except Exception: |
| 30 | + pytest.skip("Required MLX model isn't available.", allow_module_level=True) |
| 31 | + chat_model = ChatMLX(llm=llm) |
| 32 | + return chat_model.bind_tools(tools=[multiply], tool_choice=True) # type: ignore[return-value] |
| 33 | + |
| 34 | + |
| 35 | +def _call_tool(tool_call: Dict) -> ToolMessage: |
| 36 | + result = multiply.invoke(tool_call["args"]) |
| 37 | + return ToolMessage(content=str(result), tool_call_id=tool_call.get("id", "")) |
| 38 | + |
| 39 | + |
| 40 | +def test_mlx_tool_calls_soft(chat: ChatMLX) -> None: |
| 41 | + messages = [HumanMessage(content="Use the multiply tool to compute 2 * 3.")] |
| 42 | + ai_msg = chat.invoke(messages) |
| 43 | + tool_msg = _call_tool(ai_msg.tool_calls[0]) |
| 44 | + final = chat.invoke(messages + [ai_msg, tool_msg]) |
| 45 | + assert "6" in final.content |
| 46 | + |
| 47 | + |
| 48 | +def test_mlx_tool_calls_hard(chat: ChatMLX) -> None: |
| 49 | + messages = [HumanMessage(content="Use the multiply tool to compute 2 * 3.")] |
| 50 | + ai_msg = chat.invoke(messages) |
| 51 | + assert isinstance(ai_msg, AIMessage) |
| 52 | + assert ai_msg.tool_calls |
| 53 | + tool_call = ai_msg.tool_calls[0] |
| 54 | + assert tool_call["name"] == "multiply" |
| 55 | + assert tool_call["args"] == {"a": 2, "b": 3} |
0 commit comments