|
2 | 2 |
|
3 | 3 | import json |
4 | 4 | import logging |
| 5 | +import re |
| 6 | +import uuid |
| 7 | + |
5 | 8 | from typing import ( |
6 | 9 | Any, |
7 | 10 | Callable, |
|
11 | 14 | Literal, |
12 | 15 | Optional, |
13 | 16 | Sequence, |
| 17 | + Tuple, |
14 | 18 | Type, |
15 | 19 | Union, |
16 | 20 | ) |
|
26 | 30 | AIMessageChunk, |
27 | 31 | BaseMessage, |
28 | 32 | HumanMessage, |
| 33 | + InvalidToolCall, |
29 | 34 | SystemMessage, |
| 35 | + ToolCall, |
| 36 | +) |
| 37 | +from langchain_core.output_parsers.openai_tools import ( |
| 38 | + make_invalid_tool_call, |
| 39 | + parse_tool_call, |
30 | 40 | ) |
31 | 41 | from langchain_core.outputs import ( |
32 | 42 | ChatGeneration, |
|
45 | 55 | DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant.""" |
46 | 56 |
|
47 | 57 |
|
| 58 | +def _parse_react_tool_calls( |
| 59 | + text: str, |
| 60 | +) -> Tuple[list[ToolCall] | None, list[InvalidToolCall]]: |
| 61 | + """Extract ReAct-style tool calls from plain text output. |
| 62 | +
|
| 63 | + Args: |
| 64 | + text: Raw model generation text. |
| 65 | +
|
| 66 | + Returns: |
| 67 | + A tuple containing a list of parsed ``ToolCall`` objects if any were |
| 68 | + detected, otherwise ``None``, and a list of ``InvalidToolCall`` objects |
| 69 | + for unparseable patterns. |
| 70 | + """ |
| 71 | + |
| 72 | + tool_calls: list[ToolCall] = [] |
| 73 | + invalid_tool_calls: list[InvalidToolCall] = [] |
| 74 | + |
| 75 | + bracket_pattern = r"Action:\s*(?P<name>[\w.-]+)\[(?P<input>[^\]]+)\]" |
| 76 | + separate_pattern = ( |
| 77 | + r"Action:\s*(?P<name>[^\n]+)\nAction Input:\s*(?P<input>[^\n]+)" |
| 78 | + ) |
| 79 | + |
| 80 | + matches = list(re.finditer(bracket_pattern, text)) |
| 81 | + if not matches: |
| 82 | + matches = list(re.finditer(separate_pattern, text)) |
| 83 | + |
| 84 | + for match in matches: |
| 85 | + name = match.group("name").strip() |
| 86 | + arg_text = match.group("input").strip() |
| 87 | + try: |
| 88 | + args = json.loads(arg_text) |
| 89 | + if not isinstance(args, dict): |
| 90 | + args = {"input": args} |
| 91 | + except Exception: |
| 92 | + args = {"input": arg_text} |
| 93 | + tool_calls.append(ToolCall(id=str(uuid.uuid4()), name=name, args=args)) |
| 94 | + |
| 95 | + if not tool_calls and "Action:" in text: |
| 96 | + invalid_tool_calls.append( |
| 97 | + make_invalid_tool_call(text, "Could not parse ReAct tool call") |
| 98 | + ) |
| 99 | + return None, invalid_tool_calls |
| 100 | + |
| 101 | + return tool_calls or None, invalid_tool_calls |
| 102 | + |
| 103 | + |
48 | 104 | class ChatMLX(BaseChatModel): |
49 | 105 | """MLX chat models. |
50 | 106 |
|
@@ -170,8 +226,41 @@ def _to_chat_result(llm_result: LLMResult) -> ChatResult: |
170 | 226 | chat_generations = [] |
171 | 227 |
|
172 | 228 | for g in llm_result.generations[0]: |
| 229 | + tool_calls: list[ToolCall] = [] |
| 230 | + invalid_tool_calls: list[InvalidToolCall] = [] |
| 231 | + additional_kwargs: Dict[str, Any] = {} |
| 232 | + |
| 233 | + if isinstance(g.generation_info, dict): |
| 234 | + raw_tool_calls = g.generation_info.get("tool_calls") |
| 235 | + else: |
| 236 | + raw_tool_calls = None |
| 237 | + |
| 238 | + if raw_tool_calls: |
| 239 | + additional_kwargs["tool_calls"] = raw_tool_calls |
| 240 | + for raw_tool_call in raw_tool_calls: |
| 241 | + try: |
| 242 | + tc = parse_tool_call(raw_tool_call, return_id=True) |
| 243 | + except Exception as e: |
| 244 | + invalid_tool_calls.append( |
| 245 | + make_invalid_tool_call(raw_tool_call, str(e)) |
| 246 | + ) |
| 247 | + else: |
| 248 | + if tc: |
| 249 | + tool_calls.append(tc) |
| 250 | + else: |
| 251 | + react_tool_calls, invalid_reacts = _parse_react_tool_calls(g.text) |
| 252 | + if react_tool_calls is not None: |
| 253 | + tool_calls.extend(react_tool_calls) |
| 254 | + invalid_tool_calls.extend(invalid_reacts) |
| 255 | + |
173 | 256 | chat_generation = ChatGeneration( |
174 | | - message=AIMessage(content=g.text), generation_info=g.generation_info |
| 257 | + message=AIMessage( |
| 258 | + content=g.text, |
| 259 | + additional_kwargs=additional_kwargs, |
| 260 | + tool_calls=tool_calls, |
| 261 | + invalid_tool_calls=invalid_tool_calls, |
| 262 | + ), |
| 263 | + generation_info=g.generation_info, |
175 | 264 | ) |
176 | 265 | chat_generations.append(chat_generation) |
177 | 266 |
|
|
0 commit comments