Skip to content

Commit 20605bb

Browse files
fixing tests
1 parent b37bba1 commit 20605bb

11 files changed

+532
-401
lines changed

tests/conftest.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,15 +113,44 @@ async def mock_send_cdp(method, params=None):
113113
]
114114
}
115115
elif method == "DOM.resolveNode":
116+
# Create a mapping of element IDs to appropriate object IDs
117+
backend_node_id = params.get("backendNodeId", 1)
116118
return {
117119
"object": {
118-
"objectId": "test-object-id"
120+
"objectId": f"test-object-id-{backend_node_id}"
119121
}
120122
}
121123
elif method == "Runtime.callFunctionOn":
124+
# Map object IDs to appropriate selectors based on the element ID
125+
object_id = params.get("objectId", "")
126+
127+
# Extract backend_node_id from object_id
128+
if "test-object-id-" in object_id:
129+
backend_node_id = object_id.replace("test-object-id-", "")
130+
131+
# Map specific element IDs to expected selectors for tests
132+
selector_mapping = {
133+
"100": "//a[@id='home-link']",
134+
"101": "//a[@id='about-link']",
135+
"102": "//a[@id='contact-link']",
136+
"200": "//button[@id='visible-button']",
137+
"300": "//input[@id='form-input']",
138+
"400": "//div[@id='target-element']",
139+
"501": "//button[@id='btn1']",
140+
"600": "//button[@id='interactive-btn']",
141+
"700": "//div[@id='positioned-element']",
142+
"800": "//div[@id='highlighted-element']",
143+
"900": "//div[@id='custom-model-element']",
144+
"1000": "//input[@id='complex-element']",
145+
}
146+
147+
xpath = selector_mapping.get(backend_node_id, "//div[@id='test']")
148+
else:
149+
xpath = "//div[@id='test']"
150+
122151
return {
123152
"result": {
124-
"value": "//div[@id='test']"
153+
"value": xpath
125154
}
126155
}
127156
return {}
@@ -130,7 +159,45 @@ async def mock_send_cdp(method, params=None):
130159

131160
# Mock get_cdp_client to return a mock CDP session
132161
mock_cdp_client = AsyncMock()
133-
mock_cdp_client.send = AsyncMock(return_value={"result": {"value": "//div[@id='test']"}})
162+
163+
# Set up the mock CDP client to handle Runtime.callFunctionOn properly
164+
async def mock_cdp_send(method, params=None):
165+
if method == "Runtime.callFunctionOn":
166+
# Map object IDs to appropriate selectors based on the element ID
167+
object_id = params.get("objectId", "")
168+
169+
# Extract backend_node_id from object_id
170+
if "test-object-id-" in object_id:
171+
backend_node_id = object_id.replace("test-object-id-", "")
172+
173+
# Map specific element IDs to expected selectors for tests
174+
selector_mapping = {
175+
"100": "//a[@id='home-link']",
176+
"101": "//a[@id='about-link']",
177+
"102": "//a[@id='contact-link']",
178+
"200": "//button[@id='visible-button']",
179+
"300": "//input[@id='form-input']",
180+
"400": "//div[@id='target-element']",
181+
"501": "//button[@id='btn1']",
182+
"600": "//button[@id='interactive-btn']",
183+
"700": "//div[@id='positioned-element']",
184+
"800": "//div[@id='highlighted-element']",
185+
"900": "//div[@id='custom-model-element']",
186+
"1000": "//input[@id='complex-element']",
187+
}
188+
189+
xpath = selector_mapping.get(backend_node_id, "//div[@id='test']")
190+
else:
191+
xpath = "//div[@id='test']"
192+
193+
return {
194+
"result": {
195+
"value": xpath
196+
}
197+
}
198+
return {"result": {"value": "//div[@id='test']"}}
199+
200+
mock_cdp_client.send = AsyncMock(side_effect=mock_cdp_send)
134201
stagehand_page.get_cdp_client = AsyncMock(return_value=mock_cdp_client)
135202

136203
# Mock ensure_injection and evaluate methods

tests/integration/end_to_end/test_workflows.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ async def mock_extract(instruction, **kwargs):
576576
# Extract data via Browserbase
577577
extracted = await stagehand.page.extract("extract page title and content")
578578
assert extracted["title"] == "Remote Page Title"
579-
assert "Browserbase" in extracted["content"]
579+
assert extracted["content"] == "Content extracted via Browserbase"
580580

581581
# Verify server interactions
582582
assert server.was_called_with_endpoint("act")

tests/mocks/mock_llm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,13 @@ def _create_response(self, data: Any, model: str) -> MockLLMResponse:
148148
import json
149149
content = json.dumps(data)
150150
return MockLLMResponse(content, data=data, model=model)
151+
elif isinstance(data, list):
152+
# For observe responses, convert list to JSON string for content
153+
import json
154+
# Wrap the list in the expected format for observe responses
155+
response_dict = {"elements": data}
156+
content = json.dumps(response_dict)
157+
return MockLLMResponse(content, data=response_dict, model=model)
151158
else:
152159
return MockLLMResponse(str(data), data=data, model=model)
153160

tests/mocks/mock_server.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -199,21 +199,23 @@ def _extract_endpoint(self, url: str) -> str:
199199
# Remove base URL and extract the last path component
200200
path = url.split("/")[-1]
201201

202-
# Handle common Stagehand endpoints
202+
# Handle common Stagehand endpoints - check exact matches to avoid substring issues
203203
if "session" in url and "create" in url:
204-
return "create_session"
205-
elif "navigate" in path:
206-
return "navigate"
207-
elif "act" in path:
208-
return "act"
209-
elif "observe" in path:
210-
return "observe"
211-
elif "extract" in path:
212-
return "extract"
213-
elif "screenshot" in path:
214-
return "screenshot"
204+
endpoint = "create_session"
205+
elif path == "navigate":
206+
endpoint = "navigate"
207+
elif path == "act":
208+
endpoint = "act"
209+
elif path == "observe":
210+
endpoint = "observe"
211+
elif path == "extract":
212+
endpoint = "extract"
213+
elif path == "screenshot":
214+
endpoint = "screenshot"
215215
else:
216-
return path or "unknown"
216+
endpoint = path or "unknown"
217+
218+
return endpoint
217219

218220
def set_response_override(self, endpoint: str, response: Union[Dict, callable]):
219221
"""Override the default response for a specific endpoint"""

0 commit comments

Comments
 (0)