Skip to content

Commit 8fda0d4

Browse files
diego-coderRN
authored andcommitted
test: use phi-3 model for ChatMLX tools
1 parent c210e07 commit 8fda0d4

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""Tests ChatMLX tool calling."""
2+
3+
from typing import Dict
4+
5+
import pytest
6+
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
7+
from langchain_core.tools import tool
8+
9+
from langchain_community.chat_models.mlx import ChatMLX
10+
from langchain_community.llms.mlx_pipeline import MLXPipeline
11+
12+
# Use a Phi-3 model for more reliable tool-calling behavior
13+
MODEL_ID = "mlx-community/phi-3-mini-128k-instruct"
14+
15+
16+
@tool
17+
def multiply(a: int, b: int) -> int:
18+
"""Multiply two integers."""
19+
return a * b
20+
21+
22+
@pytest.fixture(scope="module")
23+
def chat() -> ChatMLX:
24+
"""Return ChatMLX bound with the multiply tool or skip if unavailable."""
25+
try:
26+
llm = MLXPipeline.from_model_id(
27+
model_id=MODEL_ID, pipeline_kwargs={"max_new_tokens": 150}
28+
)
29+
except Exception:
30+
pytest.skip("Required MLX model isn't available.", allow_module_level=True)
31+
chat_model = ChatMLX(llm=llm)
32+
return chat_model.bind_tools(tools=[multiply], tool_choice=True) # type: ignore[return-value]
33+
34+
35+
def _call_tool(tool_call: Dict) -> ToolMessage:
36+
result = multiply.invoke(tool_call["args"])
37+
return ToolMessage(content=str(result), tool_call_id=tool_call.get("id", ""))
38+
39+
40+
def test_mlx_tool_calls_soft(chat: ChatMLX) -> None:
41+
messages = [HumanMessage(content="Use the multiply tool to compute 2 * 3.")]
42+
ai_msg = chat.invoke(messages)
43+
tool_msg = _call_tool(ai_msg.tool_calls[0])
44+
final = chat.invoke(messages + [ai_msg, tool_msg])
45+
assert "6" in final.content
46+
47+
48+
def test_mlx_tool_calls_hard(chat: ChatMLX) -> None:
49+
messages = [HumanMessage(content="Use the multiply tool to compute 2 * 3.")]
50+
ai_msg = chat.invoke(messages)
51+
assert isinstance(ai_msg, AIMessage)
52+
assert ai_msg.tool_calls
53+
tool_call = ai_msg.tool_calls[0]
54+
assert tool_call["name"] == "multiply"
55+
assert tool_call["args"] == {"a": 2, "b": 3}

0 commit comments

Comments
 (0)