|
7 | 7 | from stagehand.a11y.utils import get_accessibility_tree
|
8 | 8 | from stagehand.llm.inference import extract as extract_inference
|
9 | 9 | from stagehand.metrics import StagehandFunctionName # Changed import location
|
10 |
| -from stagehand.schemas import DEFAULT_EXTRACT_SCHEMA as DefaultExtractSchema, ExtractOptions, ExtractResult |
| 10 | +from stagehand.schemas import ( |
| 11 | + DEFAULT_EXTRACT_SCHEMA, |
| 12 | + ExtractOptions, |
| 13 | + ExtractResult, |
| 14 | +) |
11 | 15 | from stagehand.utils import inject_urls, transform_url_strings_to_ids
|
12 | 16 |
|
13 | 17 | T = TypeVar("T", bound=BaseModel)
|
@@ -93,7 +97,7 @@ async def extract(
|
93 | 97 | # TODO: Remove this once we have a better way to handle URLs
|
94 | 98 | transformed_schema, url_paths = transform_url_strings_to_ids(schema)
|
95 | 99 | else:
|
96 |
| - transformed_schema = DefaultExtractSchema |
| 100 | + transformed_schema = DEFAULT_EXTRACT_SCHEMA |
97 | 101 |
|
98 | 102 | # Use inference to call the LLM
|
99 | 103 | extraction_response = extract_inference(
|
@@ -149,15 +153,15 @@ async def extract(
|
149 | 153 | validated_model_instance = schema.model_validate(raw_data_dict)
|
150 | 154 | processed_data_payload = validated_model_instance # Payload is now the Pydantic model instance
|
151 | 155 | except Exception as e:
|
152 |
| - schema_name = getattr(schema, '__name__', str(schema)) |
| 156 | + schema_name = getattr(schema, "__name__", str(schema)) |
153 | 157 | self.logger.error(
|
154 | 158 | f"Failed to validate extracted data against schema {schema_name}: {e}. Keeping raw data dict in .data field."
|
155 | 159 | )
|
156 | 160 |
|
157 | 161 | # Create ExtractResult object with extracted data as fields
|
158 | 162 | if isinstance(processed_data_payload, dict):
|
159 | 163 | result = ExtractResult(**processed_data_payload)
|
160 |
| - elif hasattr(processed_data_payload, 'model_dump'): |
| 164 | + elif hasattr(processed_data_payload, "model_dump"): |
161 | 165 | # For Pydantic models, convert to dict and spread as fields
|
162 | 166 | result = ExtractResult(**processed_data_payload.model_dump())
|
163 | 167 | else:
|
|
0 commit comments