From adc89e2e5df264ef7a7e8b17f4971a1086212099 Mon Sep 17 00:00:00 2001 From: Derek Meegan Date: Thu, 18 Sep 2025 14:13:52 -0700 Subject: [PATCH 1/4] make litellm async --- stagehand/handlers/extract_handler.py | 2 +- stagehand/handlers/observe_handler.py | 2 +- stagehand/llm/client.py | 14 +-- stagehand/llm/inference.py | 10 +- test_async_performance.py | 138 ++++++++++++++++++++++++++ tests/mocks/mock_llm.py | 18 ++-- 6 files changed, 161 insertions(+), 23 deletions(-) create mode 100644 test_async_performance.py diff --git a/stagehand/handlers/extract_handler.py b/stagehand/handlers/extract_handler.py index ec9b9fa5..8af621c9 100644 --- a/stagehand/handlers/extract_handler.py +++ b/stagehand/handlers/extract_handler.py @@ -105,7 +105,7 @@ async def extract( schema = transformed_schema = DefaultExtractSchema # Use inference to call the LLM - extraction_response = extract_inference( + extraction_response = await extract_inference( instruction=instruction, tree_elements=output_string, schema=transformed_schema, diff --git a/stagehand/handlers/observe_handler.py b/stagehand/handlers/observe_handler.py index f0f29181..5acce6d5 100644 --- a/stagehand/handlers/observe_handler.py +++ b/stagehand/handlers/observe_handler.py @@ -74,7 +74,7 @@ async def observe( iframes = tree.get("iframes", []) # use inference to call the llm - observation_response = observe_inference( + observation_response = await observe_inference( instruction=instruction, tree_elements=output_string, llm_client=self.stagehand.llm, diff --git a/stagehand/llm/client.py b/stagehand/llm/client.py index e9fbefe5..06dc9594 100644 --- a/stagehand/llm/client.py +++ b/stagehand/llm/client.py @@ -60,7 +60,7 @@ def __init__( f"Set global litellm.api_base to {value}", category="llm" ) - def create_response( + async def create_response( self, *, messages: list[dict[str, str]], @@ -77,7 +77,7 @@ def create_response( Overrides the default_model if provided. function_name: The name of the Stagehand function calling this method (ACT, OBSERVE, etc.) Used for metrics tracking. - **kwargs: Additional parameters to pass directly to litellm.completion + **kwargs: Additional parameters to pass directly to litellm.acompletion (e.g., temperature, max_tokens, stream=True, specific provider arguments). Returns: @@ -87,7 +87,7 @@ def create_response( Raises: ValueError: If no model is specified (neither default nor in the call). - Exception: Propagates exceptions from litellm.completion. + Exception: Propagates exceptions from litellm.acompletion. """ completion_model = model or self.default_model if not completion_model: @@ -115,7 +115,7 @@ def create_response( filtered_params["temperature"] = 1 self.logger.debug( - f"Calling litellm.completion with model={completion_model} and params: {filtered_params}", + f"Calling litellm.acompletion with model={completion_model} and params: {filtered_params}", category="llm", ) @@ -123,8 +123,8 @@ def create_response( # Start tracking inference time start_time = start_inference_timer() - # Use litellm's completion function - response = litellm.completion(**filtered_params) + # Use litellm's async completion function + response = await litellm.acompletion(**filtered_params) # Calculate inference time inference_time_ms = get_inference_time_ms(start_time) @@ -136,6 +136,6 @@ def create_response( return response except Exception as e: - self.logger.error(f"Error calling litellm.completion: {e}", category="llm") + self.logger.error(f"Error calling litellm.acompletion: {e}", category="llm") # Consider more specific exception handling based on litellm errors raise diff --git a/stagehand/llm/inference.py b/stagehand/llm/inference.py index 24f0de91..b438883b 100644 --- a/stagehand/llm/inference.py +++ b/stagehand/llm/inference.py @@ -21,7 +21,7 @@ # TODO: kwargs -def observe( +async def observe( instruction: str, tree_elements: str, llm_client: Any, @@ -66,7 +66,7 @@ def observe( try: # Call the LLM logger.info("Calling LLM") - response = llm_client.create_response( + response = await llm_client.create_response( model=llm_client.default_model, messages=messages, response_format=ObserveInferenceSchema, @@ -123,7 +123,7 @@ def observe( } -def extract( +async def extract( instruction: str, tree_elements: str, schema: Optional[Union[type[BaseModel], dict]] = None, @@ -177,7 +177,7 @@ def extract( # Call the LLM with appropriate parameters try: - extract_response = llm_client.create_response( + extract_response = await llm_client.create_response( model=llm_client.default_model, messages=extract_messages, response_format=response_format, @@ -227,7 +227,7 @@ def extract( # Call LLM for metadata try: metadata_start_time = time.time() - metadata_response = llm_client.create_response( + metadata_response = await llm_client.create_response( model=llm_client.default_model, messages=metadata_messages, response_format=metadata_schema, diff --git a/test_async_performance.py b/test_async_performance.py new file mode 100644 index 00000000..550549ec --- /dev/null +++ b/test_async_performance.py @@ -0,0 +1,138 @@ +"""Test script to verify async LLM calls are non-blocking""" + +import asyncio +import time +from unittest.mock import AsyncMock, MagicMock +from stagehand.llm.client import LLMClient +from stagehand.llm.inference import observe, extract + + +async def simulate_slow_llm_response(delay=1.0): + """Simulate a slow LLM API response""" + await asyncio.sleep(delay) + return MagicMock( + usage=MagicMock(prompt_tokens=100, completion_tokens=50), + choices=[MagicMock(message=MagicMock(content='{"elements": []}'))] + ) + + +async def test_parallel_execution(): + """Test that multiple LLM calls can run in parallel""" + print("\n🧪 Testing parallel async execution...") + + # Create mock LLM client + mock_logger = MagicMock() + mock_logger.info = MagicMock() + mock_logger.debug = MagicMock() + mock_logger.error = MagicMock() + + llm_client = LLMClient( + stagehand_logger=mock_logger, + default_model="gpt-4o" + ) + + # Mock the async create_response to simulate delay + async def mock_create_response(**kwargs): + return await simulate_slow_llm_response(1.0) + + llm_client.create_response = mock_create_response + + # Measure time for parallel execution + start_time = time.time() + + # Run 3 observe calls in parallel + tasks = [ + observe("Find button 1", "DOM content 1", llm_client, logger=mock_logger), + observe("Find button 2", "DOM content 2", llm_client, logger=mock_logger), + observe("Find button 3", "DOM content 3", llm_client, logger=mock_logger), + ] + + results = await asyncio.gather(*tasks) + parallel_time = time.time() - start_time + + print(f"✅ Parallel execution of 3 calls took: {parallel_time:.2f}s") + print(f" Expected ~1s (running in parallel), not 3s (sequential)") + + # Verify results + assert len(results) == 3 + for i, result in enumerate(results, 1): + assert "elements" in result + print(f" Result {i}: {result}") + + # Test sequential execution for comparison + print("\n🧪 Testing sequential execution for comparison...") + start_time = time.time() + + result1 = await observe("Find button 1", "DOM content 1", llm_client, logger=mock_logger) + result2 = await observe("Find button 2", "DOM content 2", llm_client, logger=mock_logger) + result3 = await observe("Find button 3", "DOM content 3", llm_client, logger=mock_logger) + + sequential_time = time.time() - start_time + print(f"✅ Sequential execution of 3 calls took: {sequential_time:.2f}s") + print(f" Expected ~3s (running sequentially)") + + # Parallel should be significantly faster + assert parallel_time < sequential_time * 0.5, "Parallel execution should be much faster than sequential" + + print(f"\n🎉 Async implementation is working correctly!") + print(f" Parallel speedup: {sequential_time/parallel_time:.2f}x faster") + + +async def test_real_llm_async(): + """Test with real LiteLLM to ensure the async implementation works""" + print("\n🧪 Testing with real LiteLLM (using mock responses)...") + + import litellm + from unittest.mock import patch + + # Mock litellm.acompletion to return test data + async def mock_acompletion(**kwargs): + await asyncio.sleep(0.1) # Small delay to simulate API call + return MagicMock( + usage=MagicMock(prompt_tokens=100, completion_tokens=50), + choices=[MagicMock(message=MagicMock(content='{"elements": [{"selector": "#test"}]}'))] + ) + + with patch('litellm.acompletion', new=mock_acompletion): + mock_logger = MagicMock() + mock_logger.info = MagicMock() + mock_logger.debug = MagicMock() + mock_logger.error = MagicMock() + + llm_client = LLMClient( + stagehand_logger=mock_logger, + default_model="gpt-4o" + ) + + # Test that the actual async call works + response = await llm_client.create_response( + messages=[{"role": "user", "content": "test"}], + model="gpt-4o" + ) + + assert response is not None + print(f"✅ Real LiteLLM async call successful") + print(f" Response: {response.choices[0].message.content}") + + +async def main(): + """Run all tests""" + print("=" * 50) + print("ASYNC IMPLEMENTATION VERIFICATION") + print("=" * 50) + + try: + await test_parallel_execution() + await test_real_llm_async() + + print("\n" + "=" * 50) + print("✅ ALL TESTS PASSED - ASYNC IS WORKING!") + print("=" * 50) + + except Exception as e: + print(f"\n❌ Test failed: {e}") + raise + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/tests/mocks/mock_llm.py b/tests/mocks/mock_llm.py index 7c53275e..e1cd744f 100644 --- a/tests/mocks/mock_llm.py +++ b/tests/mocks/mock_llm.py @@ -258,7 +258,7 @@ def get_usage_stats(self) -> Dict[str, int]: "total_tokens": total_prompt_tokens + total_completion_tokens } - def create_response( + async def create_response( self, *, messages: list[dict[str, str]], @@ -274,13 +274,13 @@ def create_response( # Fall back to content-based detection content = str(messages).lower() response_type = self._determine_response_type(content) - + # Track the call self.call_count += 1 self.last_messages = messages self.last_model = model or self.default_model self.last_kwargs = kwargs - + # Store call in history call_info = { "messages": messages, @@ -290,26 +290,26 @@ def create_response( "timestamp": asyncio.get_event_loop().time() } self.call_history.append(call_info) - + # Simulate failure if configured if self.should_fail: raise Exception(self.failure_message) - + # Check for custom responses first if response_type in self.custom_responses: response_data = self.custom_responses[response_type] if callable(response_data): response_data = response_data(messages, **kwargs) return self._create_response(response_data, model=self.last_model) - + # Use default response mapping response_generator = self.response_mapping.get(response_type, self._default_response) response_data = response_generator(messages, **kwargs) - + response = self._create_response(response_data, model=self.last_model) - + # Call metrics callback if set if self.metrics_callback: self.metrics_callback(response, 100, response_type) # 100ms mock inference time - + return response \ No newline at end of file From a787fc69255563c501c94de554fe196d52f5bdd7 Mon Sep 17 00:00:00 2001 From: Derek Meegan Date: Thu, 18 Sep 2025 17:10:29 -0700 Subject: [PATCH 2/4] add changeset --- .changeset/whispering-brave-potoo.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .changeset/whispering-brave-potoo.md diff --git a/.changeset/whispering-brave-potoo.md b/.changeset/whispering-brave-potoo.md new file mode 100644 index 00000000..e7b37afd --- /dev/null +++ b/.changeset/whispering-brave-potoo.md @@ -0,0 +1,5 @@ +--- +"stagehand": patch +--- + +Make litellm client async From 754a8de2668c6e89dacef50b7cde291b9b63c98e Mon Sep 17 00:00:00 2001 From: Derek Meegan Date: Thu, 25 Sep 2025 14:12:08 -0700 Subject: [PATCH 3/4] remove extra file --- test_async_performance.py | 138 -------------------------------------- 1 file changed, 138 deletions(-) delete mode 100644 test_async_performance.py diff --git a/test_async_performance.py b/test_async_performance.py deleted file mode 100644 index 550549ec..00000000 --- a/test_async_performance.py +++ /dev/null @@ -1,138 +0,0 @@ -"""Test script to verify async LLM calls are non-blocking""" - -import asyncio -import time -from unittest.mock import AsyncMock, MagicMock -from stagehand.llm.client import LLMClient -from stagehand.llm.inference import observe, extract - - -async def simulate_slow_llm_response(delay=1.0): - """Simulate a slow LLM API response""" - await asyncio.sleep(delay) - return MagicMock( - usage=MagicMock(prompt_tokens=100, completion_tokens=50), - choices=[MagicMock(message=MagicMock(content='{"elements": []}'))] - ) - - -async def test_parallel_execution(): - """Test that multiple LLM calls can run in parallel""" - print("\n🧪 Testing parallel async execution...") - - # Create mock LLM client - mock_logger = MagicMock() - mock_logger.info = MagicMock() - mock_logger.debug = MagicMock() - mock_logger.error = MagicMock() - - llm_client = LLMClient( - stagehand_logger=mock_logger, - default_model="gpt-4o" - ) - - # Mock the async create_response to simulate delay - async def mock_create_response(**kwargs): - return await simulate_slow_llm_response(1.0) - - llm_client.create_response = mock_create_response - - # Measure time for parallel execution - start_time = time.time() - - # Run 3 observe calls in parallel - tasks = [ - observe("Find button 1", "DOM content 1", llm_client, logger=mock_logger), - observe("Find button 2", "DOM content 2", llm_client, logger=mock_logger), - observe("Find button 3", "DOM content 3", llm_client, logger=mock_logger), - ] - - results = await asyncio.gather(*tasks) - parallel_time = time.time() - start_time - - print(f"✅ Parallel execution of 3 calls took: {parallel_time:.2f}s") - print(f" Expected ~1s (running in parallel), not 3s (sequential)") - - # Verify results - assert len(results) == 3 - for i, result in enumerate(results, 1): - assert "elements" in result - print(f" Result {i}: {result}") - - # Test sequential execution for comparison - print("\n🧪 Testing sequential execution for comparison...") - start_time = time.time() - - result1 = await observe("Find button 1", "DOM content 1", llm_client, logger=mock_logger) - result2 = await observe("Find button 2", "DOM content 2", llm_client, logger=mock_logger) - result3 = await observe("Find button 3", "DOM content 3", llm_client, logger=mock_logger) - - sequential_time = time.time() - start_time - print(f"✅ Sequential execution of 3 calls took: {sequential_time:.2f}s") - print(f" Expected ~3s (running sequentially)") - - # Parallel should be significantly faster - assert parallel_time < sequential_time * 0.5, "Parallel execution should be much faster than sequential" - - print(f"\n🎉 Async implementation is working correctly!") - print(f" Parallel speedup: {sequential_time/parallel_time:.2f}x faster") - - -async def test_real_llm_async(): - """Test with real LiteLLM to ensure the async implementation works""" - print("\n🧪 Testing with real LiteLLM (using mock responses)...") - - import litellm - from unittest.mock import patch - - # Mock litellm.acompletion to return test data - async def mock_acompletion(**kwargs): - await asyncio.sleep(0.1) # Small delay to simulate API call - return MagicMock( - usage=MagicMock(prompt_tokens=100, completion_tokens=50), - choices=[MagicMock(message=MagicMock(content='{"elements": [{"selector": "#test"}]}'))] - ) - - with patch('litellm.acompletion', new=mock_acompletion): - mock_logger = MagicMock() - mock_logger.info = MagicMock() - mock_logger.debug = MagicMock() - mock_logger.error = MagicMock() - - llm_client = LLMClient( - stagehand_logger=mock_logger, - default_model="gpt-4o" - ) - - # Test that the actual async call works - response = await llm_client.create_response( - messages=[{"role": "user", "content": "test"}], - model="gpt-4o" - ) - - assert response is not None - print(f"✅ Real LiteLLM async call successful") - print(f" Response: {response.choices[0].message.content}") - - -async def main(): - """Run all tests""" - print("=" * 50) - print("ASYNC IMPLEMENTATION VERIFICATION") - print("=" * 50) - - try: - await test_parallel_execution() - await test_real_llm_async() - - print("\n" + "=" * 50) - print("✅ ALL TESTS PASSED - ASYNC IS WORKING!") - print("=" * 50) - - except Exception as e: - print(f"\n❌ Test failed: {e}") - raise - - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file From f29484da036777f518f692b3af1a8fea081916b3 Mon Sep 17 00:00:00 2001 From: Derek Meegan Date: Thu, 25 Sep 2025 16:19:33 -0700 Subject: [PATCH 4/4] format --- stagehand/handlers/act_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stagehand/handlers/act_handler.py b/stagehand/handlers/act_handler.py index 27ae6bd2..d6a7ccac 100644 --- a/stagehand/handlers/act_handler.py +++ b/stagehand/handlers/act_handler.py @@ -97,7 +97,7 @@ async def act(self, options: Union[ActOptions, ObserveResult]) -> ActResult: variables = options.get("variables", {}) element_to_act_on.arguments = [ str(arg).replace(f"%{key}%", str(value)) - for arg in (element_to_act_on.arguments or []) + for arg in element_to_act_on.arguments or [] for key, value in variables.items() ]