Skip to content

Commit b2d3001

Browse files
diego-coderRN
authored andcommitted
test(mlx): add async tool test
1 parent 1e2d589 commit b2d3001

File tree

2 files changed

+86
-3
lines changed
  • libs/community

2 files changed

+86
-3
lines changed

libs/community/langchain_community/chat_models/mlx.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ def _generate(
9999
run_manager: Optional[CallbackManagerForLLMRun] = None,
100100
**kwargs: Any,
101101
) -> ChatResult:
102-
llm_input = self._to_chat_prompt(messages)
102+
tools = kwargs.pop("tools", None)
103+
llm_input = self._to_chat_prompt(messages, tools=tools)
103104
llm_result = self.llm._generate(
104105
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
105106
)
@@ -112,7 +113,8 @@ async def _agenerate(
112113
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
113114
**kwargs: Any,
114115
) -> ChatResult:
115-
llm_input = self._to_chat_prompt(messages)
116+
tools = kwargs.pop("tools", None)
117+
llm_input = self._to_chat_prompt(messages, tools=tools)
116118
llm_result = await self.llm._agenerate(
117119
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
118120
)
@@ -123,8 +125,17 @@ def _to_chat_prompt(
123125
messages: List[BaseMessage],
124126
tokenize: bool = False,
125127
return_tensors: Optional[str] = None,
128+
tools: Sequence[dict] | None = None,
126129
) -> str:
127-
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
130+
"""Convert messages to the prompt format expected by the wrapped LLM.
131+
132+
Args:
133+
messages: Chat messages to include in the prompt.
134+
tokenize: Whether to return token IDs instead of text.
135+
return_tensors: Framework for returned tensors when ``tokenize`` is
136+
True.
137+
tools: Optional tool definitions to include in the prompt.
138+
"""
128139
if not messages:
129140
raise ValueError("At least one HumanMessage must be provided!")
130141

@@ -137,6 +148,7 @@ def _to_chat_prompt(
137148
tokenize=tokenize,
138149
add_generation_prompt=True,
139150
return_tensors=return_tensors,
151+
tools=tools,
140152
)
141153

142154
def _to_chatml_format(self, message: BaseMessage) -> dict:

libs/community/tests/unit_tests/chat_models/test_mlx.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,42 @@
22

33
from importlib import import_module
44

5+
import pytest
6+
from langchain_core.messages import HumanMessage
7+
8+
from langchain_community.chat_models.mlx import ChatMLX
9+
10+
11+
class _FakeTokenizer:
12+
def __init__(self) -> None:
13+
self.tools = None
14+
15+
def apply_chat_template(
16+
self,
17+
messages,
18+
tokenize=False,
19+
add_generation_prompt=True,
20+
return_tensors=None,
21+
tools=None,
22+
) -> str:
23+
self.tools = tools
24+
return "prompt"
25+
26+
27+
class _FakeLLM:
28+
def __init__(self) -> None:
29+
self.tokenizer = _FakeTokenizer()
30+
31+
def _generate(self, prompts, stop=None, run_manager=None, **kwargs):
32+
class _Res:
33+
generations = [[type("G", (), {"text": "", "generation_info": {}})]]
34+
llm_output = {}
35+
36+
return _Res()
37+
38+
async def _agenerate(self, prompts, stop=None, run_manager=None, **kwargs):
39+
return self._generate(prompts, stop=stop, run_manager=run_manager, **kwargs)
40+
541

642
def test_import_class() -> None:
743
"""Test that the class can be imported."""
@@ -10,3 +46,38 @@ def test_import_class() -> None:
1046

1147
module = import_module(module_name)
1248
assert hasattr(module, class_name)
49+
50+
51+
def test_generate_passes_tools_to_tokenizer() -> None:
52+
llm = _FakeLLM()
53+
chat = ChatMLX(llm=llm)
54+
tools = [
55+
{
56+
"type": "function",
57+
"function": {
58+
"name": "foo",
59+
"description": "",
60+
"parameters": {"type": "object", "properties": {}},
61+
},
62+
}
63+
]
64+
chat._generate([HumanMessage(content="hi")], tools=tools)
65+
assert llm.tokenizer.tools == tools
66+
67+
68+
@pytest.mark.asyncio
69+
async def test_agenerate_passes_tools_to_tokenizer() -> None:
70+
llm = _FakeLLM()
71+
chat = ChatMLX(llm=llm)
72+
tools = [
73+
{
74+
"type": "function",
75+
"function": {
76+
"name": "foo",
77+
"description": "",
78+
"parameters": {"type": "object", "properties": {}},
79+
},
80+
}
81+
]
82+
await chat._agenerate([HumanMessage(content="hi")], tools=tools)
83+
assert llm.tokenizer.tools == tools

0 commit comments

Comments
 (0)