|
1 | 1 | import os
|
2 |
| -from typing import Any, Callable, Literal, Optional |
| 2 | +from typing import Any, Callable, Literal, Optional, Union |
3 | 3 |
|
4 | 4 | from browserbase.types import SessionCreateParams as BrowserbaseSessionCreateParams
|
5 |
| -from pydantic import BaseModel, ConfigDict, Field |
| 5 | +from pydantic import BaseModel, ConfigDict, Field, field_validator |
6 | 6 |
|
7 | 7 | from stagehand.schemas import AvailableModel
|
8 | 8 |
|
@@ -71,7 +71,7 @@ class StagehandConfig(BaseModel):
|
71 | 71 | alias="domSettleTimeoutMs",
|
72 | 72 | description="Timeout for DOM to settle (in ms)",
|
73 | 73 | )
|
74 |
| - browserbase_session_create_params: Optional[BrowserbaseSessionCreateParams] = Field( |
| 74 | + browserbase_session_create_params: Optional[Union[BrowserbaseSessionCreateParams, dict[str, Any]]] = Field( |
75 | 75 | None,
|
76 | 76 | alias="browserbaseSessionCreateParams",
|
77 | 77 | description="Browserbase session create params",
|
@@ -117,6 +117,17 @@ class StagehandConfig(BaseModel):
|
117 | 117 | )
|
118 | 118 |
|
119 | 119 | model_config = ConfigDict(populate_by_name=True)
|
| 120 | + |
| 121 | + @field_validator('browserbase_session_create_params', mode='before') |
| 122 | + @classmethod |
| 123 | + def validate_browserbase_params(cls, v, info): |
| 124 | + """Validate and convert browserbase session create params.""" |
| 125 | + if isinstance(v, dict) and 'project_id' not in v: |
| 126 | + values = info.data |
| 127 | + project_id = values.get('project_id') or values.get('projectId') |
| 128 | + if project_id: |
| 129 | + v = {**v, 'project_id': project_id} |
| 130 | + return v |
120 | 131 |
|
121 | 132 | def with_overrides(self, **overrides) -> "StagehandConfig":
|
122 | 133 | """
|
|
0 commit comments