|
1 | 1 | import os
|
2 |
| -from typing import Any, Callable, Literal, Optional |
| 2 | +from typing import Any, Callable, Literal, Optional, Union |
3 | 3 |
|
4 | 4 | from browserbase.types import SessionCreateParams as BrowserbaseSessionCreateParams
|
5 |
| -from pydantic import BaseModel, ConfigDict, Field |
| 5 | +from pydantic import BaseModel, ConfigDict, Field, field_validator |
6 | 6 |
|
7 | 7 | from stagehand.schemas import AvailableModel
|
8 | 8 |
|
@@ -65,7 +65,7 @@ class StagehandConfig(BaseModel):
|
65 | 65 | alias="domSettleTimeoutMs",
|
66 | 66 | description="Timeout for DOM to settle (in ms)",
|
67 | 67 | )
|
68 |
| - browserbase_session_create_params: Optional[BrowserbaseSessionCreateParams] = Field( |
| 68 | + browserbase_session_create_params: Optional[Union[BrowserbaseSessionCreateParams, dict[str, Any]]] = Field( |
69 | 69 | None,
|
70 | 70 | alias="browserbaseSessionCreateParams",
|
71 | 71 | description="Browserbase session create params",
|
@@ -111,6 +111,17 @@ class StagehandConfig(BaseModel):
|
111 | 111 | )
|
112 | 112 |
|
113 | 113 | model_config = ConfigDict(populate_by_name=True)
|
| 114 | + |
| 115 | + @field_validator('browserbase_session_create_params', mode='before') |
| 116 | + @classmethod |
| 117 | + def validate_browserbase_params(cls, v, info): |
| 118 | + """Validate and convert browserbase session create params.""" |
| 119 | + if isinstance(v, dict) and 'project_id' not in v: |
| 120 | + values = info.data |
| 121 | + project_id = values.get('project_id') or values.get('projectId') |
| 122 | + if project_id: |
| 123 | + v = {**v, 'project_id': project_id} |
| 124 | + return v |
114 | 125 |
|
115 | 126 | def with_overrides(self, **overrides) -> "StagehandConfig":
|
116 | 127 | """
|
|
0 commit comments