Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion stagehand/handlers/extract_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async def extract(
# TODO: Remove this once we have a better way to handle URLs
transformed_schema, url_paths = transform_url_strings_to_ids(schema)
else:
transformed_schema = DefaultExtractSchema
schema = transformed_schema = DefaultExtractSchema

# Use inference to call the LLM
extraction_response = extract_inference(
Expand Down
88 changes: 46 additions & 42 deletions stagehand/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ObserveOptions,
ObserveResult,
)
from .types import DefaultExtractSchema

_INJECTION_SCRIPT = None

Expand Down Expand Up @@ -236,15 +237,36 @@ async def extract(
# Otherwise use API implementation
# Allow for no options to extract the entire page
if options is None:
options_obj = ExtractOptions()
payload = {}
# Convert string to ExtractOptions if needed
elif isinstance(options, str):
options = ExtractOptions(instruction=options)
payload = options.model_dump(exclude_none=True, by_alias=True)
options_obj = ExtractOptions(instruction=options)
payload = options_obj.model_dump(exclude_none=True, by_alias=True)
# Otherwise, it should be an ExtractOptions object
else:
options_obj = options
# Allow extraction without instruction if other options (like schema) are provided
payload = options.model_dump(exclude_none=True, by_alias=True)
payload = options_obj.model_dump(exclude_none=True, by_alias=True)

# Determine the schema to pass to the handler
schema_to_validate_with = None
if (
hasattr(options_obj, "schema_definition")
and options_obj.schema_definition != DEFAULT_EXTRACT_SCHEMA
):
if isinstance(options_obj.schema_definition, type) and issubclass(
options_obj.schema_definition, BaseModel
):
# Case 1: Pydantic model class
schema_to_validate_with = options_obj.schema_definition
elif isinstance(options_obj.schema_definition, dict):
# TODO: revisit this case to pass the json_schema since litellm has a bug when passing it directly
# Case 2: Dictionary
# Assume it's a direct JSON schema dictionary
schema_to_validate_with = options_obj.schema_definition
else:
schema_to_validate_with = DefaultExtractSchema

# If in LOCAL mode, use local implementation
if self._stagehand.env == "LOCAL":
Expand All @@ -263,57 +285,39 @@ async def extract(
)
return result

# Convert string to ExtractOptions if needed
if isinstance(options, str):
options = ExtractOptions(instruction=options)

# Determine the schema to pass to the handler
schema_to_pass_to_handler = None
if (
hasattr(options, "schema_definition")
and options.schema_definition != DEFAULT_EXTRACT_SCHEMA
):
if isinstance(options.schema_definition, type) and issubclass(
options.schema_definition, BaseModel
):
# Case 1: Pydantic model class
schema_to_pass_to_handler = options.schema_definition
elif isinstance(options.schema_definition, dict):
# TODO: revisit this case to pass the json_schema since litellm has a bug when passing it directly
# Case 2: Dictionary
# Assume it's a direct JSON schema dictionary
schema_to_pass_to_handler = options.schema_definition

# Call local extract implementation
result = await self._extract_handler.extract(
options,
schema_to_pass_to_handler,
options_obj,
schema_to_validate_with,
)
return result.data

# Use API
lock = self._stagehand._get_lock_for_session()
async with lock:
result = await self._stagehand._execute("extract", payload)
result_dict = await self._stagehand._execute("extract", payload)

# Attempt to parse the result using the base ExtractResult,
# which allows extra fields based on the dynamic schema.
if isinstance(result, dict):
if isinstance(result_dict, dict):
# Pydantic will validate against known fields + allow extras if configured
try:
# Note: We don't know the exact return structure here,
# ExtractResult allows extra fields.
# The user needs to access data based on their schema.
return ExtractResult(**result)
except Exception as e:
self._stagehand.logger.error(f"Failed to parse extract result: {e}")
# Return raw dict if parsing fails, or raise? Returning dict for now.
return result # type: ignore
processed_data_payload = result_dict
if schema_to_validate_with and isinstance(processed_data_payload, dict):
try:
validated_model = schema_to_validate_with.model_validate(
processed_data_payload
)
processed_data_payload = (
validated_model # Payload is now the Pydantic model instance
)
except Exception as e:
self._stagehand.logger.error(
f"Failed to validate extracted data against schema {schema_to_validate_with.__name__}: {e}. Keeping raw data dict in .data field."
)
return ExtractResult(data=processed_data_payload).data
# Handle unexpected return types
self._stagehand.logger.info(
f"Unexpected result type from extract: {type(result)}"
f"Unexpected result type from extract: {type(result_dict)}"
)
# Return raw result if not dict or raise error
return result # type: ignore
return result_dict

async def screenshot(self, options: Optional[dict] = None) -> str:
"""
Expand Down
Loading