Skip to content

Commit adc89e2

Browse files
committed
make litellm async
1 parent ad95605 commit adc89e2

File tree

6 files changed

+161
-23
lines changed

6 files changed

+161
-23
lines changed

stagehand/handlers/extract_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ async def extract(
105105
schema = transformed_schema = DefaultExtractSchema
106106

107107
# Use inference to call the LLM
108-
extraction_response = extract_inference(
108+
extraction_response = await extract_inference(
109109
instruction=instruction,
110110
tree_elements=output_string,
111111
schema=transformed_schema,

stagehand/handlers/observe_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ async def observe(
7474
iframes = tree.get("iframes", [])
7575

7676
# use inference to call the llm
77-
observation_response = observe_inference(
77+
observation_response = await observe_inference(
7878
instruction=instruction,
7979
tree_elements=output_string,
8080
llm_client=self.stagehand.llm,

stagehand/llm/client.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(
6060
f"Set global litellm.api_base to {value}", category="llm"
6161
)
6262

63-
def create_response(
63+
async def create_response(
6464
self,
6565
*,
6666
messages: list[dict[str, str]],
@@ -77,7 +77,7 @@ def create_response(
7777
Overrides the default_model if provided.
7878
function_name: The name of the Stagehand function calling this method (ACT, OBSERVE, etc.)
7979
Used for metrics tracking.
80-
**kwargs: Additional parameters to pass directly to litellm.completion
80+
**kwargs: Additional parameters to pass directly to litellm.acompletion
8181
(e.g., temperature, max_tokens, stream=True, specific provider arguments).
8282
8383
Returns:
@@ -87,7 +87,7 @@ def create_response(
8787
8888
Raises:
8989
ValueError: If no model is specified (neither default nor in the call).
90-
Exception: Propagates exceptions from litellm.completion.
90+
Exception: Propagates exceptions from litellm.acompletion.
9191
"""
9292
completion_model = model or self.default_model
9393
if not completion_model:
@@ -115,16 +115,16 @@ def create_response(
115115
filtered_params["temperature"] = 1
116116

117117
self.logger.debug(
118-
f"Calling litellm.completion with model={completion_model} and params: {filtered_params}",
118+
f"Calling litellm.acompletion with model={completion_model} and params: {filtered_params}",
119119
category="llm",
120120
)
121121

122122
try:
123123
# Start tracking inference time
124124
start_time = start_inference_timer()
125125

126-
# Use litellm's completion function
127-
response = litellm.completion(**filtered_params)
126+
# Use litellm's async completion function
127+
response = await litellm.acompletion(**filtered_params)
128128

129129
# Calculate inference time
130130
inference_time_ms = get_inference_time_ms(start_time)
@@ -136,6 +136,6 @@ def create_response(
136136
return response
137137

138138
except Exception as e:
139-
self.logger.error(f"Error calling litellm.completion: {e}", category="llm")
139+
self.logger.error(f"Error calling litellm.acompletion: {e}", category="llm")
140140
# Consider more specific exception handling based on litellm errors
141141
raise

stagehand/llm/inference.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222

2323
# TODO: kwargs
24-
def observe(
24+
async def observe(
2525
instruction: str,
2626
tree_elements: str,
2727
llm_client: Any,
@@ -66,7 +66,7 @@ def observe(
6666
try:
6767
# Call the LLM
6868
logger.info("Calling LLM")
69-
response = llm_client.create_response(
69+
response = await llm_client.create_response(
7070
model=llm_client.default_model,
7171
messages=messages,
7272
response_format=ObserveInferenceSchema,
@@ -123,7 +123,7 @@ def observe(
123123
}
124124

125125

126-
def extract(
126+
async def extract(
127127
instruction: str,
128128
tree_elements: str,
129129
schema: Optional[Union[type[BaseModel], dict]] = None,
@@ -177,7 +177,7 @@ def extract(
177177

178178
# Call the LLM with appropriate parameters
179179
try:
180-
extract_response = llm_client.create_response(
180+
extract_response = await llm_client.create_response(
181181
model=llm_client.default_model,
182182
messages=extract_messages,
183183
response_format=response_format,
@@ -227,7 +227,7 @@ def extract(
227227
# Call LLM for metadata
228228
try:
229229
metadata_start_time = time.time()
230-
metadata_response = llm_client.create_response(
230+
metadata_response = await llm_client.create_response(
231231
model=llm_client.default_model,
232232
messages=metadata_messages,
233233
response_format=metadata_schema,

test_async_performance.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
"""Test script to verify async LLM calls are non-blocking"""
2+
3+
import asyncio
4+
import time
5+
from unittest.mock import AsyncMock, MagicMock
6+
from stagehand.llm.client import LLMClient
7+
from stagehand.llm.inference import observe, extract
8+
9+
10+
async def simulate_slow_llm_response(delay=1.0):
11+
"""Simulate a slow LLM API response"""
12+
await asyncio.sleep(delay)
13+
return MagicMock(
14+
usage=MagicMock(prompt_tokens=100, completion_tokens=50),
15+
choices=[MagicMock(message=MagicMock(content='{"elements": []}'))]
16+
)
17+
18+
19+
async def test_parallel_execution():
20+
"""Test that multiple LLM calls can run in parallel"""
21+
print("\n🧪 Testing parallel async execution...")
22+
23+
# Create mock LLM client
24+
mock_logger = MagicMock()
25+
mock_logger.info = MagicMock()
26+
mock_logger.debug = MagicMock()
27+
mock_logger.error = MagicMock()
28+
29+
llm_client = LLMClient(
30+
stagehand_logger=mock_logger,
31+
default_model="gpt-4o"
32+
)
33+
34+
# Mock the async create_response to simulate delay
35+
async def mock_create_response(**kwargs):
36+
return await simulate_slow_llm_response(1.0)
37+
38+
llm_client.create_response = mock_create_response
39+
40+
# Measure time for parallel execution
41+
start_time = time.time()
42+
43+
# Run 3 observe calls in parallel
44+
tasks = [
45+
observe("Find button 1", "DOM content 1", llm_client, logger=mock_logger),
46+
observe("Find button 2", "DOM content 2", llm_client, logger=mock_logger),
47+
observe("Find button 3", "DOM content 3", llm_client, logger=mock_logger),
48+
]
49+
50+
results = await asyncio.gather(*tasks)
51+
parallel_time = time.time() - start_time
52+
53+
print(f"✅ Parallel execution of 3 calls took: {parallel_time:.2f}s")
54+
print(f" Expected ~1s (running in parallel), not 3s (sequential)")
55+
56+
# Verify results
57+
assert len(results) == 3
58+
for i, result in enumerate(results, 1):
59+
assert "elements" in result
60+
print(f" Result {i}: {result}")
61+
62+
# Test sequential execution for comparison
63+
print("\n🧪 Testing sequential execution for comparison...")
64+
start_time = time.time()
65+
66+
result1 = await observe("Find button 1", "DOM content 1", llm_client, logger=mock_logger)
67+
result2 = await observe("Find button 2", "DOM content 2", llm_client, logger=mock_logger)
68+
result3 = await observe("Find button 3", "DOM content 3", llm_client, logger=mock_logger)
69+
70+
sequential_time = time.time() - start_time
71+
print(f"✅ Sequential execution of 3 calls took: {sequential_time:.2f}s")
72+
print(f" Expected ~3s (running sequentially)")
73+
74+
# Parallel should be significantly faster
75+
assert parallel_time < sequential_time * 0.5, "Parallel execution should be much faster than sequential"
76+
77+
print(f"\n🎉 Async implementation is working correctly!")
78+
print(f" Parallel speedup: {sequential_time/parallel_time:.2f}x faster")
79+
80+
81+
async def test_real_llm_async():
82+
"""Test with real LiteLLM to ensure the async implementation works"""
83+
print("\n🧪 Testing with real LiteLLM (using mock responses)...")
84+
85+
import litellm
86+
from unittest.mock import patch
87+
88+
# Mock litellm.acompletion to return test data
89+
async def mock_acompletion(**kwargs):
90+
await asyncio.sleep(0.1) # Small delay to simulate API call
91+
return MagicMock(
92+
usage=MagicMock(prompt_tokens=100, completion_tokens=50),
93+
choices=[MagicMock(message=MagicMock(content='{"elements": [{"selector": "#test"}]}'))]
94+
)
95+
96+
with patch('litellm.acompletion', new=mock_acompletion):
97+
mock_logger = MagicMock()
98+
mock_logger.info = MagicMock()
99+
mock_logger.debug = MagicMock()
100+
mock_logger.error = MagicMock()
101+
102+
llm_client = LLMClient(
103+
stagehand_logger=mock_logger,
104+
default_model="gpt-4o"
105+
)
106+
107+
# Test that the actual async call works
108+
response = await llm_client.create_response(
109+
messages=[{"role": "user", "content": "test"}],
110+
model="gpt-4o"
111+
)
112+
113+
assert response is not None
114+
print(f"✅ Real LiteLLM async call successful")
115+
print(f" Response: {response.choices[0].message.content}")
116+
117+
118+
async def main():
119+
"""Run all tests"""
120+
print("=" * 50)
121+
print("ASYNC IMPLEMENTATION VERIFICATION")
122+
print("=" * 50)
123+
124+
try:
125+
await test_parallel_execution()
126+
await test_real_llm_async()
127+
128+
print("\n" + "=" * 50)
129+
print("✅ ALL TESTS PASSED - ASYNC IS WORKING!")
130+
print("=" * 50)
131+
132+
except Exception as e:
133+
print(f"\n❌ Test failed: {e}")
134+
raise
135+
136+
137+
if __name__ == "__main__":
138+
asyncio.run(main())

tests/mocks/mock_llm.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def get_usage_stats(self) -> Dict[str, int]:
258258
"total_tokens": total_prompt_tokens + total_completion_tokens
259259
}
260260

261-
def create_response(
261+
async def create_response(
262262
self,
263263
*,
264264
messages: list[dict[str, str]],
@@ -274,13 +274,13 @@ def create_response(
274274
# Fall back to content-based detection
275275
content = str(messages).lower()
276276
response_type = self._determine_response_type(content)
277-
277+
278278
# Track the call
279279
self.call_count += 1
280280
self.last_messages = messages
281281
self.last_model = model or self.default_model
282282
self.last_kwargs = kwargs
283-
283+
284284
# Store call in history
285285
call_info = {
286286
"messages": messages,
@@ -290,26 +290,26 @@ def create_response(
290290
"timestamp": asyncio.get_event_loop().time()
291291
}
292292
self.call_history.append(call_info)
293-
293+
294294
# Simulate failure if configured
295295
if self.should_fail:
296296
raise Exception(self.failure_message)
297-
297+
298298
# Check for custom responses first
299299
if response_type in self.custom_responses:
300300
response_data = self.custom_responses[response_type]
301301
if callable(response_data):
302302
response_data = response_data(messages, **kwargs)
303303
return self._create_response(response_data, model=self.last_model)
304-
304+
305305
# Use default response mapping
306306
response_generator = self.response_mapping.get(response_type, self._default_response)
307307
response_data = response_generator(messages, **kwargs)
308-
308+
309309
response = self._create_response(response_data, model=self.last_model)
310-
310+
311311
# Call metrics callback if set
312312
if self.metrics_callback:
313313
self.metrics_callback(response, 100, response_type) # 100ms mock inference time
314-
314+
315315
return response

0 commit comments

Comments
 (0)