Skip to content

Commit 436668d

Browse files
committed
fixed snake_case migration
1 parent 7270b64 commit 436668d

File tree

6 files changed

+99
-38
lines changed

6 files changed

+99
-38
lines changed

MANIFEST.in

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
include README.md
2+
include requirements.txt
3+
global-exclude *.pyc
4+
global-exclude __pycache__
5+
global-exclude .DS_Store
6+
global-exclude */node_modules/*

examples/example.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ async def main():
7979

8080
console.print("\n▶️ [highlight] Performing action:[/] search for openai")
8181
await page.act("search for openai")
82+
await page.keyboard.press("Enter")
8283
console.print("✅ [success]Performing Action:[/] Action completed successfully")
8384

8485
console.print("\n▶️ [highlight] Observing page[/] for news button")

stagehand/client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from .config import StagehandConfig
1414
from .page import StagehandPage
15-
from .utils import default_log_handler
15+
from .utils import default_log_handler, convert_dict_keys_to_camel_case
1616

1717
load_dotenv()
1818

@@ -362,6 +362,9 @@ async def _execute(self, method: str, payload: Dict[str, Any]) -> Any:
362362
if hasattr(self, "model_client_options") and self.model_client_options and "modelClientOptions" not in modified_payload:
363363
modified_payload["modelClientOptions"] = self.model_client_options
364364

365+
# Convert snake_case keys to camelCase for the API
366+
modified_payload = convert_dict_keys_to_camel_case(modified_payload)
367+
365368
client = self.httpx_client or httpx.AsyncClient(timeout=self.timeout_settings)
366369
self._log(f"\n==== EXECUTING {method.upper()} ====", level=3)
367370
self._log(

stagehand/page.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,14 @@ async def act(self, options: Union[str, ActOptions, ObserveResult]) -> ActResult
8585
if isinstance(options, ObserveResult) and hasattr(options, "selector") and hasattr(options, "method"):
8686
# For ObserveResult, we directly pass it to the server which will
8787
# execute the method against the selector
88-
payload = options.model_dump(exclude_none=True)
88+
payload = options.model_dump(exclude_none=True, by_alias=True)
8989
# Convert string to ActOptions if needed
9090
elif isinstance(options, str):
9191
options = ActOptions(action=options)
92-
payload = options.model_dump(exclude_none=True)
92+
payload = options.model_dump(exclude_none=True, by_alias=True)
9393
# Otherwise, it should be an ActOptions object
9494
else:
95-
payload = options.model_dump(exclude_none=True)
95+
payload = options.model_dump(exclude_none=True, by_alias=True)
9696

9797
lock = self._stagehand._get_lock_for_session()
9898
async with lock:
@@ -117,7 +117,7 @@ async def observe(self, options: Union[str, ObserveOptions]) -> List[ObserveResu
117117
if isinstance(options, str):
118118
options = ObserveOptions(instruction=options)
119119

120-
payload = options.model_dump(exclude_none=True)
120+
payload = options.model_dump(exclude_none=True, by_alias=True)
121121
lock = self._stagehand._get_lock_for_session()
122122
async with lock:
123123
result = await self._stagehand._execute("observe", payload)
@@ -148,7 +148,7 @@ async def extract(self, options: Union[str, ExtractOptions]) -> ExtractResult:
148148
if isinstance(options, str):
149149
options = ExtractOptions(instruction=options)
150150

151-
payload = options.model_dump(exclude_none=True)
151+
payload = options.model_dump(exclude_none=True, by_alias=True)
152152
lock = self._stagehand._get_lock_for_session()
153153
async with lock:
154154
result = await self._stagehand._execute("extract", payload)

stagehand/schemas.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from enum import Enum
22
from typing import Any, Dict, List, Optional, Type, Union
33

4-
from pydantic import BaseModel, Field
4+
from pydantic import BaseModel, Field, field_serializer
55

66
# Default extraction schema that matches the TypeScript version
77
DEFAULT_EXTRACT_SCHEMA = {
@@ -18,7 +18,18 @@ class AvailableModel(str, Enum):
1818
CLAUDE_3_7_SONNET_LATEST = "claude-3-7-sonnet-latest"
1919

2020

21-
class ActOptions(BaseModel):
21+
class StagehandBaseModel(BaseModel):
22+
"""Base model for all Stagehand models with camelCase conversion support"""
23+
24+
class Config:
25+
populate_by_name = True # Allow accessing fields by their Python name
26+
alias_generator = lambda field_name: ''.join(
27+
[field_name.split('_')[0]] +
28+
[word.capitalize() for word in field_name.split('_')[1:]]
29+
) # snake_case to camelCase
30+
31+
32+
class ActOptions(StagehandBaseModel):
2233
"""
2334
Options for the 'act' command.
2435
@@ -31,11 +42,11 @@ class ActOptions(BaseModel):
3142

3243
action: str = Field(..., description="The action command to be executed by the AI.")
3344
variables: Optional[Dict[str, str]] = None
34-
model_name: Optional[AvailableModel] = Field(None, alias="modelName")
35-
slow_dom_based_act: Optional[bool] = Field(None, alias="slowDomBasedAct")
45+
model_name: Optional[AvailableModel] = None
46+
slow_dom_based_act: Optional[bool] = None
3647

3748

38-
class ActResult(BaseModel):
49+
class ActResult(StagehandBaseModel):
3950
"""
4051
Result of the 'act' command.
4152
@@ -50,7 +61,7 @@ class ActResult(BaseModel):
5061
action: str = Field(..., description="The action command that was executed.")
5162

5263

53-
class ExtractOptions(BaseModel):
64+
class ExtractOptions(StagehandBaseModel):
5465
"""
5566
Options for the 'extract' command.
5667
@@ -66,22 +77,28 @@ class ExtractOptions(BaseModel):
6677
instruction: str = Field(
6778
..., description="Instruction specifying what data to extract using AI."
6879
)
69-
model_name: Optional[AvailableModel] = Field(None, alias="modelName")
80+
model_name: Optional[AvailableModel] = None
7081
selector: Optional[str] = None
7182
# IMPORTANT: If using a Pydantic model for schema_definition, please call its .model_json_schema() method
7283
# to convert it to a JSON serializable dictionary before sending it with the extract command.
7384
schema_definition: Union[Dict[str, Any], Type[BaseModel]] = Field(
7485
default=DEFAULT_EXTRACT_SCHEMA,
7586
description="A JSON schema or Pydantic model that defines the structure of the expected data.",
76-
alias="schemaDefinition",
7787
)
78-
use_text_extract: Optional[bool] = Field(True, alias="useTextExtract")
88+
use_text_extract: Optional[bool] = True
89+
90+
@field_serializer('schema_definition')
91+
def serialize_schema_definition(self, schema_definition: Union[Dict[str, Any], Type[BaseModel]]) -> Dict[str, Any]:
92+
"""Serialize schema_definition to a JSON schema if it's a Pydantic model"""
93+
if isinstance(schema_definition, type) and issubclass(schema_definition, BaseModel):
94+
return schema_definition.model_json_schema()
95+
return schema_definition
7996

8097
class Config:
8198
arbitrary_types_allowed = True
8299

83100

84-
class ExtractResult(BaseModel):
101+
class ExtractResult(StagehandBaseModel):
85102
"""
86103
Result of the 'extract' command.
87104
@@ -103,7 +120,7 @@ def __getitem__(self, key):
103120
return getattr(self, key)
104121

105122

106-
class ObserveOptions(BaseModel):
123+
class ObserveOptions(StagehandBaseModel):
107124
"""
108125
Options for the 'observe' command.
109126
@@ -118,13 +135,13 @@ class ObserveOptions(BaseModel):
118135
instruction: str = Field(
119136
..., description="Instruction detailing what the AI should observe."
120137
)
121-
only_visible: Optional[bool] = Field(False, alias="onlyVisible")
122-
model_name: Optional[AvailableModel] = Field(None, alias="modelName")
123-
return_action: Optional[bool] = Field(None, alias="returnAction")
124-
draw_overlay: Optional[bool] = Field(None, alias="drawOverlay")
138+
only_visible: Optional[bool] = False
139+
model_name: Optional[AvailableModel] = None
140+
return_action: Optional[bool] = None
141+
draw_overlay: Optional[bool] = None
125142

126143

127-
class ObserveResult(BaseModel):
144+
class ObserveResult(StagehandBaseModel):
128145
"""
129146
Result of the 'observe' command.
130147
"""
@@ -133,7 +150,7 @@ class ObserveResult(BaseModel):
133150
description: str = Field(
134151
..., description="The description of the observed element."
135152
)
136-
backend_node_id: Optional[int] = Field(None, alias="backendNodeId")
153+
backend_node_id: Optional[int] = None
137154
method: Optional[str] = None
138155
arguments: Optional[List[str]] = None
139156

stagehand/utils.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,57 @@
1+
import asyncio
12
import logging
3+
from typing import Any, Dict
24

5+
# Setup logging
36
logger = logging.getLogger(__name__)
7+
handler = logging.StreamHandler()
8+
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
9+
logger.addHandler(handler)
410

11+
async def default_log_handler(log_data: Dict[str, Any]) -> None:
12+
"""Default handler for log messages from the Stagehand server."""
13+
level = log_data.get("level", "info").lower()
14+
message = log_data.get("message", "")
515

6-
async def default_log_handler(log_data: dict):
16+
log_method = getattr(logger, level, logger.info)
17+
log_method(message)
18+
19+
20+
def snake_to_camel(snake_str: str) -> str:
721
"""
8-
Default async log handler that shows detailed server logs.
9-
Can be overridden by passing a custom handler to Stagehand's constructor.
22+
Convert a snake_case string to camelCase.
23+
24+
Args:
25+
snake_str: The snake_case string to convert
26+
27+
Returns:
28+
The converted camelCase string
1029
"""
11-
if "type" in log_data:
12-
log_type = log_data["type"]
13-
data = log_data.get("data", {})
30+
components = snake_str.split('_')
31+
return components[0] + ''.join(x.title() for x in components[1:])
32+
1433

15-
if log_type == "system":
16-
logger.info(f"🔧 SYSTEM: {data}")
17-
elif log_type == "log":
18-
logger.info(f"📝 LOG: {data}")
19-
else:
20-
logger.info(f"ℹ️ OTHER [{log_type}]: {data}")
21-
else:
22-
# Fallback for any other format
23-
logger.info(f"🤖 RAW LOG: {log_data}")
34+
def convert_dict_keys_to_camel_case(data: Dict[str, Any]) -> Dict[str, Any]:
35+
"""
36+
Convert all keys in a dictionary from snake_case to camelCase.
37+
Works recursively for nested dictionaries.
38+
39+
Args:
40+
data: Dictionary with snake_case keys
41+
42+
Returns:
43+
Dictionary with camelCase keys
44+
"""
45+
result = {}
46+
47+
for key, value in data.items():
48+
if isinstance(value, dict):
49+
value = convert_dict_keys_to_camel_case(value)
50+
elif isinstance(value, list):
51+
value = [convert_dict_keys_to_camel_case(item) if isinstance(item, dict) else item for item in value]
52+
53+
# Convert snake_case key to camelCase
54+
camel_key = snake_to_camel(key)
55+
result[camel_key] = value
56+
57+
return result

0 commit comments

Comments
 (0)