Skip to content

Commit 7270b64

Browse files
add act from observe (#12)
* add act from observe * change all camel case to snake case
1 parent c60cd13 commit 7270b64

File tree

4 files changed

+59
-33
lines changed

4 files changed

+59
-33
lines changed

stagehand/client.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(
4747
debug_dom: Optional[bool] = None,
4848
httpx_client: Optional[httpx.AsyncClient] = None,
4949
timeout_settings: Optional[httpx.Timeout] = None,
50+
model_client_options: Optional[Dict[str, Any]] = None,
5051
):
5152
"""
5253
Initialize the Stagehand client.
@@ -65,6 +66,7 @@ def __init__(
6566
debug_dom (Optional[bool]): Whether to enable DOM debugging mode.
6667
httpx_client (Optional[httpx.AsyncClient]): Optional custom httpx.AsyncClient instance.
6768
timeout_settings (Optional[httpx.Timeout]): Optional custom timeout settings for httpx.
69+
model_client_options (Optional[Dict[str, Any]]): Optional model client options.
6870
"""
6971
self.server_url = server_url or os.getenv("STAGEHAND_SERVER_URL")
7072

@@ -92,6 +94,7 @@ def __init__(
9294
# Additional config parameters available for future use:
9395
self.headless = config.headless
9496
self.enable_caching = config.enable_caching
97+
self.model_client_options = model_client_options
9598
else:
9699
self.browserbase_api_key = browserbase_api_key or os.getenv(
97100
"BROWSERBASE_API_KEY"
@@ -104,6 +107,7 @@ def __init__(
104107
self.model_name = model_name
105108
self.dom_settle_timeout_ms = dom_settle_timeout_ms
106109
self.debug_dom = debug_dom
110+
self.model_client_options = model_client_options
107111

108112
self.on_log = on_log
109113
self.verbose = verbose
@@ -312,6 +316,9 @@ async def _create_session(self):
312316
"verbose": self.verbose,
313317
"debugDom": self.debug_dom,
314318
}
319+
320+
if hasattr(self, "model_client_options") and self.model_client_options:
321+
payload["modelClientOptions"] = self.model_client_options
315322

316323
headers = {
317324
"x-bb-api-key": self.browserbase_api_key,
@@ -350,21 +357,25 @@ async def _execute(self, method: str, payload: Dict[str, Any]) -> Any:
350357
}
351358
if self.model_api_key:
352359
headers["x-model-api-key"] = self.model_api_key
353-
360+
361+
modified_payload = dict(payload)
362+
if hasattr(self, "model_client_options") and self.model_client_options and "modelClientOptions" not in modified_payload:
363+
modified_payload["modelClientOptions"] = self.model_client_options
364+
354365
client = self.httpx_client or httpx.AsyncClient(timeout=self.timeout_settings)
355366
self._log(f"\n==== EXECUTING {method.upper()} ====", level=3)
356367
self._log(
357368
f"URL: {self.server_url}/sessions/{self.session_id}/{method}", level=3
358369
)
359-
self._log(f"Payload: {payload}", level=3)
370+
self._log(f"Payload: {modified_payload}", level=3)
360371
self._log(f"Headers: {headers}", level=3)
361-
372+
362373
async with client:
363374
try:
364375
async with client.stream(
365376
"POST",
366377
f"{self.server_url}/sessions/{self.session_id}/{method}",
367-
json=payload,
378+
json=modified_payload,
368379
headers=headers,
369380
) as response:
370381
if response.status_code != 200:

stagehand/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class StagehandConfig(BaseModel):
2020
enable_caching (Optional[bool]): Enable caching functionality.
2121
browserbase_session_id (Optional[str]): Session ID for resuming Browserbase sessions.
2222
model_name (Optional[str]): Name of the model to use.
23-
selfHeal (Optional[bool]): Enable self-healing functionality.
23+
self_heal (Optional[bool]): Enable self-healing functionality.
2424
"""
2525

2626
env: str = "BROWSERBASE"
@@ -53,8 +53,8 @@ class StagehandConfig(BaseModel):
5353
model_name: Optional[str] = Field(
5454
AvailableModel.GPT_4O, alias="modelName", description="Name of the model to use"
5555
)
56-
selfHeal: Optional[bool] = Field(
57-
True, description="Enable self-healing functionality"
56+
self_heal: Optional[bool] = Field(
57+
True, alias="selfHeal", description="Enable self-healing functionality"
5858
)
5959

6060
class Config:

stagehand/page.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ async def goto(
5252
if timeout is not None:
5353
options["timeout"] = timeout
5454
if wait_until is not None:
55+
options["wait_until"] = wait_until
5556
options["waitUntil"] = wait_until
5657

5758
payload = {"url": url}
@@ -68,18 +69,31 @@ async def act(self, options: Union[str, ActOptions, ObserveResult]) -> ActResult
6869
Execute an AI action via the Stagehand server.
6970
7071
Args:
71-
options (Union[str, ActOptions]): Either a string with the action command or
72-
a Pydantic model encapsulating the action.
73-
See `stagehand.schemas.ActOptions` for details on expected fields.
72+
options (Union[str, ActOptions, ObserveResult]):
73+
- A string with the action command to be executed by the AI
74+
- An ActOptions object encapsulating the action command and optional parameters
75+
- An ObserveResult with selector and method fields for direct execution without LLM
76+
77+
When an ObserveResult with both 'selector' and 'method' fields is provided,
78+
the SDK will directly execute the action against the selector using the method
79+
and arguments provided, bypassing the LLM processing.
7480
7581
Returns:
76-
Any: The result from the Stagehand server's action execution.
82+
ActResult: The result from the Stagehand server's action execution.
7783
"""
84+
# Check if options is an ObserveResult with both selector and method
85+
if isinstance(options, ObserveResult) and hasattr(options, "selector") and hasattr(options, "method"):
86+
# For ObserveResult, we directly pass it to the server which will
87+
# execute the method against the selector
88+
payload = options.model_dump(exclude_none=True)
7889
# Convert string to ActOptions if needed
79-
if isinstance(options, str):
90+
elif isinstance(options, str):
8091
options = ActOptions(action=options)
92+
payload = options.model_dump(exclude_none=True)
93+
# Otherwise, it should be an ActOptions object
94+
else:
95+
payload = options.model_dump(exclude_none=True)
8196

82-
payload = options.model_dump(exclude_none=True)
8397
lock = self._stagehand._get_lock_for_session()
8498
async with lock:
8599
result = await self._stagehand._execute("act", payload)

stagehand/schemas.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ class ActOptions(BaseModel):
2525
Attributes:
2626
action (str): The action command to be executed by the AI.
2727
variables: Optional[Dict[str, str]] = None
28-
modelName: Optional[AvailableModel] = None
29-
slowDomBasedAct: Optional[bool] = None
28+
model_name: Optional[AvailableModel] = None
29+
slow_dom_based_act: Optional[bool] = None
3030
"""
3131

3232
action: str = Field(..., description="The action command to be executed by the AI.")
3333
variables: Optional[Dict[str, str]] = None
34-
modelName: Optional[AvailableModel] = None
35-
slowDomBasedAct: Optional[bool] = None
34+
model_name: Optional[AvailableModel] = Field(None, alias="modelName")
35+
slow_dom_based_act: Optional[bool] = Field(None, alias="slowDomBasedAct")
3636

3737

3838
class ActResult(BaseModel):
@@ -56,25 +56,26 @@ class ExtractOptions(BaseModel):
5656
5757
Attributes:
5858
instruction (str): Instruction specifying what data to extract using AI.
59-
modelName: Optional[AvailableModel] = None
59+
model_name: Optional[AvailableModel] = None
6060
selector: Optional[str] = None
61-
schemaDefinition (Union[Dict[str, Any], Type[BaseModel]]): A JSON schema or Pydantic model that defines the structure of the expected data.
61+
schema_definition (Union[Dict[str, Any], Type[BaseModel]]): A JSON schema or Pydantic model that defines the structure of the expected data.
6262
Note: If passing a Pydantic model, invoke its .model_json_schema() method to ensure the schema is JSON serializable.
63-
useTextExtract: Optional[bool] = None
63+
use_text_extract: Optional[bool] = None
6464
"""
6565

6666
instruction: str = Field(
6767
..., description="Instruction specifying what data to extract using AI."
6868
)
69-
modelName: Optional[AvailableModel] = None
69+
model_name: Optional[AvailableModel] = Field(None, alias="modelName")
7070
selector: Optional[str] = None
71-
# IMPORTANT: If using a Pydantic model for schemaDefinition, please call its .model_json_schema() method
71+
# IMPORTANT: If using a Pydantic model for schema_definition, please call its .model_json_schema() method
7272
# to convert it to a JSON serializable dictionary before sending it with the extract command.
73-
schemaDefinition: Union[Dict[str, Any], Type[BaseModel]] = Field(
73+
schema_definition: Union[Dict[str, Any], Type[BaseModel]] = Field(
7474
default=DEFAULT_EXTRACT_SCHEMA,
7575
description="A JSON schema or Pydantic model that defines the structure of the expected data.",
76+
alias="schemaDefinition",
7677
)
77-
useTextExtract: Optional[bool] = True
78+
use_text_extract: Optional[bool] = Field(True, alias="useTextExtract")
7879

7980
class Config:
8081
arbitrary_types_allowed = True
@@ -108,19 +109,19 @@ class ObserveOptions(BaseModel):
108109
109110
Attributes:
110111
instruction (str): Instruction detailing what the AI should observe.
111-
modelName: Optional[AvailableModel] = None
112-
onlyVisible: Optional[bool] = None
113-
returnAction: Optional[bool] = None
114-
drawOverlay: Optional[bool] = None
112+
model_name: Optional[AvailableModel] = None
113+
only_visible: Optional[bool] = None
114+
return_action: Optional[bool] = None
115+
draw_overlay: Optional[bool] = None
115116
"""
116117

117118
instruction: str = Field(
118119
..., description="Instruction detailing what the AI should observe."
119120
)
120-
onlyVisible: Optional[bool] = False
121-
modelName: Optional[AvailableModel] = None
122-
returnAction: Optional[bool] = None
123-
drawOverlay: Optional[bool] = None
121+
only_visible: Optional[bool] = Field(False, alias="onlyVisible")
122+
model_name: Optional[AvailableModel] = Field(None, alias="modelName")
123+
return_action: Optional[bool] = Field(None, alias="returnAction")
124+
draw_overlay: Optional[bool] = Field(None, alias="drawOverlay")
124125

125126

126127
class ObserveResult(BaseModel):
@@ -132,7 +133,7 @@ class ObserveResult(BaseModel):
132133
description: str = Field(
133134
..., description="The description of the observed element."
134135
)
135-
backendNodeId: Optional[int] = None
136+
backend_node_id: Optional[int] = Field(None, alias="backendNodeId")
136137
method: Optional[str] = None
137138
arguments: Optional[List[str]] = None
138139

0 commit comments

Comments
 (0)