Skip to content

Commit 14f85b2

Browse files
authored
fix result payload parsing and dump on pydantic (#90)
1 parent c2c2674 commit 14f85b2

File tree

2 files changed

+47
-43
lines changed

2 files changed

+47
-43
lines changed

stagehand/handlers/extract_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ async def extract(
9393
# TODO: Remove this once we have a better way to handle URLs
9494
transformed_schema, url_paths = transform_url_strings_to_ids(schema)
9595
else:
96-
transformed_schema = DefaultExtractSchema
96+
schema = transformed_schema = DefaultExtractSchema
9797

9898
# Use inference to call the LLM
9999
extraction_response = extract_inference(

stagehand/page.py

Lines changed: 46 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
ObserveOptions,
1717
ObserveResult,
1818
)
19+
from .types import DefaultExtractSchema
1920

2021
_INJECTION_SCRIPT = None
2122

@@ -236,15 +237,36 @@ async def extract(
236237
# Otherwise use API implementation
237238
# Allow for no options to extract the entire page
238239
if options is None:
240+
options_obj = ExtractOptions()
239241
payload = {}
240242
# Convert string to ExtractOptions if needed
241243
elif isinstance(options, str):
242-
options = ExtractOptions(instruction=options)
243-
payload = options.model_dump(exclude_none=True, by_alias=True)
244+
options_obj = ExtractOptions(instruction=options)
245+
payload = options_obj.model_dump(exclude_none=True, by_alias=True)
244246
# Otherwise, it should be an ExtractOptions object
245247
else:
248+
options_obj = options
246249
# Allow extraction without instruction if other options (like schema) are provided
247-
payload = options.model_dump(exclude_none=True, by_alias=True)
250+
payload = options_obj.model_dump(exclude_none=True, by_alias=True)
251+
252+
# Determine the schema to pass to the handler
253+
schema_to_validate_with = None
254+
if (
255+
hasattr(options_obj, "schema_definition")
256+
and options_obj.schema_definition != DEFAULT_EXTRACT_SCHEMA
257+
):
258+
if isinstance(options_obj.schema_definition, type) and issubclass(
259+
options_obj.schema_definition, BaseModel
260+
):
261+
# Case 1: Pydantic model class
262+
schema_to_validate_with = options_obj.schema_definition
263+
elif isinstance(options_obj.schema_definition, dict):
264+
# TODO: revisit this case to pass the json_schema since litellm has a bug when passing it directly
265+
# Case 2: Dictionary
266+
# Assume it's a direct JSON schema dictionary
267+
schema_to_validate_with = options_obj.schema_definition
268+
else:
269+
schema_to_validate_with = DefaultExtractSchema
248270

249271
# If in LOCAL mode, use local implementation
250272
if self._stagehand.env == "LOCAL":
@@ -263,57 +285,39 @@ async def extract(
263285
)
264286
return result
265287

266-
# Convert string to ExtractOptions if needed
267-
if isinstance(options, str):
268-
options = ExtractOptions(instruction=options)
269-
270-
# Determine the schema to pass to the handler
271-
schema_to_pass_to_handler = None
272-
if (
273-
hasattr(options, "schema_definition")
274-
and options.schema_definition != DEFAULT_EXTRACT_SCHEMA
275-
):
276-
if isinstance(options.schema_definition, type) and issubclass(
277-
options.schema_definition, BaseModel
278-
):
279-
# Case 1: Pydantic model class
280-
schema_to_pass_to_handler = options.schema_definition
281-
elif isinstance(options.schema_definition, dict):
282-
# TODO: revisit this case to pass the json_schema since litellm has a bug when passing it directly
283-
# Case 2: Dictionary
284-
# Assume it's a direct JSON schema dictionary
285-
schema_to_pass_to_handler = options.schema_definition
286-
287288
# Call local extract implementation
288289
result = await self._extract_handler.extract(
289-
options,
290-
schema_to_pass_to_handler,
290+
options_obj,
291+
schema_to_validate_with,
291292
)
292293
return result.data
293294

295+
# Use API
294296
lock = self._stagehand._get_lock_for_session()
295297
async with lock:
296-
result = await self._stagehand._execute("extract", payload)
298+
result_dict = await self._stagehand._execute("extract", payload)
297299

298-
# Attempt to parse the result using the base ExtractResult,
299-
# which allows extra fields based on the dynamic schema.
300-
if isinstance(result, dict):
300+
if isinstance(result_dict, dict):
301301
# Pydantic will validate against known fields + allow extras if configured
302-
try:
303-
# Note: We don't know the exact return structure here,
304-
# ExtractResult allows extra fields.
305-
# The user needs to access data based on their schema.
306-
return ExtractResult(**result)
307-
except Exception as e:
308-
self._stagehand.logger.error(f"Failed to parse extract result: {e}")
309-
# Return raw dict if parsing fails, or raise? Returning dict for now.
310-
return result # type: ignore
302+
processed_data_payload = result_dict
303+
if schema_to_validate_with and isinstance(processed_data_payload, dict):
304+
try:
305+
validated_model = schema_to_validate_with.model_validate(
306+
processed_data_payload
307+
)
308+
processed_data_payload = (
309+
validated_model # Payload is now the Pydantic model instance
310+
)
311+
except Exception as e:
312+
self._stagehand.logger.error(
313+
f"Failed to validate extracted data against schema {schema_to_validate_with.__name__}: {e}. Keeping raw data dict in .data field."
314+
)
315+
return ExtractResult(data=processed_data_payload).data
311316
# Handle unexpected return types
312317
self._stagehand.logger.info(
313-
f"Unexpected result type from extract: {type(result)}"
318+
f"Unexpected result type from extract: {type(result_dict)}"
314319
)
315-
# Return raw result if not dict or raise error
316-
return result # type: ignore
320+
return result_dict
317321

318322
async def screenshot(self, options: Optional[dict] = None) -> str:
319323
"""

0 commit comments

Comments
 (0)