Skip to content

Commit c210e07

Browse files
diego-coderRN
authored andcommitted
Refine MLX ReAct tool-call parsing
1 parent b2d3001 commit c210e07

File tree

1 file changed

+90
-1
lines changed
  • libs/community/langchain_community/chat_models

1 file changed

+90
-1
lines changed

libs/community/langchain_community/chat_models/mlx.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
import json
44
import logging
5+
import re
6+
import uuid
7+
58
from typing import (
69
Any,
710
Callable,
@@ -11,6 +14,7 @@
1114
Literal,
1215
Optional,
1316
Sequence,
17+
Tuple,
1418
Type,
1519
Union,
1620
)
@@ -26,7 +30,13 @@
2630
AIMessageChunk,
2731
BaseMessage,
2832
HumanMessage,
33+
InvalidToolCall,
2934
SystemMessage,
35+
ToolCall,
36+
)
37+
from langchain_core.output_parsers.openai_tools import (
38+
make_invalid_tool_call,
39+
parse_tool_call,
3040
)
3141
from langchain_core.outputs import (
3242
ChatGeneration,
@@ -45,6 +55,52 @@
4555
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."""
4656

4757

58+
def _parse_react_tool_calls(
59+
text: str,
60+
) -> Tuple[list[ToolCall] | None, list[InvalidToolCall]]:
61+
"""Extract ReAct-style tool calls from plain text output.
62+
63+
Args:
64+
text: Raw model generation text.
65+
66+
Returns:
67+
A tuple containing a list of parsed ``ToolCall`` objects if any were
68+
detected, otherwise ``None``, and a list of ``InvalidToolCall`` objects
69+
for unparseable patterns.
70+
"""
71+
72+
tool_calls: list[ToolCall] = []
73+
invalid_tool_calls: list[InvalidToolCall] = []
74+
75+
bracket_pattern = r"Action:\s*(?P<name>[\w.-]+)\[(?P<input>[^\]]+)\]"
76+
separate_pattern = (
77+
r"Action:\s*(?P<name>[^\n]+)\nAction Input:\s*(?P<input>[^\n]+)"
78+
)
79+
80+
matches = list(re.finditer(bracket_pattern, text))
81+
if not matches:
82+
matches = list(re.finditer(separate_pattern, text))
83+
84+
for match in matches:
85+
name = match.group("name").strip()
86+
arg_text = match.group("input").strip()
87+
try:
88+
args = json.loads(arg_text)
89+
if not isinstance(args, dict):
90+
args = {"input": args}
91+
except Exception:
92+
args = {"input": arg_text}
93+
tool_calls.append(ToolCall(id=str(uuid.uuid4()), name=name, args=args))
94+
95+
if not tool_calls and "Action:" in text:
96+
invalid_tool_calls.append(
97+
make_invalid_tool_call(text, "Could not parse ReAct tool call")
98+
)
99+
return None, invalid_tool_calls
100+
101+
return tool_calls or None, invalid_tool_calls
102+
103+
48104
class ChatMLX(BaseChatModel):
49105
"""MLX chat models.
50106
@@ -170,8 +226,41 @@ def _to_chat_result(llm_result: LLMResult) -> ChatResult:
170226
chat_generations = []
171227

172228
for g in llm_result.generations[0]:
229+
tool_calls: list[ToolCall] = []
230+
invalid_tool_calls: list[InvalidToolCall] = []
231+
additional_kwargs: Dict[str, Any] = {}
232+
233+
if isinstance(g.generation_info, dict):
234+
raw_tool_calls = g.generation_info.get("tool_calls")
235+
else:
236+
raw_tool_calls = None
237+
238+
if raw_tool_calls:
239+
additional_kwargs["tool_calls"] = raw_tool_calls
240+
for raw_tool_call in raw_tool_calls:
241+
try:
242+
tc = parse_tool_call(raw_tool_call, return_id=True)
243+
except Exception as e:
244+
invalid_tool_calls.append(
245+
make_invalid_tool_call(raw_tool_call, str(e))
246+
)
247+
else:
248+
if tc:
249+
tool_calls.append(tc)
250+
else:
251+
react_tool_calls, invalid_reacts = _parse_react_tool_calls(g.text)
252+
if react_tool_calls is not None:
253+
tool_calls.extend(react_tool_calls)
254+
invalid_tool_calls.extend(invalid_reacts)
255+
173256
chat_generation = ChatGeneration(
174-
message=AIMessage(content=g.text), generation_info=g.generation_info
257+
message=AIMessage(
258+
content=g.text,
259+
additional_kwargs=additional_kwargs,
260+
tool_calls=tool_calls,
261+
invalid_tool_calls=invalid_tool_calls,
262+
),
263+
generation_info=g.generation_info,
175264
)
176265
chat_generations.append(chat_generation)
177266

0 commit comments

Comments
 (0)