Skip to content

Commit fa68005

Browse files
fixing more tests
1 parent 503876a commit fa68005

File tree

5 files changed

+618
-354
lines changed

5 files changed

+618
-354
lines changed

stagehand/handlers/extract_handler.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from stagehand.a11y.utils import get_accessibility_tree
88
from stagehand.llm.inference import extract as extract_inference
99
from stagehand.metrics import StagehandFunctionName # Changed import location
10-
from stagehand.types import DefaultExtractSchema, ExtractOptions, ExtractResult
10+
from stagehand.schemas import DEFAULT_EXTRACT_SCHEMA as DefaultExtractSchema, ExtractOptions, ExtractResult
1111
from stagehand.utils import inject_urls, transform_url_strings_to_ids
1212

1313
T = TypeVar("T", bound=BaseModel)
@@ -153,10 +153,12 @@ async def extract(
153153
f"Failed to validate extracted data against schema {schema.__name__}: {e}. Keeping raw data dict in .data field."
154154
)
155155

156-
# Create ExtractResult object
157-
result = ExtractResult(
158-
data=processed_data_payload,
159-
)
156+
# Create ExtractResult object with extracted data as fields
157+
if isinstance(processed_data_payload, dict):
158+
result = ExtractResult(**processed_data_payload)
159+
else:
160+
# For non-dict data (like Pydantic models), create with data field
161+
result = ExtractResult(data=processed_data_payload)
160162

161163
return result
162164

tests/conftest.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,65 @@ def mock_stagehand_page(mock_playwright_page):
8585
mock_client.logger.error = MagicMock()
8686
mock_client._get_lock_for_session = MagicMock(return_value=AsyncMock())
8787
mock_client._execute = AsyncMock()
88+
mock_client.update_metrics = MagicMock()
8889

8990
stagehand_page = StagehandPage(mock_playwright_page, mock_client)
91+
92+
# Mock CDP calls for accessibility tree
93+
async def mock_send_cdp(method, params=None):
94+
if method == "Accessibility.getFullAXTree":
95+
return {
96+
"nodes": [
97+
{
98+
"nodeId": "1",
99+
"role": {"value": "button"},
100+
"name": {"value": "Click me"},
101+
"backendDOMNodeId": 1,
102+
"childIds": [],
103+
"properties": []
104+
},
105+
{
106+
"nodeId": "2",
107+
"role": {"value": "textbox"},
108+
"name": {"value": "Search input"},
109+
"backendDOMNodeId": 2,
110+
"childIds": [],
111+
"properties": []
112+
}
113+
]
114+
}
115+
elif method == "DOM.resolveNode":
116+
return {
117+
"object": {
118+
"objectId": "test-object-id"
119+
}
120+
}
121+
elif method == "Runtime.callFunctionOn":
122+
return {
123+
"result": {
124+
"value": "//div[@id='test']"
125+
}
126+
}
127+
return {}
128+
129+
stagehand_page.send_cdp = AsyncMock(side_effect=mock_send_cdp)
130+
131+
# Mock get_cdp_client to return a mock CDP session
132+
mock_cdp_client = AsyncMock()
133+
mock_cdp_client.send = AsyncMock(return_value={"result": {"value": "//div[@id='test']"}})
134+
stagehand_page.get_cdp_client = AsyncMock(return_value=mock_cdp_client)
135+
136+
# Mock ensure_injection and evaluate methods
137+
stagehand_page.ensure_injection = AsyncMock()
138+
stagehand_page.evaluate = AsyncMock(return_value=[])
139+
140+
# Mock enable/disable CDP domain methods
141+
stagehand_page.enable_cdp_domain = AsyncMock()
142+
stagehand_page.disable_cdp_domain = AsyncMock()
143+
144+
# Mock _wait_for_settled_dom to avoid asyncio.sleep issues
145+
stagehand_page._wait_for_settled_dom = AsyncMock()
146+
90147
return stagehand_page
91148

92149

tests/mocks/mock_llm.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,9 @@ def _create_response(self, data: Any, model: str) -> MockLLMResponse:
144144
if isinstance(data, str):
145145
return MockLLMResponse(data, model=model)
146146
elif isinstance(data, dict):
147-
content = data.get("content", str(data))
147+
# For extract responses, convert dict to JSON string for content
148+
import json
149+
content = json.dumps(data)
148150
return MockLLMResponse(content, data=data, model=model)
149151
else:
150152
return MockLLMResponse(str(data), data=data, model=model)
@@ -247,4 +249,60 @@ def get_usage_stats(self) -> Dict[str, int]:
247249
"total_prompt_tokens": total_prompt_tokens,
248250
"total_completion_tokens": total_completion_tokens,
249251
"total_tokens": total_prompt_tokens + total_completion_tokens
250-
}
252+
}
253+
254+
def create_response(
255+
self,
256+
*,
257+
messages: list[dict[str, str]],
258+
model: Optional[str] = None,
259+
function_name: Optional[str] = None,
260+
**kwargs
261+
) -> MockLLMResponse:
262+
"""Create a response using the same interface as the real LLMClient"""
263+
# Use function_name to determine response type if available
264+
if function_name:
265+
response_type = function_name.lower()
266+
else:
267+
# Fall back to content-based detection
268+
content = str(messages).lower()
269+
response_type = self._determine_response_type(content)
270+
271+
# Track the call
272+
self.call_count += 1
273+
self.last_messages = messages
274+
self.last_model = model or self.default_model
275+
self.last_kwargs = kwargs
276+
277+
# Store call in history
278+
call_info = {
279+
"messages": messages,
280+
"model": self.last_model,
281+
"kwargs": kwargs,
282+
"function_name": function_name,
283+
"timestamp": asyncio.get_event_loop().time()
284+
}
285+
self.call_history.append(call_info)
286+
287+
# Simulate failure if configured
288+
if self.should_fail:
289+
raise Exception(self.failure_message)
290+
291+
# Check for custom responses first
292+
if response_type in self.custom_responses:
293+
response_data = self.custom_responses[response_type]
294+
if callable(response_data):
295+
response_data = response_data(messages, **kwargs)
296+
return self._create_response(response_data, model=self.last_model)
297+
298+
# Use default response mapping
299+
response_generator = self.response_mapping.get(response_type, self._default_response)
300+
response_data = response_generator(messages, **kwargs)
301+
302+
response = self._create_response(response_data, model=self.last_model)
303+
304+
# Call metrics callback if set
305+
if self.metrics_callback:
306+
self.metrics_callback(response, 100, response_type) # 100ms mock inference time
307+
308+
return response

0 commit comments

Comments
 (0)