diff --git a/stagehand/page.py b/stagehand/page.py index 05f20056..ecaf1126 100644 --- a/stagehand/page.py +++ b/stagehand/page.py @@ -134,8 +134,12 @@ async def act( elif isinstance(action_or_result, str): options = ActOptions(action=action_or_result, **kwargs) payload = options.model_dump(exclude_none=True, by_alias=True) + elif isinstance(action_or_result, ActOptions): + payload = action_or_result.model_dump(exclude_none=True, by_alias=True) else: - payload = options.model_dump(exclude_none=True, by_alias=True) + raise TypeError( + "Invalid arguments for 'act'. Expected str, ObserveResult, or ActOptions." + ) # TODO: Temporary until we move api based logic to client if self._stagehand.env == "LOCAL": @@ -158,12 +162,19 @@ async def act( return ActResult(**result) return result - async def observe(self, options: Union[str, ObserveOptions]) -> list[ObserveResult]: + async def observe( + self, + options_or_instruction: Union[str, ObserveOptions, None] = None, + **kwargs, + ) -> list[ObserveResult]: """ Make an AI observation via the Stagehand server. Args: - instruction (str): The observation instruction for the AI. + options_or_instruction (Union[str, ObserveOptions, None]): + - A string with the observation instruction for the AI. + - An ObserveOptions object. + - None to use default options. **kwargs: Additional options corresponding to fields in ObserveOptions (e.g., model_name, only_visible, return_action). @@ -172,15 +183,29 @@ async def observe(self, options: Union[str, ObserveOptions]) -> list[ObserveResu """ await self.ensure_injection() - # Convert string to ObserveOptions if needed - if isinstance(options, str): - options = ObserveOptions(instruction=options) - # Handle None by creating an empty options object - elif options is None: - options = ObserveOptions() + options_dict = {} + + if isinstance(options_or_instruction, ObserveOptions): + # Already a pydantic object – take it as is. + options_obj = options_or_instruction + else: + if isinstance(options_or_instruction, str): + options_dict["instruction"] = options_or_instruction + + # Merge any explicit keyword arguments (highest priority) + options_dict.update(kwargs) + + if not options_dict: + raise TypeError("No instruction provided for observe.") + + try: + options_obj = ObserveOptions(**options_dict) + except Exception as e: + raise TypeError(f"Invalid observe options: {e}") from e + + # Serialized payload for server / local handlers + payload = options_obj.model_dump(exclude_none=True, by_alias=True) - # Otherwise use API implementation - payload = options.model_dump(exclude_none=True, by_alias=True) # If in LOCAL mode, use local implementation if self._stagehand.env == "LOCAL": self._stagehand.logger.debug( @@ -193,7 +218,7 @@ async def observe(self, options: Union[str, ObserveOptions]) -> list[ObserveResu # Call local observe implementation result = await self._observe_handler.observe( - options, + options_obj, from_act=False, ) @@ -216,43 +241,74 @@ async def observe(self, options: Union[str, ObserveOptions]) -> list[ObserveResu return [] async def extract( - self, options: Union[str, ExtractOptions, None] = None + self, + options_or_instruction: Union[str, ExtractOptions, None] = None, + *, + schema: Optional[type[BaseModel]] = None, + **kwargs, ) -> ExtractResult: - # TODO update args """ Extract data using AI via the Stagehand server. Args: - instruction (Optional[str]): Instruction specifying what data to extract. - If None, attempts to extract the entire page content - based on other kwargs (e.g., schema_definition). + options_or_instruction (Union[str, ExtractOptions, None]): + - A string with the instruction specifying what data to extract. + - An ExtractOptions object. + - None to extract the entire page content. + schema (Optional[Union[type[BaseModel], None]]): + A Pydantic model class that defines the structure + of the expected extracted data. **kwargs: Additional options corresponding to fields in ExtractOptions - (e.g., schema_definition, model_name, use_text_extract). + (e.g., model_name, use_text_extract, selector, dom_settle_timeout_ms). Returns: ExtractResult: Depending on the type of the schema provided, the result will be a Pydantic model or JSON representation of the extracted data. """ await self.ensure_injection() - # Otherwise use API implementation - # Allow for no options to extract the entire page - if options is None: - options_obj = ExtractOptions() + options_dict = {} + + if isinstance(options_or_instruction, ExtractOptions): + options_obj = options_or_instruction + else: + if isinstance(options_or_instruction, str): + options_dict["instruction"] = options_or_instruction + + # Merge keyword overrides (highest priority) + options_dict.update(kwargs) + + # Ensure schema_definition is only set once (explicit arg precedence) + if schema is not None: + options_dict["schema_definition"] = schema + + if options_dict: + try: + options_obj = ExtractOptions(**options_dict) + except Exception as e: + raise TypeError(f"Invalid extract options: {e}") from e + else: + # No options_dict provided and no ExtractOptions given: full page extract. + options_obj = None + + # If we started with an existing ExtractOptions instance and the caller + # explicitly provided a schema, override it + if ( + schema is not None + and isinstance(options_obj, ExtractOptions) + and options_obj.schema_definition != schema + ): + options_obj = options_obj.model_copy(update={"schema_definition": schema}) + + if options_obj is None: payload = {} - # Convert string to ExtractOptions if needed - elif isinstance(options, str): - options_obj = ExtractOptions(instruction=options) - payload = options_obj.model_dump(exclude_none=True, by_alias=True) - # Otherwise, it should be an ExtractOptions object else: - options_obj = options - # Allow extraction without instruction if other options (like schema) are provided payload = options_obj.model_dump(exclude_none=True, by_alias=True) # Determine the schema to pass to the handler schema_to_validate_with = None if ( - hasattr(options_obj, "schema_definition") + options_obj is not None + and options_obj.schema_definition is not None and options_obj.schema_definition != DEFAULT_EXTRACT_SCHEMA ): if isinstance(options_obj.schema_definition, type) and issubclass( @@ -277,7 +333,7 @@ async def extract( ) # Allow for no options to extract the entire page - if options is None: + if options_obj is None: # Call local extract implementation with no options result = await self._extract_handler.extract( None,