Skip to content

Commit f570ed5

Browse files
Fix XPath injection in search_item_ctrl_f (#1768)
1 parent c9992da commit f570ed5

File tree

2 files changed

+152
-1
lines changed

2 files changed

+152
-1
lines changed

src/smolagents/vision_web_browser.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,24 @@ def save_screenshot(memory_step: ActionStep, agent: CodeAgent) -> None:
8484
return
8585

8686

87+
def _escape_xpath_string(s: str) -> str:
88+
"""
89+
Escapes a string for safe use in an XPath expression.
90+
91+
Args:
92+
s (`str`): Arbitrary input string to escape.
93+
94+
Returns:
95+
`str`: Valid XPath expression representing the literal value of `s`.
96+
"""
97+
if "'" not in s:
98+
return f"'{s}'"
99+
if '"' not in s:
100+
return f'"{s}"'
101+
parts = s.split("'")
102+
return "concat(" + ', "\'", '.join(f"'{p}'" for p in parts) + ")"
103+
104+
87105
@tool
88106
def search_item_ctrl_f(text: str, nth_result: int = 1) -> str:
89107
"""
@@ -92,7 +110,8 @@ def search_item_ctrl_f(text: str, nth_result: int = 1) -> str:
92110
text: The text to search for
93111
nth_result: Which occurrence to jump to (default: 1)
94112
"""
95-
elements = driver.find_elements(By.XPATH, f"//*[contains(text(), '{text}')]")
113+
escaped_text = _escape_xpath_string(text)
114+
elements = driver.find_elements(By.XPATH, f"//*[contains(text(), {escaped_text})]")
96115
if nth_result > len(elements):
97116
raise Exception(f"Match n°{nth_result} not found (only {len(elements)} matches found)")
98117
result = f"Found {len(elements)} matches for '{text}'."

tests/test_vision_web_browser.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
"""Test XPath injection vulnerability fix in vision_web_browser.py"""
2+
3+
from unittest.mock import Mock, patch
4+
5+
import pytest
6+
7+
from smolagents.vision_web_browser import _escape_xpath_string, search_item_ctrl_f
8+
9+
10+
@pytest.fixture
11+
def mock_driver():
12+
"""Mock Selenium WebDriver"""
13+
driver = Mock()
14+
driver.find_elements.return_value = [Mock()] # Mock found elements
15+
driver.execute_script.return_value = None
16+
return driver
17+
18+
19+
class TestXPathEscaping:
20+
"""Test XPath string escaping functionality"""
21+
22+
@pytest.mark.parametrize(
23+
"input_text,expected_pattern",
24+
[
25+
("normal text", "'normal text'"),
26+
("text with 'quote'", "\"text with 'quote'\""),
27+
('text with "quote"', "'text with \"quote\"'"),
28+
("text with one single'quote", '"text with one single\'quote"'),
29+
('text with one double"quote', "'text with one double\"quote'"),
30+
(
31+
"text with both 'single' and \"double\" quotes",
32+
"concat('text with both ', \"'\", 'single', \"'\", ' and \"double\" quotes')",
33+
),
34+
("", "''"),
35+
("'", '"\'"'),
36+
('"', "'\"'"),
37+
],
38+
)
39+
def test_escape_xpath_string_basic(self, input_text, expected_pattern):
40+
"""Test basic XPath escaping cases"""
41+
result = _escape_xpath_string(input_text)
42+
assert result == expected_pattern
43+
44+
@pytest.mark.parametrize(
45+
"input_text",
46+
[
47+
"text with both 'single' and \"double\" quotes",
48+
'it\'s a "test" case',
49+
"'mixed\" quotes'",
50+
],
51+
)
52+
def test_escape_xpath_string_mixed_quotes(self, input_text):
53+
"""Test XPath escaping with mixed quotes uses concat()"""
54+
result = _escape_xpath_string(input_text)
55+
assert result.startswith("concat(")
56+
assert result.endswith(")")
57+
58+
@pytest.mark.parametrize(
59+
"malicious_input",
60+
[
61+
"')] | //script[@src='evil.js'] | foo[contains(text(), '",
62+
"') or 1=1 or ('",
63+
"')] | //user[contains(@role,'admin')] | foo[contains(text(), '",
64+
"') and substring(//user[1]/password,1,1)='a",
65+
],
66+
)
67+
def test_escape_prevents_injection(self, malicious_input):
68+
"""Test that malicious XPath injection attempts are safely escaped"""
69+
result = _escape_xpath_string(malicious_input)
70+
# Should either be wrapped in quotes or use concat()
71+
assert (
72+
(result.startswith("'") and result.endswith("'"))
73+
or (result.startswith('"') and result.endswith('"'))
74+
or result.startswith("concat(")
75+
)
76+
77+
78+
class TestSearchItemCtrlF:
79+
"""Test the search_item_ctrl_f function with XPath injection protection"""
80+
81+
@pytest.mark.parametrize(
82+
"search_text",
83+
[
84+
"normal search",
85+
"search with 'quotes'",
86+
'search with "quotes"',
87+
"')] | //script[@src='evil.js'] | foo[contains(text(), '",
88+
"') or 1=1 or ('",
89+
],
90+
)
91+
def test_search_item_prevents_injection(self, search_text, mock_driver):
92+
"""Test that search_item_ctrl_f prevents XPath injection"""
93+
with patch("smolagents.vision_web_browser.driver", mock_driver, create=True):
94+
# Call the function
95+
result = search_item_ctrl_f(search_text)
96+
97+
# Verify driver.find_elements was called
98+
mock_driver.find_elements.assert_called_once()
99+
100+
# Get the actual XPath query that was generated
101+
call_args = mock_driver.find_elements.call_args
102+
xpath_query = call_args[0][1] # Second positional argument
103+
104+
# Verify the query doesn't contain unescaped injection
105+
if "')] | //" in search_text:
106+
# For injection attempts, verify they're properly escaped
107+
# The query should either use concat() or be properly quoted
108+
is_concat = "concat(" in xpath_query
109+
is_properly_quoted = xpath_query.count('"') >= 2 or xpath_query.count("'") >= 2
110+
assert is_concat or is_properly_quoted, f"XPath injection not prevented: {xpath_query}"
111+
112+
# Verify we got a result
113+
assert "Found" in result
114+
115+
def test_search_item_nth_result(self, mock_driver):
116+
"""Test nth_result parameter works correctly"""
117+
mock_driver.find_elements.return_value = [Mock(), Mock(), Mock()] # 3 elements
118+
119+
with patch("smolagents.vision_web_browser.driver", mock_driver, create=True):
120+
result = search_item_ctrl_f("test", nth_result=2)
121+
122+
# Should find 3 matches and focus on element 2
123+
assert "Found 3 matches" in result
124+
assert "Focused on element 2 of 3" in result
125+
126+
def test_search_item_not_found(self, mock_driver):
127+
"""Test exception when nth_result exceeds available matches"""
128+
mock_driver.find_elements.return_value = [Mock()] # Only 1 element
129+
130+
with patch("smolagents.vision_web_browser.driver", mock_driver, create=True):
131+
with pytest.raises(Exception, match="Match n°3 not found"):
132+
search_item_ctrl_f("test", nth_result=3)

0 commit comments

Comments
 (0)