Skip to content

Commit 4de1858

Browse files
committed
Fix stagehand.metrics
1 parent 2e3eb1a commit 4de1858

File tree

4 files changed

+171
-22
lines changed

4 files changed

+171
-22
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"stagehand": patch
3+
---
4+
5+
Fix stagehand.metrics on env:BROWSERBASE

pyproject.toml

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,19 @@ description = "Python SDK for Stagehand"
99
readme = "README.md"
1010
classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent",]
1111
requires-python = ">=3.9"
12-
dependencies = [ "httpx>=0.24.0", "python-dotenv>=1.0.0", "pydantic>=1.10.0", "playwright>=1.42.1", "requests>=2.31.0", "browserbase>=1.4.0", "rich>=13.7.0", "openai>=1.83.0", "anthropic>=0.51.0", "litellm>=1.72.0",]
12+
dependencies = [
13+
"httpx>=0.24.0",
14+
"python-dotenv>=1.0.0",
15+
"pydantic>=1.10.0",
16+
"playwright>=1.42.1",
17+
"requests>=2.31.0",
18+
"browserbase>=1.4.0",
19+
"rich>=13.7.0",
20+
"openai>=1.83.0",
21+
"anthropic>=0.51.0",
22+
"litellm>=1.72.0",
23+
"nest-asyncio>=1.6.0",
24+
]
1325
[[project.authors]]
1426
name = "Browserbase, Inc."
1527

stagehand/api.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from .utils import convert_dict_keys_to_camel_case
77

8-
__all__ = ["_create_session", "_execute"]
8+
__all__ = ["_create_session", "_execute", "_get_replay_metrics"]
99

1010

1111
async def _create_session(self):
@@ -177,3 +177,91 @@ async def _execute(self, method: str, payload: dict[str, Any]) -> Any:
177177
except Exception as e:
178178
self.logger.error(f"[EXCEPTION] {str(e)}")
179179
raise
180+
181+
182+
async def _get_replay_metrics(self):
183+
"""
184+
Fetch replay metrics from the API endpoint /sessions/:id/replay and parse them
185+
into StagehandMetrics format.
186+
"""
187+
from .metrics import StagehandMetrics
188+
189+
if not self.session_id:
190+
raise ValueError("session_id is required to fetch metrics.")
191+
192+
headers = {
193+
"x-bb-api-key": self.browserbase_api_key,
194+
"x-bb-project-id": self.browserbase_project_id,
195+
"Content-Type": "application/json",
196+
}
197+
198+
try:
199+
response = await self._client.get(
200+
f"{self.api_url}/sessions/{self.session_id}/replay",
201+
headers=headers,
202+
)
203+
204+
if response.status_code != 200:
205+
error_text = (
206+
await response.aread() if hasattr(response, "aread") else response.text
207+
)
208+
self.logger.error(
209+
f"[HTTP ERROR] Failed to fetch metrics. Status {response.status_code}: {error_text}"
210+
)
211+
raise RuntimeError(
212+
f"Failed to fetch metrics with status {response.status_code}: {error_text}"
213+
)
214+
215+
data = response.json()
216+
217+
if not data.get("success"):
218+
raise RuntimeError(
219+
f"Failed to fetch metrics: {data.get('error', 'Unknown error')}"
220+
)
221+
222+
# Parse the API data into StagehandMetrics format
223+
api_data = data.get("data", {})
224+
metrics = StagehandMetrics()
225+
226+
# Parse pages and their actions
227+
pages = api_data.get("pages", [])
228+
for page in pages:
229+
actions = page.get("actions", [])
230+
for action in actions:
231+
# Get method name and token usage
232+
method = action.get("method", "").lower()
233+
token_usage = action.get("tokenUsage", {})
234+
235+
if token_usage:
236+
input_tokens = token_usage.get("inputTokens", 0)
237+
output_tokens = token_usage.get("outputTokens", 0)
238+
time_ms = token_usage.get("timeMs", 0)
239+
240+
# Map method to metrics fields
241+
if method == "act":
242+
metrics.act_prompt_tokens += input_tokens
243+
metrics.act_completion_tokens += output_tokens
244+
metrics.act_inference_time_ms += time_ms
245+
elif method == "extract":
246+
metrics.extract_prompt_tokens += input_tokens
247+
metrics.extract_completion_tokens += output_tokens
248+
metrics.extract_inference_time_ms += time_ms
249+
elif method == "observe":
250+
metrics.observe_prompt_tokens += input_tokens
251+
metrics.observe_completion_tokens += output_tokens
252+
metrics.observe_inference_time_ms += time_ms
253+
elif method == "agent":
254+
metrics.agent_prompt_tokens += input_tokens
255+
metrics.agent_completion_tokens += output_tokens
256+
metrics.agent_inference_time_ms += time_ms
257+
258+
# Always update totals for any method with token usage
259+
metrics.total_prompt_tokens += input_tokens
260+
metrics.total_completion_tokens += output_tokens
261+
metrics.total_inference_time_ms += time_ms
262+
263+
return metrics
264+
265+
except Exception as e:
266+
self.logger.error(f"[EXCEPTION] Error fetching replay metrics: {str(e)}")
267+
raise

stagehand/main.py

Lines changed: 64 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any, Optional
88

