Skip to content

Commit 83d8be7

Browse files
eyurtsevmdrxy
andauthored
langchain[patch]: harden xml parser for xmloutput agent (#31859)
Harden the default implementation of the XML parser for the agent --------- Co-authored-by: Mason Daugherty <[email protected]> Co-authored-by: Mason Daugherty <[email protected]>
1 parent 3f839d5 commit 83d8be7

File tree

4 files changed

+205
-28
lines changed

4 files changed

+205
-28
lines changed
Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,52 @@
1+
from typing import Literal, Optional
2+
13
from langchain_core.agents import AgentAction
24

35

6+
def _escape(xml: str) -> str:
7+
"""Replace XML tags with custom safe delimiters."""
8+
replacements = {
9+
"<tool>": "[[tool]]",
10+
"</tool>": "[[/tool]]",
11+
"<tool_input>": "[[tool_input]]",
12+
"</tool_input>": "[[/tool_input]]",
13+
"<observation>": "[[observation]]",
14+
"</observation>": "[[/observation]]",
15+
}
16+
for orig, repl in replacements.items():
17+
xml = xml.replace(orig, repl)
18+
return xml
19+
20+
421
def format_xml(
522
intermediate_steps: list[tuple[AgentAction, str]],
23+
*,
24+
escape_format: Optional[Literal["minimal"]] = "minimal",
625
) -> str:
726
"""Format the intermediate steps as XML.
827
928
Args:
1029
intermediate_steps: The intermediate steps.
30+
escape_format: The escaping format to use. Currently only 'minimal' is
31+
supported, which replaces XML tags with custom delimiters to prevent
32+
conflicts.
1133
1234
Returns:
1335
The intermediate steps as XML.
1436
"""
1537
log = ""
1638
for action, observation in intermediate_steps:
39+
if escape_format == "minimal":
40+
# Escape XML tags in tool names and inputs using custom delimiters
41+
tool = _escape(action.tool)
42+
tool_input = _escape(str(action.tool_input))
43+
observation = _escape(str(observation))
44+
else:
45+
tool = action.tool
46+
tool_input = str(action.tool_input)
47+
observation = str(observation)
1748
log += (
18-
f"<tool>{action.tool}</tool><tool_input>{action.tool_input}"
49+
f"<tool>{tool}</tool><tool_input>{tool_input}"
1950
f"</tool_input><observation>{observation}</observation>"
2051
)
2152
return log

libs/langchain/langchain/agents/output_parsers/xml.py

Lines changed: 99 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,119 @@
1-
from typing import Union
1+
import re
2+
from typing import Literal, Optional, Union
23

34
from langchain_core.agents import AgentAction, AgentFinish
5+
from pydantic import Field
46

57
from langchain.agents import AgentOutputParser
68

79

10+
def _unescape(text: str) -> str:
11+
"""Convert custom tag delimiters back into XML tags."""
12+
replacements = {
13+
"[[tool]]": "<tool>",
14+
"[[/tool]]": "</tool>",
15+
"[[tool_input]]": "<tool_input>",
16+
"[[/tool_input]]": "</tool_input>",
17+
"[[observation]]": "<observation>",
18+
"[[/observation]]": "</observation>",
19+
}
20+
for repl, orig in replacements.items():
21+
text = text.replace(repl, orig)
22+
return text
23+
24+
825
class XMLAgentOutputParser(AgentOutputParser):
9-
"""Parses tool invocations and final answers in XML format.
26+
"""Parses tool invocations and final answers from XML-formatted agent output.
27+
28+
This parser extracts structured information from XML tags to determine whether
29+
an agent should perform a tool action or provide a final answer. It includes
30+
built-in escaping support to safely handle tool names and inputs
31+
containing XML special characters.
32+
33+
Args:
34+
escape_format: The escaping format to use when parsing XML content.
35+
Supports 'minimal' which uses custom delimiters like [[tool]] to replace
36+
XML tags within content, preventing parsing conflicts.
37+
Use 'minimal' if using a corresponding encoding format that uses
38+
the _escape function when formatting the output (e.g., with format_xml).
39+
40+
Expected formats:
41+
Tool invocation (returns AgentAction):
42+
<tool>search</tool>
43+
<tool_input>what is 2 + 2</tool_input>
1044
11-
Expects output to be in one of two formats.
45+
Final answer (returns AgentFinish):
46+
<final_answer>The answer is 4</final_answer>
1247
13-
If the output signals that an action should be taken,
14-
should be in the below format. This will result in an AgentAction
15-
being returned.
48+
Note:
49+
Minimal escaping allows tool names containing XML tags to be safely
50+
represented. For example, a tool named "search<tool>nested</tool>" would be
51+
escaped as "search[[tool]]nested[[/tool]]" in the XML and automatically
52+
unescaped during parsing.
1653
17-
```
18-
<tool>search</tool>
19-
<tool_input>what is 2 + 2</tool_input>
20-
```
54+
Raises:
55+
ValueError: If the input doesn't match either expected XML format or
56+
contains malformed XML structure.
57+
"""
58+
59+
escape_format: Optional[Literal["minimal"]] = Field(default="minimal")
60+
"""The format to use for escaping XML characters.
2161
22-
If the output signals that a final answer should be given,
23-
should be in the below format. This will result in an AgentFinish
24-
being returned.
62+
minimal - uses custom delimiters to replace XML tags within content,
63+
preventing parsing conflicts. This is the only supported format currently.
2564
26-
```
27-
<final_answer>Foo</final_answer>
28-
```
65+
None - no escaping is applied, which may lead to parsing conflicts.
2966
"""
3067

3168
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
32-
if "</tool>" in text:
33-
tool, tool_input = text.split("</tool>")
34-
_tool = tool.split("<tool>")[1]
35-
_tool_input = tool_input.split("<tool_input>")[1]
36-
if "</tool_input>" in _tool_input:
37-
_tool_input = _tool_input.split("</tool_input>")[0]
69+
# Check for tool invocation first
70+
tool_matches = re.findall(r"<tool>(.*?)</tool>", text, re.DOTALL)
71+
if tool_matches:
72+
if len(tool_matches) != 1:
73+
msg = (
74+
f"Malformed tool invocation: expected exactly one <tool> block, "
75+
f"but found {len(tool_matches)}."
76+
)
77+
raise ValueError(msg)
78+
_tool = tool_matches[0]
79+
80+
# Match optional tool input
81+
input_matches = re.findall(
82+
r"<tool_input>(.*?)</tool_input>", text, re.DOTALL
83+
)
84+
if len(input_matches) > 1:
85+
msg = (
86+
f"Malformed tool invocation: expected at most one <tool_input> "
87+
f"block, but found {len(input_matches)}."
88+
)
89+
raise ValueError(msg)
90+
_tool_input = input_matches[0] if input_matches else ""
91+
92+
# Unescape if minimal escape format is used
93+
if self.escape_format == "minimal":
94+
_tool = _unescape(_tool)
95+
_tool_input = _unescape(_tool_input)
96+
3897
return AgentAction(tool=_tool, tool_input=_tool_input, log=text)
39-
if "<final_answer>" in text:
40-
_, answer = text.split("<final_answer>")
41-
if "</final_answer>" in answer:
42-
answer = answer.split("</final_answer>")[0]
98+
# Check for final answer
99+
if "<final_answer>" in text and "</final_answer>" in text:
100+
matches = re.findall(r"<final_answer>(.*?)</final_answer>", text, re.DOTALL)
101+
if len(matches) != 1:
102+
msg = (
103+
"Malformed output: expected exactly one "
104+
"<final_answer>...</final_answer> block."
105+
)
106+
raise ValueError(msg)
107+
answer = matches[0]
108+
# Unescape custom delimiters in final answer
109+
if self.escape_format == "minimal":
110+
answer = _unescape(answer)
43111
return AgentFinish(return_values={"output": answer}, log=text)
44-
raise ValueError
112+
msg = (
113+
"Malformed output: expected either a tool invocation "
114+
"or a final answer in XML format."
115+
)
116+
raise ValueError(msg)
45117

46118
def get_format_instructions(self) -> str:
47119
raise NotImplementedError

libs/langchain/tests/unit_tests/agents/format_scratchpad/test_xml.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,42 @@ def test_multiple_agent_actions_observations() -> None:
3939
def test_empty_list_agent_actions() -> None:
4040
result = format_xml([])
4141
assert result == ""
42+
43+
44+
def test_xml_escaping_minimal() -> None:
45+
"""Test that XML tags in tool names are escaped with minimal format."""
46+
# Arrange
47+
agent_action = AgentAction(
48+
tool="search<tool>nested</tool>", tool_input="query<input>test</input>", log=""
49+
)
50+
observation = "Found <observation>result</observation>"
51+
intermediate_steps = [(agent_action, observation)]
52+
53+
# Act
54+
result = format_xml(intermediate_steps, escape_format="minimal")
55+
56+
# Assert - XML tags should be replaced with custom delimiters
57+
expected_result = (
58+
"<tool>search[[tool]]nested[[/tool]]</tool>"
59+
"<tool_input>query<input>test</input></tool_input>"
60+
"<observation>Found [[observation]]result[[/observation]]</observation>"
61+
)
62+
assert result == expected_result
63+
64+
65+
def test_no_escaping() -> None:
66+
"""Test that escaping can be disabled."""
67+
# Arrange
68+
agent_action = AgentAction(tool="Tool1", tool_input="Input1", log="")
69+
observation = "Observation1"
70+
intermediate_steps = [(agent_action, observation)]
71+
72+
# Act
73+
result = format_xml(intermediate_steps, escape_format=None)
74+
75+
# Assert
76+
expected_result = (
77+
"<tool>Tool1</tool><tool_input>Input1</tool_input>"
78+
"<observation>Observation1</observation>"
79+
)
80+
assert result == expected_result

libs/langchain/tests/unit_tests/agents/output_parsers/test_xml.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,38 @@ def test_finish() -> None:
3232
output = parser.invoke(_input)
3333
expected_output = AgentFinish(return_values={"output": "bar"}, log=_input)
3434
assert output == expected_output
35+
36+
37+
def test_malformed_xml_with_nested_tags() -> None:
38+
"""Test handling of tool names with XML tags via format_xml minimal escaping."""
39+
from langchain.agents.format_scratchpad.xml import format_xml
40+
41+
# Create an AgentAction with XML tags in the tool name
42+
action = AgentAction(tool="search<tool>nested</tool>", tool_input="query", log="")
43+
44+
# The format_xml function should escape the XML tags using custom delimiters
45+
formatted_xml = format_xml([(action, "observation")])
46+
47+
# Extract just the tool part for parsing
48+
tool_part = formatted_xml.split("<observation>")[0] # Remove observation part
49+
50+
# Now test that the parser can handle the escaped XML
51+
parser = XMLAgentOutputParser(escape_format="minimal")
52+
output = parser.invoke(tool_part)
53+
54+
# The parser should unescape and extract the original tool name
55+
expected_output = AgentAction(
56+
tool="search<tool>nested</tool>", tool_input="query", log=tool_part
57+
)
58+
assert output == expected_output
59+
60+
61+
def test_no_escaping() -> None:
62+
"""Test parser with escaping disabled."""
63+
parser = XMLAgentOutputParser(escape_format=None)
64+
65+
# Test with regular tool name (no XML tags)
66+
_input = """<tool>search</tool><tool_input>foo</tool_input>"""
67+
output = parser.invoke(_input)
68+
expected_output = AgentAction(tool="search", tool_input="foo", log=_input)
69+
assert output == expected_output

0 commit comments

Comments
 (0)