Skip to content

Commit 20b977c

Browse files
update kwargs (#93)
* update kwargs * move kwargs towards user interaction from handlers * revert * add kwargs to act/extract/observe * cleanup --------- Co-authored-by: miguel <[email protected]>
1 parent 7068412 commit 20b977c

File tree

1 file changed

+87
-31
lines changed

1 file changed

+87
-31
lines changed

stagehand/page.py

Lines changed: 87 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,12 @@ async def act(
134134
elif isinstance(action_or_result, str):
135135
options = ActOptions(action=action_or_result, **kwargs)
136136
payload = options.model_dump(exclude_none=True, by_alias=True)
137+
elif isinstance(action_or_result, ActOptions):
138+
payload = action_or_result.model_dump(exclude_none=True, by_alias=True)
137139
else:
138-
payload = options.model_dump(exclude_none=True, by_alias=True)
140+
raise TypeError(
141+
"Invalid arguments for 'act'. Expected str, ObserveResult, or ActOptions."
142+
)
139143

140144
# TODO: Temporary until we move api based logic to client
141145
if self._stagehand.env == "LOCAL":
@@ -158,12 +162,19 @@ async def act(
158162
return ActResult(**result)
159163
return result
160164

161-
async def observe(self, options: Union[str, ObserveOptions]) -> list[ObserveResult]:
165+
async def observe(
166+
self,
167+
options_or_instruction: Union[str, ObserveOptions, None] = None,
168+
**kwargs,
169+
) -> list[ObserveResult]:
162170
"""
163171
Make an AI observation via the Stagehand server.
164172
165173
Args:
166-
instruction (str): The observation instruction for the AI.
174+
options_or_instruction (Union[str, ObserveOptions, None]):
175+
- A string with the observation instruction for the AI.
176+
- An ObserveOptions object.
177+
- None to use default options.
167178
**kwargs: Additional options corresponding to fields in ObserveOptions
168179
(e.g., model_name, only_visible, return_action).
169180
@@ -172,15 +183,29 @@ async def observe(self, options: Union[str, ObserveOptions]) -> list[ObserveResu
172183
"""
173184
await self.ensure_injection()
174185

175-
# Convert string to ObserveOptions if needed
176-
if isinstance(options, str):
177-
options = ObserveOptions(instruction=options)
178-
# Handle None by creating an empty options object
179-
elif options is None:
180-
options = ObserveOptions()
186+
options_dict = {}
187+
188+
if isinstance(options_or_instruction, ObserveOptions):
189+
# Already a pydantic object – take it as is.
190+
options_obj = options_or_instruction
191+
else:
192+
if isinstance(options_or_instruction, str):
193+
options_dict["instruction"] = options_or_instruction
194+
195+
# Merge any explicit keyword arguments (highest priority)
196+
options_dict.update(kwargs)
197+
198+
if not options_dict:
199+
raise TypeError("No instruction provided for observe.")
200+
201+
try:
202+
options_obj = ObserveOptions(**options_dict)
203+
except Exception as e:
204+
raise TypeError(f"Invalid observe options: {e}") from e
205+
206+
# Serialized payload for server / local handlers
207+
payload = options_obj.model_dump(exclude_none=True, by_alias=True)
181208

182-
# Otherwise use API implementation
183-
payload = options.model_dump(exclude_none=True, by_alias=True)
184209
# If in LOCAL mode, use local implementation
185210
if self._stagehand.env == "LOCAL":
186211
self._stagehand.logger.debug(
@@ -193,7 +218,7 @@ async def observe(self, options: Union[str, ObserveOptions]) -> list[ObserveResu
193218

194219
# Call local observe implementation
195220
result = await self._observe_handler.observe(
196-
options,
221+
options_obj,
197222
from_act=False,
198223
)
199224

@@ -216,43 +241,74 @@ async def observe(self, options: Union[str, ObserveOptions]) -> list[ObserveResu
216241
return []
217242

218243
async def extract(
219-
self, options: Union[str, ExtractOptions, None] = None
244+
self,
245+
options_or_instruction: Union[str, ExtractOptions, None] = None,
246+
*,
247+
schema: Optional[type[BaseModel]] = None,
248+
**kwargs,
220249
) -> ExtractResult:
221-
# TODO update args
222250
"""
223251
Extract data using AI via the Stagehand server.
224252
225253
Args:
226-
instruction (Optional[str]): Instruction specifying what data to extract.
227-
If None, attempts to extract the entire page content
228-
based on other kwargs (e.g., schema_definition).
254+
options_or_instruction (Union[str, ExtractOptions, None]):
255+
- A string with the instruction specifying what data to extract.
256+
- An ExtractOptions object.
257+
- None to extract the entire page content.
258+
schema (Optional[Union[type[BaseModel], None]]):
259+
A Pydantic model class that defines the structure
260+
of the expected extracted data.
229261
**kwargs: Additional options corresponding to fields in ExtractOptions
230-
(e.g., schema_definition, model_name, use_text_extract).
262+
(e.g., model_name, use_text_extract, selector, dom_settle_timeout_ms).
231263
232264
Returns:
233265
ExtractResult: Depending on the type of the schema provided, the result will be a Pydantic model or JSON representation of the extracted data.
234266
"""
235267
await self.ensure_injection()
236268

237-
# Otherwise use API implementation
238-
# Allow for no options to extract the entire page
239-
if options is None:
240-
options_obj = ExtractOptions()
269+
options_dict = {}
270+
271+
if isinstance(options_or_instruction, ExtractOptions):
272+
options_obj = options_or_instruction
273+
else:
274+
if isinstance(options_or_instruction, str):
275+
options_dict["instruction"] = options_or_instruction
276+
277+
# Merge keyword overrides (highest priority)
278+
options_dict.update(kwargs)
279+
280+
# Ensure schema_definition is only set once (explicit arg precedence)
281+
if schema is not None:
282+
options_dict["schema_definition"] = schema
283+
284+
if options_dict:
285+
try:
286+
options_obj = ExtractOptions(**options_dict)
287+
except Exception as e:
288+
raise TypeError(f"Invalid extract options: {e}") from e
289+
else:
290+
# No options_dict provided and no ExtractOptions given: full page extract.
291+
options_obj = None
292+
293+
# If we started with an existing ExtractOptions instance and the caller
294+
# explicitly provided a schema, override it
295+
if (
296+
schema is not None
297+
and isinstance(options_obj, ExtractOptions)
298+
and options_obj.schema_definition != schema
299+
):
300+
options_obj = options_obj.model_copy(update={"schema_definition": schema})
301+
302+
if options_obj is None:
241303
payload = {}
242-
# Convert string to ExtractOptions if needed
243-
elif isinstance(options, str):
244-
options_obj = ExtractOptions(instruction=options)
245-
payload = options_obj.model_dump(exclude_none=True, by_alias=True)
246-
# Otherwise, it should be an ExtractOptions object
247304
else:
248-
options_obj = options
249-
# Allow extraction without instruction if other options (like schema) are provided
250305
payload = options_obj.model_dump(exclude_none=True, by_alias=True)
251306

252307
# Determine the schema to pass to the handler
253308
schema_to_validate_with = None
254309
if (
255-
hasattr(options_obj, "schema_definition")
310+
options_obj is not None
311+
and options_obj.schema_definition is not None
256312
and options_obj.schema_definition != DEFAULT_EXTRACT_SCHEMA
257313
):
258314
if isinstance(options_obj.schema_definition, type) and issubclass(
@@ -277,7 +333,7 @@ async def extract(
277333
)
278334

279335
# Allow for no options to extract the entire page
280-
if options is None:
336+
if options_obj is None:
281337
# Call local extract implementation with no options
282338
result = await self._extract_handler.extract(
283339
None,

0 commit comments

Comments
 (0)