diff --git a/.changeset/dog-of-interesting-energy.md b/.changeset/dog-of-interesting-energy.md new file mode 100644 index 0000000..002c28b --- /dev/null +++ b/.changeset/dog-of-interesting-energy.md @@ -0,0 +1,5 @@ +--- +"stagehand": patch +--- + +Fix stagehand.metrics on env:BROWSERBASE diff --git a/pyproject.toml b/pyproject.toml index f894f9f..2844514 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,19 @@ description = "Python SDK for Stagehand" readme = "README.md" classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent",] requires-python = ">=3.9" -dependencies = [ "httpx>=0.24.0", "python-dotenv>=1.0.0", "pydantic>=1.10.0", "playwright>=1.42.1", "requests>=2.31.0", "browserbase>=1.4.0", "rich>=13.7.0", "openai>=1.83.0", "anthropic>=0.51.0", "litellm>=1.72.0",] +dependencies = [ + "httpx>=0.24.0", + "python-dotenv>=1.0.0", + "pydantic>=1.10.0", + "playwright>=1.42.1", + "requests>=2.31.0", + "browserbase>=1.4.0", + "rich>=13.7.0", + "openai>=1.83.0", + "anthropic>=0.51.0", + "litellm>=1.72.0", + "nest-asyncio>=1.6.0", +] [[project.authors]] name = "Browserbase, Inc." email = "support@browserbase.com" diff --git a/stagehand/api.py b/stagehand/api.py index 1773de5..1df2620 100644 --- a/stagehand/api.py +++ b/stagehand/api.py @@ -3,9 +3,10 @@ from importlib.metadata import PackageNotFoundError, version from typing import Any +from .metrics import StagehandMetrics from .utils import convert_dict_keys_to_camel_case -__all__ = ["_create_session", "_execute"] +__all__ = ["_create_session", "_execute", "_get_replay_metrics"] async def _create_session(self): @@ -177,3 +178,89 @@ async def _execute(self, method: str, payload: dict[str, Any]) -> Any: except Exception as e: self.logger.error(f"[EXCEPTION] {str(e)}") raise + + +async def _get_replay_metrics(self): + """ + Fetch replay metrics from the API and parse them into StagehandMetrics. + """ + + if not self.session_id: + raise ValueError("session_id is required to fetch metrics.") + + headers = { + "x-bb-api-key": self.browserbase_api_key, + "x-bb-project-id": self.browserbase_project_id, + "Content-Type": "application/json", + } + + try: + response = await self._client.get( + f"{self.api_url}/sessions/{self.session_id}/replay", + headers=headers, + ) + + if response.status_code != 200: + error_text = ( + await response.aread() if hasattr(response, "aread") else response.text + ) + self.logger.error( + f"[HTTP ERROR] Failed to fetch metrics. Status {response.status_code}: {error_text}" + ) + raise RuntimeError( + f"Failed to fetch metrics with status {response.status_code}: {error_text}" + ) + + data = response.json() + + if not data.get("success"): + raise RuntimeError( + f"Failed to fetch metrics: {data.get('error', 'Unknown error')}" + ) + + # Parse the API data into StagehandMetrics format + api_data = data.get("data", {}) + metrics = StagehandMetrics() + + # Parse pages and their actions + pages = api_data.get("pages", []) + for page in pages: + actions = page.get("actions", []) + for action in actions: + # Get method name and token usage + method = action.get("method", "").lower() + token_usage = action.get("tokenUsage", {}) + + if token_usage: + input_tokens = token_usage.get("inputTokens", 0) + output_tokens = token_usage.get("outputTokens", 0) + time_ms = token_usage.get("timeMs", 0) + + # Map method to metrics fields + if method == "act": + metrics.act_prompt_tokens += input_tokens + metrics.act_completion_tokens += output_tokens + metrics.act_inference_time_ms += time_ms + elif method == "extract": + metrics.extract_prompt_tokens += input_tokens + metrics.extract_completion_tokens += output_tokens + metrics.extract_inference_time_ms += time_ms + elif method == "observe": + metrics.observe_prompt_tokens += input_tokens + metrics.observe_completion_tokens += output_tokens + metrics.observe_inference_time_ms += time_ms + elif method == "agent": + metrics.agent_prompt_tokens += input_tokens + metrics.agent_completion_tokens += output_tokens + metrics.agent_inference_time_ms += time_ms + + # Always update totals for any method with token usage + metrics.total_prompt_tokens += input_tokens + metrics.total_completion_tokens += output_tokens + metrics.total_inference_time_ms += time_ms + + return metrics + + except Exception as e: + self.logger.error(f"[EXCEPTION] Error fetching replay metrics: {str(e)}") + raise diff --git a/stagehand/main.py b/stagehand/main.py index 45ee8c2..4a201ad 100644 --- a/stagehand/main.py +++ b/stagehand/main.py @@ -7,6 +7,7 @@ from typing import Any, Optional import httpx +import nest_asyncio from dotenv import load_dotenv from playwright.async_api import ( BrowserContext, @@ -16,7 +17,7 @@ from playwright.async_api import Page as PlaywrightPage from .agent import Agent -from .api import _create_session, _execute +from .api import _create_session, _execute, _get_replay_metrics from .browser import ( cleanup_browser_resources, connect_browserbase_browser, @@ -206,7 +207,7 @@ def __init__( ) # Initialize metrics tracking - self.metrics = StagehandMetrics() + self._local_metrics = StagehandMetrics() # Internal storage for local metrics self._inference_start_time = 0 # To track inference time # Validate env @@ -372,26 +373,26 @@ def update_metrics( inference_time_ms: Time taken for inference in milliseconds """ if function_name == StagehandFunctionName.ACT: - self.metrics.act_prompt_tokens += prompt_tokens - self.metrics.act_completion_tokens += completion_tokens - self.metrics.act_inference_time_ms += inference_time_ms + self._local_metrics.act_prompt_tokens += prompt_tokens + self._local_metrics.act_completion_tokens += completion_tokens + self._local_metrics.act_inference_time_ms += inference_time_ms elif function_name == StagehandFunctionName.EXTRACT: - self.metrics.extract_prompt_tokens += prompt_tokens - self.metrics.extract_completion_tokens += completion_tokens - self.metrics.extract_inference_time_ms += inference_time_ms + self._local_metrics.extract_prompt_tokens += prompt_tokens + self._local_metrics.extract_completion_tokens += completion_tokens + self._local_metrics.extract_inference_time_ms += inference_time_ms elif function_name == StagehandFunctionName.OBSERVE: - self.metrics.observe_prompt_tokens += prompt_tokens - self.metrics.observe_completion_tokens += completion_tokens - self.metrics.observe_inference_time_ms += inference_time_ms + self._local_metrics.observe_prompt_tokens += prompt_tokens + self._local_metrics.observe_completion_tokens += completion_tokens + self._local_metrics.observe_inference_time_ms += inference_time_ms elif function_name == StagehandFunctionName.AGENT: - self.metrics.agent_prompt_tokens += prompt_tokens - self.metrics.agent_completion_tokens += completion_tokens - self.metrics.agent_inference_time_ms += inference_time_ms + self._local_metrics.agent_prompt_tokens += prompt_tokens + self._local_metrics.agent_completion_tokens += completion_tokens + self._local_metrics.agent_inference_time_ms += inference_time_ms # Always update totals - self.metrics.total_prompt_tokens += prompt_tokens - self.metrics.total_completion_tokens += completion_tokens - self.metrics.total_inference_time_ms += inference_time_ms + self._local_metrics.total_prompt_tokens += prompt_tokens + self._local_metrics.total_completion_tokens += completion_tokens + self._local_metrics.total_inference_time_ms += inference_time_ms def update_metrics_from_response( self, @@ -426,9 +427,9 @@ def update_metrics_from_response( f"{completion_tokens} completion tokens, {time_ms}ms" ) self.logger.debug( - f"Total metrics: {self.metrics.total_prompt_tokens} prompt tokens, " - f"{self.metrics.total_completion_tokens} completion tokens, " - f"{self.metrics.total_inference_time_ms}ms" + f"Total metrics: {self._local_metrics.total_prompt_tokens} prompt tokens, " + f"{self._local_metrics.total_completion_tokens} completion tokens, " + f"{self._local_metrics.total_inference_time_ms}ms" ) else: # Try to extract from _hidden_params or other locations @@ -736,7 +737,50 @@ def page(self) -> Optional[StagehandPage]: return self._live_page_proxy + def __getattribute__(self, name): + """ + Intercept access to 'metrics' to fetch from API when use_api=True. + """ + if name == "metrics": + use_api = ( + object.__getattribute__(self, "use_api") + if hasattr(self, "use_api") + else False + ) + + if use_api: + # Need to fetch from API + try: + # Get the _get_replay_metrics method + get_replay_metrics = object.__getattribute__( + self, "_get_replay_metrics" + ) + + # Try to get current event loop + try: + asyncio.get_running_loop() + # We're in an async context, need to handle this carefully + # Create a new task and wait for it + nest_asyncio.apply() + return asyncio.run(get_replay_metrics()) + except RuntimeError: + # No event loop running, we can use asyncio.run directly + return asyncio.run(get_replay_metrics()) + except Exception as e: + # Log error and return empty metrics + logger = object.__getattribute__(self, "logger") + if logger: + logger.error(f"Failed to fetch metrics from API: {str(e)}") + return StagehandMetrics() + else: + # Return local metrics + return object.__getattribute__(self, "_local_metrics") + + # For all other attributes, use normal behavior + return object.__getattribute__(self, name) + # Bind the imported API methods to the Stagehand class Stagehand._create_session = _create_session Stagehand._execute = _execute +Stagehand._get_replay_metrics = _get_replay_metrics