Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions stagehand/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class StagehandConfig(BaseModel):
browserbase_session_id (Optional[str]): Session ID for resuming Browserbase sessions.
model_name (Optional[str]): Name of the model to use.
model_api_key (Optional[str]): Model API key.
model_client_options (Optional[dict[str, Any]]): Options for the model client.
logger (Optional[Callable[[Any], None]]): Custom logging function.
verbose (Optional[int]): Verbosity level for logs (1=minimal, 2=medium, 3=detailed).
use_rich_logging (bool): Whether to use Rich for colorized logging.
Expand Down Expand Up @@ -50,6 +51,11 @@ class StagehandConfig(BaseModel):
model_api_key: Optional[str] = Field(
None, alias="modelApiKey", description="Model API key"
)
model_client_options: Optional[dict[str, Any]] = Field(
None,
alias="modelClientOptions",
description="Configuration options for the language model client (i.e. api_base)",
)
verbose: Optional[int] = Field(
1,
description="Verbosity level for logs: 0=minimal (ERROR), 1=medium (INFO), 2=detailed (DEBUG)",
Expand Down
2 changes: 1 addition & 1 deletion stagehand/llm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
setattr(litellm, key, value)
self.logger.debug(f"Set global litellm.{key}", category="llm")
# Handle common aliases or expected config names if necessary
elif key == "api_base": # Example: map api_base if needed
elif key == "api_base" or key == "baseURL":
litellm.api_base = value
self.logger.debug(
f"Set global litellm.api_base to {value}", category="llm"
Expand Down
10 changes: 5 additions & 5 deletions stagehand/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,11 @@ def __init__(

# Handle non-config parameters
self.api_url = self.config.api_url

# Handle model-related settings
self.model_client_options = self.config.model_client_options or {}
self.model_api_key = self.config.model_api_key or os.getenv("MODEL_API_KEY")

self.model_name = self.config.model_name

# Extract frequently used values from config for convenience
Expand All @@ -181,11 +185,7 @@ def __init__(
self.local_browser_launch_options = (
self.config.local_browser_launch_options or {}
)

# Handle model-related settings
self.model_client_options = {}
if self.model_api_key and "apiKey" not in self.model_client_options:
self.model_client_options["apiKey"] = self.model_api_key
self.model_client_options["apiKey"] = self.model_api_key

# Handle browserbase session create params
self.browserbase_session_create_params = make_serializable(
Expand Down
1 change: 1 addition & 0 deletions tests/unit/llm/test_llm_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_llm_client_with_custom_options(self):
api_key="test-key",
default_model="gpt-4o-mini",
stagehand_logger=StagehandLogger(),
api_base="https://test-api-base.com",
)

assert client.default_model == "gpt-4o-mini"
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_client_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async def mock_client(self):
browserbase_session_id="test-session-123",
api_key="test-api-key",
project_id="test-project-id",
model_api_key="test-model-api-key",
model_client_options={"apiKey": "test-model-api-key"}
)
return client

Expand Down
31 changes: 30 additions & 1 deletion tests/unit/test_client_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_init_with_direct_params(self):
browserbase_session_id="test-session",
api_key="test-api-key",
project_id="test-project-id",
model_api_key="test-model-api-key",
model_client_options={"apiKey": "test-model-api-key"},
verbose=2,
)

Expand Down Expand Up @@ -228,3 +228,32 @@ async def mock_create_session():
# Call _create_session and expect error
with pytest.raises(RuntimeError, match="Invalid response format"):
await client._create_session()

@mock.patch.dict(os.environ, {"MODEL_API_KEY": "test-model-api-key"}, clear=True)
def test_init_with_model_api_key_in_env(self):
config = StagehandConfig(env="LOCAL")
client = Stagehand(config=config)
assert client.model_api_key == "test-model-api-key"

def test_init_with_custom_llm(self):
config = StagehandConfig(
env="LOCAL",
model_client_options={"apiKey": "custom-llm-key", "baseURL": "https://custom-llm.com"}
)
client = Stagehand(config=config)
assert client.model_api_key == "custom-llm-key"
assert client.model_client_options["apiKey"] == "custom-llm-key"
assert client.model_client_options["baseURL"] == "https://custom-llm.com"

def test_init_with_custom_llm_override(self):
config = StagehandConfig(
env="LOCAL",
model_client_options={"apiKey": "custom-llm-key", "baseURL": "https://custom-llm.com"}
)
client = Stagehand(
config=config,
model_client_options={"apiKey": "override-llm-key", "baseURL": "https://override-llm.com"}
)
assert client.model_api_key == "override-llm-key"
assert client.model_client_options["apiKey"] == "override-llm-key"
assert client.model_client_options["baseURL"] == "https://override-llm.com"
Loading