diff --git a/examples/quickstart.py b/examples/quickstart.py index 20daf858..b0b2ba06 100644 --- a/examples/quickstart.py +++ b/examples/quickstart.py @@ -44,7 +44,7 @@ async def main(): # Extract companies using structured schema companies_data = await page.extract( "Extract names and descriptions of 5 companies in batch 3", - schema=Companies + schema_definition=Companies ) # Display results @@ -66,4 +66,4 @@ async def main(): await stagehand.close() if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/stagehand/schemas.py b/stagehand/schemas.py index 5ff23fb2..62f98d2f 100644 --- a/stagehand/schemas.py +++ b/stagehand/schemas.py @@ -158,6 +158,14 @@ def _resolve_references(self, obj: Any, definitions: dict, ref_prefix: str) -> N model_config = ConfigDict(arbitrary_types_allowed=True) +class AttrDict(dict): + """A dictionary that allows attribute-style access to its items.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__dict__ = self + + class ExtractResult(StagehandBaseModel): """ Result of the 'extract' command. @@ -171,6 +179,24 @@ class ExtractResult(StagehandBaseModel): model_config = ConfigDict(extra="allow") # Allow any extra fields + def __init__(self, **data): + """Initialize and recursively convert nested dictionaries to AttrDict objects.""" + # Convert nested dictionaries to AttrDict for attribute access + converted_data = self._convert_to_attr_dict(data) + super().__init__(**converted_data) + + def _convert_to_attr_dict(self, obj): + """Recursively convert dictionaries to AttrDict objects.""" + if isinstance(obj, dict): + # Convert dict to AttrDict and recursively convert nested objects + attr_dict = AttrDict() + for key, value in obj.items(): + attr_dict[key] = self._convert_to_attr_dict(value) + return attr_dict + elif isinstance(obj, list): + return [self._convert_to_attr_dict(item) for item in obj] + return obj + def __getitem__(self, key): """ Enable dictionary-style access to attributes.