Skip to content

Commit 5f29b21

Browse files
Adding initial support for LLM Customisation (#132)
* Adding support for LLM Customisation * Added back model_api_key config * Fixed README * Update README.md Co-authored-by: Miguel <[email protected]> * Update stagehand/llm/client.py --------- Co-authored-by: Miguel <[email protected]>
1 parent 6586996 commit 5f29b21

File tree

6 files changed

+42
-9
lines changed

6 files changed

+42
-9
lines changed

stagehand/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class StagehandConfig(BaseModel):
2020
browserbase_session_id (Optional[str]): Session ID for resuming Browserbase sessions.
2121
model_name (Optional[str]): Name of the model to use.
2222
model_api_key (Optional[str]): Model API key.
23+
model_client_options (Optional[dict[str, Any]]): Options for the model client.
2324
logger (Optional[Callable[[Any], None]]): Custom logging function.
2425
verbose (Optional[int]): Verbosity level for logs (1=minimal, 2=medium, 3=detailed).
2526
use_rich_logging (bool): Whether to use Rich for colorized logging.
@@ -50,6 +51,9 @@ class StagehandConfig(BaseModel):
5051
model_api_key: Optional[str] = Field(
5152
None, alias="modelApiKey", description="Model API key"
5253
)
54+
model_client_options: Optional[dict[str, Any]] = Field(
55+
None, alias="modelClientOptions", description="Configuration options for the language model client (i.e. apiKey, baseURL)",
56+
)
5357
verbose: Optional[int] = Field(
5458
1,
5559
description="Verbosity level for logs: 0=minimal (ERROR), 1=medium (INFO), 2=detailed (DEBUG)",

stagehand/llm/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(
5454
setattr(litellm, key, value)
5555
self.logger.debug(f"Set global litellm.{key}", category="llm")
5656
# Handle common aliases or expected config names if necessary
57-
elif key == "api_base": # Example: map api_base if needed
57+
elif key == "api_base" or key == "baseURL":
5858
litellm.api_base = value
5959
self.logger.debug(
6060
f"Set global litellm.api_base to {value}", category="llm"

stagehand/main.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,11 @@ def __init__(
6868

6969
# Handle non-config parameters
7070
self.api_url = self.config.api_url
71-
self.model_api_key = self.config.model_api_key or os.getenv("MODEL_API_KEY")
71+
72+
# Handle model-related settings
73+
self.model_client_options = self.config.model_client_options or {}
74+
self.model_api_key = self.config.model_api_key or self.model_client_options.get("apiKey") or os.getenv("MODEL_API_KEY")
75+
7276
self.model_name = self.config.model_name
7377

7478
# Extract frequently used values from config for convenience
@@ -89,11 +93,6 @@ def __init__(
8993
self.config.local_browser_launch_options or {}
9094
)
9195

92-
# Handle model-related settings
93-
self.model_client_options = {}
94-
if self.model_api_key and "apiKey" not in self.model_client_options:
95-
self.model_client_options["apiKey"] = self.model_api_key
96-
9796
# Handle browserbase session create params
9897
self.browserbase_session_create_params = make_serializable(
9998
self.config.browserbase_session_create_params

tests/unit/llm/test_llm_integration.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def test_llm_client_with_custom_options(self):
4040
api_key="test-key",
4141
default_model="gpt-4o-mini",
4242
stagehand_logger=StagehandLogger(),
43+
api_base="https://test-api-base.com",
4344
)
4445

4546
assert client.default_model == "gpt-4o-mini"

tests/unit/test_client_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ async def mock_client(self):
1919
browserbase_session_id="test-session-123",
2020
api_key="test-api-key",
2121
project_id="test-project-id",
22-
model_api_key="test-model-api-key",
22+
model_client_options={"apiKey": "test-model-api-key"}
2323
)
2424
return client
2525

tests/unit/test_client_initialization.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_init_with_direct_params(self):
2323
browserbase_session_id="test-session",
2424
api_key="test-api-key",
2525
project_id="test-project-id",
26-
model_api_key="test-model-api-key",
26+
model_client_options={"apiKey": "test-model-api-key"},
2727
verbose=2,
2828
)
2929

@@ -203,3 +203,32 @@ async def mock_create_session():
203203
# Call _create_session and expect error
204204
with pytest.raises(RuntimeError, match="Invalid response format"):
205205
await client._create_session()
206+
207+
@mock.patch.dict(os.environ, {"MODEL_API_KEY": "test-model-api-key"}, clear=True)
208+
def test_init_with_model_api_key_in_env(self):
209+
config = StagehandConfig(env="LOCAL")
210+
client = Stagehand(config=config)
211+
assert client.model_api_key == "test-model-api-key"
212+
213+
def test_init_with_custom_llm(self):
214+
config = StagehandConfig(
215+
env="LOCAL",
216+
model_client_options={"apiKey": "custom-llm-key", "baseURL": "https://custom-llm.com"}
217+
)
218+
client = Stagehand(config=config)
219+
assert client.model_api_key == "custom-llm-key"
220+
assert client.model_client_options["apiKey"] == "custom-llm-key"
221+
assert client.model_client_options["baseURL"] == "https://custom-llm.com"
222+
223+
def test_init_with_custom_llm_override(self):
224+
config = StagehandConfig(
225+
env="LOCAL",
226+
model_client_options={"apiKey": "custom-llm-key", "baseURL": "https://custom-llm.com"}
227+
)
228+
client = Stagehand(
229+
config=config,
230+
model_client_options={"apiKey": "override-llm-key", "baseURL": "https://override-llm.com"}
231+
)
232+
assert client.model_api_key == "override-llm-key"
233+
assert client.model_client_options["apiKey"] == "override-llm-key"
234+
assert client.model_client_options["baseURL"] == "https://override-llm.com"

0 commit comments

Comments
 (0)