|
1 |
| -from typing import Union |
| 1 | +import re |
| 2 | +from typing import Literal, Optional, Union |
2 | 3 |
|
3 | 4 | from langchain_core.agents import AgentAction, AgentFinish
|
| 5 | +from pydantic import Field |
4 | 6 |
|
5 | 7 | from langchain.agents import AgentOutputParser
|
6 | 8 |
|
7 | 9 |
|
| 10 | +def _unescape(text: str) -> str: |
| 11 | + """Convert custom tag delimiters back into XML tags.""" |
| 12 | + replacements = { |
| 13 | + "[[tool]]": "<tool>", |
| 14 | + "[[/tool]]": "</tool>", |
| 15 | + "[[tool_input]]": "<tool_input>", |
| 16 | + "[[/tool_input]]": "</tool_input>", |
| 17 | + "[[observation]]": "<observation>", |
| 18 | + "[[/observation]]": "</observation>", |
| 19 | + } |
| 20 | + for repl, orig in replacements.items(): |
| 21 | + text = text.replace(repl, orig) |
| 22 | + return text |
| 23 | + |
| 24 | + |
8 | 25 | class XMLAgentOutputParser(AgentOutputParser):
|
9 |
| - """Parses tool invocations and final answers in XML format. |
| 26 | + """Parses tool invocations and final answers from XML-formatted agent output. |
| 27 | +
|
| 28 | + This parser extracts structured information from XML tags to determine whether |
| 29 | + an agent should perform a tool action or provide a final answer. It includes |
| 30 | + built-in escaping support to safely handle tool names and inputs |
| 31 | + containing XML special characters. |
| 32 | +
|
| 33 | + Args: |
| 34 | + escape_format: The escaping format to use when parsing XML content. |
| 35 | + Supports 'minimal' which uses custom delimiters like [[tool]] to replace |
| 36 | + XML tags within content, preventing parsing conflicts. |
| 37 | + Use 'minimal' if using a corresponding encoding format that uses |
| 38 | + the _escape function when formatting the output (e.g., with format_xml). |
| 39 | +
|
| 40 | + Expected formats: |
| 41 | + Tool invocation (returns AgentAction): |
| 42 | + <tool>search</tool> |
| 43 | + <tool_input>what is 2 + 2</tool_input> |
10 | 44 |
|
11 |
| - Expects output to be in one of two formats. |
| 45 | + Final answer (returns AgentFinish): |
| 46 | + <final_answer>The answer is 4</final_answer> |
12 | 47 |
|
13 |
| - If the output signals that an action should be taken, |
14 |
| - should be in the below format. This will result in an AgentAction |
15 |
| - being returned. |
| 48 | + Note: |
| 49 | + Minimal escaping allows tool names containing XML tags to be safely |
| 50 | + represented. For example, a tool named "search<tool>nested</tool>" would be |
| 51 | + escaped as "search[[tool]]nested[[/tool]]" in the XML and automatically |
| 52 | + unescaped during parsing. |
16 | 53 |
|
17 |
| - ``` |
18 |
| - <tool>search</tool> |
19 |
| - <tool_input>what is 2 + 2</tool_input> |
20 |
| - ``` |
| 54 | + Raises: |
| 55 | + ValueError: If the input doesn't match either expected XML format or |
| 56 | + contains malformed XML structure. |
| 57 | + """ |
| 58 | + |
| 59 | + escape_format: Optional[Literal["minimal"]] = Field(default="minimal") |
| 60 | + """The format to use for escaping XML characters. |
21 | 61 |
|
22 |
| - If the output signals that a final answer should be given, |
23 |
| - should be in the below format. This will result in an AgentFinish |
24 |
| - being returned. |
| 62 | + minimal - uses custom delimiters to replace XML tags within content, |
| 63 | + preventing parsing conflicts. This is the only supported format currently. |
25 | 64 |
|
26 |
| - ``` |
27 |
| - <final_answer>Foo</final_answer> |
28 |
| - ``` |
| 65 | + None - no escaping is applied, which may lead to parsing conflicts. |
29 | 66 | """
|
30 | 67 |
|
31 | 68 | def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
32 |
| - if "</tool>" in text: |
33 |
| - tool, tool_input = text.split("</tool>") |
34 |
| - _tool = tool.split("<tool>")[1] |
35 |
| - _tool_input = tool_input.split("<tool_input>")[1] |
36 |
| - if "</tool_input>" in _tool_input: |
37 |
| - _tool_input = _tool_input.split("</tool_input>")[0] |
| 69 | + # Check for tool invocation first |
| 70 | + tool_matches = re.findall(r"<tool>(.*?)</tool>", text, re.DOTALL) |
| 71 | + if tool_matches: |
| 72 | + if len(tool_matches) != 1: |
| 73 | + msg = ( |
| 74 | + f"Malformed tool invocation: expected exactly one <tool> block, " |
| 75 | + f"but found {len(tool_matches)}." |
| 76 | + ) |
| 77 | + raise ValueError(msg) |
| 78 | + _tool = tool_matches[0] |
| 79 | + |
| 80 | + # Match optional tool input |
| 81 | + input_matches = re.findall( |
| 82 | + r"<tool_input>(.*?)</tool_input>", text, re.DOTALL |
| 83 | + ) |
| 84 | + if len(input_matches) > 1: |
| 85 | + msg = ( |
| 86 | + f"Malformed tool invocation: expected at most one <tool_input> " |
| 87 | + f"block, but found {len(input_matches)}." |
| 88 | + ) |
| 89 | + raise ValueError(msg) |
| 90 | + _tool_input = input_matches[0] if input_matches else "" |
| 91 | + |
| 92 | + # Unescape if minimal escape format is used |
| 93 | + if self.escape_format == "minimal": |
| 94 | + _tool = _unescape(_tool) |
| 95 | + _tool_input = _unescape(_tool_input) |
| 96 | + |
38 | 97 | return AgentAction(tool=_tool, tool_input=_tool_input, log=text)
|
39 |
| - if "<final_answer>" in text: |
40 |
| - _, answer = text.split("<final_answer>") |
41 |
| - if "</final_answer>" in answer: |
42 |
| - answer = answer.split("</final_answer>")[0] |
| 98 | + # Check for final answer |
| 99 | + if "<final_answer>" in text and "</final_answer>" in text: |
| 100 | + matches = re.findall(r"<final_answer>(.*?)</final_answer>", text, re.DOTALL) |
| 101 | + if len(matches) != 1: |
| 102 | + msg = ( |
| 103 | + "Malformed output: expected exactly one " |
| 104 | + "<final_answer>...</final_answer> block." |
| 105 | + ) |
| 106 | + raise ValueError(msg) |
| 107 | + answer = matches[0] |
| 108 | + # Unescape custom delimiters in final answer |
| 109 | + if self.escape_format == "minimal": |
| 110 | + answer = _unescape(answer) |
43 | 111 | return AgentFinish(return_values={"output": answer}, log=text)
|
44 |
| - raise ValueError |
| 112 | + msg = ( |
| 113 | + "Malformed output: expected either a tool invocation " |
| 114 | + "or a final answer in XML format." |
| 115 | + ) |
| 116 | + raise ValueError(msg) |
45 | 117 |
|
46 | 118 | def get_format_instructions(self) -> str:
|
47 | 119 | raise NotImplementedError
|
|
0 commit comments