diff --git a/stagehand/handlers/extract_handler.py b/stagehand/handlers/extract_handler.py index 9025ff82..21af7a99 100644 --- a/stagehand/handlers/extract_handler.py +++ b/stagehand/handlers/extract_handler.py @@ -93,7 +93,7 @@ async def extract( # TODO: Remove this once we have a better way to handle URLs transformed_schema, url_paths = transform_url_strings_to_ids(schema) else: - transformed_schema = DefaultExtractSchema + schema = transformed_schema = DefaultExtractSchema # Use inference to call the LLM extraction_response = extract_inference( diff --git a/stagehand/page.py b/stagehand/page.py index d01c83af..05f20056 100644 --- a/stagehand/page.py +++ b/stagehand/page.py @@ -16,6 +16,7 @@ ObserveOptions, ObserveResult, ) +from .types import DefaultExtractSchema _INJECTION_SCRIPT = None @@ -236,15 +237,36 @@ async def extract( # Otherwise use API implementation # Allow for no options to extract the entire page if options is None: + options_obj = ExtractOptions() payload = {} # Convert string to ExtractOptions if needed elif isinstance(options, str): - options = ExtractOptions(instruction=options) - payload = options.model_dump(exclude_none=True, by_alias=True) + options_obj = ExtractOptions(instruction=options) + payload = options_obj.model_dump(exclude_none=True, by_alias=True) # Otherwise, it should be an ExtractOptions object else: + options_obj = options # Allow extraction without instruction if other options (like schema) are provided - payload = options.model_dump(exclude_none=True, by_alias=True) + payload = options_obj.model_dump(exclude_none=True, by_alias=True) + + # Determine the schema to pass to the handler + schema_to_validate_with = None + if ( + hasattr(options_obj, "schema_definition") + and options_obj.schema_definition != DEFAULT_EXTRACT_SCHEMA + ): + if isinstance(options_obj.schema_definition, type) and issubclass( + options_obj.schema_definition, BaseModel + ): + # Case 1: Pydantic model class + schema_to_validate_with = options_obj.schema_definition + elif isinstance(options_obj.schema_definition, dict): + # TODO: revisit this case to pass the json_schema since litellm has a bug when passing it directly + # Case 2: Dictionary + # Assume it's a direct JSON schema dictionary + schema_to_validate_with = options_obj.schema_definition + else: + schema_to_validate_with = DefaultExtractSchema # If in LOCAL mode, use local implementation if self._stagehand.env == "LOCAL": @@ -263,57 +285,39 @@ async def extract( ) return result - # Convert string to ExtractOptions if needed - if isinstance(options, str): - options = ExtractOptions(instruction=options) - - # Determine the schema to pass to the handler - schema_to_pass_to_handler = None - if ( - hasattr(options, "schema_definition") - and options.schema_definition != DEFAULT_EXTRACT_SCHEMA - ): - if isinstance(options.schema_definition, type) and issubclass( - options.schema_definition, BaseModel - ): - # Case 1: Pydantic model class - schema_to_pass_to_handler = options.schema_definition - elif isinstance(options.schema_definition, dict): - # TODO: revisit this case to pass the json_schema since litellm has a bug when passing it directly - # Case 2: Dictionary - # Assume it's a direct JSON schema dictionary - schema_to_pass_to_handler = options.schema_definition - # Call local extract implementation result = await self._extract_handler.extract( - options, - schema_to_pass_to_handler, + options_obj, + schema_to_validate_with, ) return result.data + # Use API lock = self._stagehand._get_lock_for_session() async with lock: - result = await self._stagehand._execute("extract", payload) + result_dict = await self._stagehand._execute("extract", payload) - # Attempt to parse the result using the base ExtractResult, - # which allows extra fields based on the dynamic schema. - if isinstance(result, dict): + if isinstance(result_dict, dict): # Pydantic will validate against known fields + allow extras if configured - try: - # Note: We don't know the exact return structure here, - # ExtractResult allows extra fields. - # The user needs to access data based on their schema. - return ExtractResult(**result) - except Exception as e: - self._stagehand.logger.error(f"Failed to parse extract result: {e}") - # Return raw dict if parsing fails, or raise? Returning dict for now. - return result # type: ignore + processed_data_payload = result_dict + if schema_to_validate_with and isinstance(processed_data_payload, dict): + try: + validated_model = schema_to_validate_with.model_validate( + processed_data_payload + ) + processed_data_payload = ( + validated_model # Payload is now the Pydantic model instance + ) + except Exception as e: + self._stagehand.logger.error( + f"Failed to validate extracted data against schema {schema_to_validate_with.__name__}: {e}. Keeping raw data dict in .data field." + ) + return ExtractResult(data=processed_data_payload).data # Handle unexpected return types self._stagehand.logger.info( - f"Unexpected result type from extract: {type(result)}" + f"Unexpected result type from extract: {type(result_dict)}" ) - # Return raw result if not dict or raise error - return result # type: ignore + return result_dict async def screenshot(self, options: Optional[dict] = None) -> str: """