99
import httpx
10+
import nest_asyncio
1011
from dotenv import load_dotenv
1112
from playwright.async_api import (
1213
BrowserContext,
@@ -16,7 +17,7 @@
1617
from playwright.async_api import Page as PlaywrightPage
1718

1819
from .agent import Agent
19-
from .api import _create_session, _execute
20+
from .api import _create_session, _execute, _get_replay_metrics
2021
from .browser import (
2122
cleanup_browser_resources,
2223
connect_browserbase_browser,
@@ -206,7 +207,7 @@ def __init__(
206207
)
207208

208209
# Initialize metrics tracking
209-
self.metrics = StagehandMetrics()
210+
self._local_metrics = StagehandMetrics() # Internal storage for local metrics
210211
self._inference_start_time = 0 # To track inference time
211212

212213
# Validate env
@@ -372,26 +373,26 @@ def update_metrics(
372373
inference_time_ms: Time taken for inference in milliseconds
373374
"""
374375
if function_name == StagehandFunctionName.ACT:
375-
self.metrics.act_prompt_tokens += prompt_tokens
376-
self.metrics.act_completion_tokens += completion_tokens
377-
self.metrics.act_inference_time_ms += inference_time_ms
376+
self._local_metrics.act_prompt_tokens += prompt_tokens
377+
self._local_metrics.act_completion_tokens += completion_tokens
378+
self._local_metrics.act_inference_time_ms += inference_time_ms
378379
elif function_name == StagehandFunctionName.EXTRACT:
379-
self.metrics.extract_prompt_tokens += prompt_tokens
380-
self.metrics.extract_completion_tokens += completion_tokens
381-
self.metrics.extract_inference_time_ms += inference_time_ms
380+
self._local_metrics.extract_prompt_tokens += prompt_tokens
381+
self._local_metrics.extract_completion_tokens += completion_tokens
382+
self._local_metrics.extract_inference_time_ms += inference_time_ms
382383
elif function_name == StagehandFunctionName.OBSERVE:
383-
self.metrics.observe_prompt_tokens += prompt_tokens
384-
self.metrics.observe_completion_tokens += completion_tokens
385-
self.metrics.observe_inference_time_ms += inference_time_ms
384+
self._local_metrics.observe_prompt_tokens += prompt_tokens
385+
self._local_metrics.observe_completion_tokens += completion_tokens
386+
self._local_metrics.observe_inference_time_ms += inference_time_ms
386387
elif function_name == StagehandFunctionName.AGENT:
387-
self.metrics.agent_prompt_tokens += prompt_tokens
388-
self.metrics.agent_completion_tokens += completion_tokens
389-
self.metrics.agent_inference_time_ms += inference_time_ms
388+
self._local_metrics.agent_prompt_tokens += prompt_tokens
389+
self._local_metrics.agent_completion_tokens += completion_tokens
390+
self._local_metrics.agent_inference_time_ms += inference_time_ms
390391

391392
# Always update totals
392-
self.metrics.total_prompt_tokens += prompt_tokens
393-
self.metrics.total_completion_tokens += completion_tokens
394-
self.metrics.total_inference_time_ms += inference_time_ms
393+
self._local_metrics.total_prompt_tokens += prompt_tokens
394+
self._local_metrics.total_completion_tokens += completion_tokens
395+
self._local_metrics.total_inference_time_ms += inference_time_ms
395396

396397
def update_metrics_from_response(
397398
self,
@@ -426,9 +427,9 @@ def update_metrics_from_response(
426427
f"{completion_tokens} completion tokens, {time_ms}ms"
427428
)
428429
self.logger.debug(
429-
f"Total metrics: {self.metrics.total_prompt_tokens} prompt tokens, "
430-
f"{self.metrics.total_completion_tokens} completion tokens, "
431-
f"{self.metrics.total_inference_time_ms}ms"
430+
f"Total metrics: {self._local_metrics.total_prompt_tokens} prompt tokens, "
431+
f"{self._local_metrics.total_completion_tokens} completion tokens, "
432+
f"{self._local_metrics.total_inference_time_ms}ms"
432433
)
433434
else:
434435
# Try to extract from _hidden_params or other locations
@@ -736,7 +737,50 @@ def page(self) -> Optional[StagehandPage]:
736737

737738
return self._live_page_proxy
738739

740+
def __getattribute__(self, name):
741+
"""
742+
Intercept access to 'metrics' to fetch from API when use_api=True.
743+
"""
744+
if name == "metrics":
745+
use_api = (
746+
object.__getattribute__(self, "use_api")
747+
if hasattr(self, "use_api")
748+
else False
749+
)
750+
751+
if use_api:
752+
# Need to fetch from API
753+
try:
754+
# Get the _get_replay_metrics method
755+
get_replay_metrics = object.__getattribute__(
756+
self, "_get_replay_metrics"
757+
)
758+
759+
# Try to get current event loop
760+
try:
761+
asyncio.get_running_loop()
762+
# We're in an async context, need to handle this carefully
763+
# Create a new task and wait for it
764+
nest_asyncio.apply()
765+
return asyncio.run(get_replay_metrics())
766+
except RuntimeError:
767+
# No event loop running, we can use asyncio.run directly
768+
return asyncio.run(get_replay_metrics())
769+
except Exception as e:
770+
# Log error and return empty metrics
771+
logger = object.__getattribute__(self, "logger")
772+
if logger:
773+
logger.error(f"Failed to fetch metrics from API: {str(e)}")
774+
return StagehandMetrics()
775+
else:
776+
# Return local metrics
777+
return object.__getattribute__(self, "_local_metrics")
778+
779+
# For all other attributes, use normal behavior
780+
return object.__getattribute__(self, name)
781+
739782

740783
# Bind the imported API methods to the Stagehand class
741784
Stagehand._create_session = _create_session
742785
Stagehand._execute = _execute
786+
Stagehand._get_replay_metrics = _get_replay_metrics

0 commit comments

Comments
 (0)