Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/daffy-rapid-turaco.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"stagehand": patch
---

Fix parsing schema for extract with no arguments (full page extract)
17 changes: 17 additions & 0 deletions format
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash

# Define source directories (adjust as needed)
SOURCE_DIRS="stagehand"

# Apply Black formatting first
echo "Applying Black formatting..."
black $SOURCE_DIRS

# Apply Ruff with autofix for all issues (including import sorting)
echo "Applying Ruff autofixes (including import sorting)..."
ruff check --fix $SOURCE_DIRS

echo "Checking for remaining issues..."
ruff check $SOURCE_DIRS

echo "Done! Code has been formatted and linted."
11 changes: 9 additions & 2 deletions stagehand/handlers/extract_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from stagehand.a11y.utils import get_accessibility_tree
from stagehand.llm.inference import extract as extract_inference
from stagehand.metrics import StagehandFunctionName # Changed import location
from stagehand.types import DefaultExtractSchema, ExtractOptions, ExtractResult
from stagehand.types import (
DefaultExtractSchema,
EmptyExtractSchema,
ExtractOptions,
ExtractResult,
)
from stagehand.utils import inject_urls, transform_url_strings_to_ids

T = TypeVar("T", bound=BaseModel)
Expand Down Expand Up @@ -166,4 +171,6 @@ async def _extract_page_text(self) -> ExtractResult:

tree = await get_accessibility_tree(self.stagehand_page, self.logger)
output_string = tree["simplified"]
return ExtractResult(data=output_string)
output_dict = {"page_text": output_string}
validated_model = EmptyExtractSchema.model_validate(output_dict)
return ExtractResult(data=validated_model).data
19 changes: 12 additions & 7 deletions stagehand/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
ObserveOptions,
ObserveResult,
)
from .types import DefaultExtractSchema
from .types import DefaultExtractSchema, EmptyExtractSchema

_INJECTION_SCRIPT = None

Expand Down Expand Up @@ -361,12 +361,17 @@ async def extract(
processed_data_payload = result_dict
if schema_to_validate_with and isinstance(processed_data_payload, dict):
try:
validated_model = schema_to_validate_with.model_validate(
processed_data_payload
)
processed_data_payload = (
validated_model # Payload is now the Pydantic model instance
)
# For extract with no params
if not options_obj:
validated_model = EmptyExtractSchema.model_validate(
processed_data_payload
)
processed_data_payload = validated_model
else:
validated_model = schema_to_validate_with.model_validate(
processed_data_payload
)
processed_data_payload = validated_model
except Exception as e:
self._stagehand.logger.error(
f"Failed to validate extracted data against schema {schema_to_validate_with.__name__}: {e}. Keeping raw data dict in .data field."
Expand Down
2 changes: 2 additions & 0 deletions stagehand/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ActOptions,
ActResult,
DefaultExtractSchema,
EmptyExtractSchema,
ExtractOptions,
ExtractResult,
MetadataSchema,
Expand Down Expand Up @@ -56,4 +57,5 @@
"AgentConfig",
"AgentExecuteOptions",
"AgentResult",
"EmptyExtractSchema",
]
4 changes: 4 additions & 0 deletions stagehand/types/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ class DefaultExtractSchema(BaseModel):
extraction: str


class EmptyExtractSchema(BaseModel):
page_text: str


class ObserveElementSchema(BaseModel):
element_id: int
description: str = Field(
Expand Down