1
+ """Test script to verify async LLM calls are non-blocking"""
2
+
3
+ import asyncio
4
+ import time
5
+ from unittest .mock import AsyncMock , MagicMock
6
+ from stagehand .llm .client import LLMClient
7
+ from stagehand .llm .inference import observe , extract
8
+
9
+
10
+ async def simulate_slow_llm_response (delay = 1.0 ):
11
+ """Simulate a slow LLM API response"""
12
+ await asyncio .sleep (delay )
13
+ return MagicMock (
14
+ usage = MagicMock (prompt_tokens = 100 , completion_tokens = 50 ),
15
+ choices = [MagicMock (message = MagicMock (content = '{"elements": []}' ))]
16
+ )
17
+
18
+
19
+ async def test_parallel_execution ():
20
+ """Test that multiple LLM calls can run in parallel"""
21
+ print ("\n 🧪 Testing parallel async execution..." )
22
+
23
+ # Create mock LLM client
24
+ mock_logger = MagicMock ()
25
+ mock_logger .info = MagicMock ()
26
+ mock_logger .debug = MagicMock ()
27
+ mock_logger .error = MagicMock ()
28
+
29
+ llm_client = LLMClient (
30
+ stagehand_logger = mock_logger ,
31
+ default_model = "gpt-4o"
32
+ )
33
+
34
+ # Mock the async create_response to simulate delay
35
+ async def mock_create_response (** kwargs ):
36
+ return await simulate_slow_llm_response (1.0 )
37
+
38
+ llm_client .create_response = mock_create_response
39
+
40
+ # Measure time for parallel execution
41
+ start_time = time .time ()
42
+
43
+ # Run 3 observe calls in parallel
44
+ tasks = [
45
+ observe ("Find button 1" , "DOM content 1" , llm_client , logger = mock_logger ),
46
+ observe ("Find button 2" , "DOM content 2" , llm_client , logger = mock_logger ),
47
+ observe ("Find button 3" , "DOM content 3" , llm_client , logger = mock_logger ),
48
+ ]
49
+
50
+ results = await asyncio .gather (* tasks )
51
+ parallel_time = time .time () - start_time
52
+
53
+ print (f"✅ Parallel execution of 3 calls took: { parallel_time :.2f} s" )
54
+ print (f" Expected ~1s (running in parallel), not 3s (sequential)" )
55
+
56
+ # Verify results
57
+ assert len (results ) == 3
58
+ for i , result in enumerate (results , 1 ):
59
+ assert "elements" in result
60
+ print (f" Result { i } : { result } " )
61
+
62
+ # Test sequential execution for comparison
63
+ print ("\n 🧪 Testing sequential execution for comparison..." )
64
+ start_time = time .time ()
65
+
66
+ result1 = await observe ("Find button 1" , "DOM content 1" , llm_client , logger = mock_logger )
67
+ result2 = await observe ("Find button 2" , "DOM content 2" , llm_client , logger = mock_logger )
68
+ result3 = await observe ("Find button 3" , "DOM content 3" , llm_client , logger = mock_logger )
69
+
70
+ sequential_time = time .time () - start_time
71
+ print (f"✅ Sequential execution of 3 calls took: { sequential_time :.2f} s" )
72
+ print (f" Expected ~3s (running sequentially)" )
73
+
74
+ # Parallel should be significantly faster
75
+ assert parallel_time < sequential_time * 0.5 , "Parallel execution should be much faster than sequential"
76
+
77
+ print (f"\n 🎉 Async implementation is working correctly!" )
78
+ print (f" Parallel speedup: { sequential_time / parallel_time :.2f} x faster" )
79
+
80
+
81
+ async def test_real_llm_async ():
82
+ """Test with real LiteLLM to ensure the async implementation works"""
83
+ print ("\n 🧪 Testing with real LiteLLM (using mock responses)..." )
84
+
85
+ import litellm
86
+ from unittest .mock import patch
87
+
88
+ # Mock litellm.acompletion to return test data
89
+ async def mock_acompletion (** kwargs ):
90
+ await asyncio .sleep (0.1 ) # Small delay to simulate API call
91
+ return MagicMock (
92
+ usage = MagicMock (prompt_tokens = 100 , completion_tokens = 50 ),
93
+ choices = [MagicMock (message = MagicMock (content = '{"elements": [{"selector": "#test"}]}' ))]
94
+ )
95
+
96
+ with patch ('litellm.acompletion' , new = mock_acompletion ):
97
+ mock_logger = MagicMock ()
98
+ mock_logger .info = MagicMock ()
99
+ mock_logger .debug = MagicMock ()
100
+ mock_logger .error = MagicMock ()
101
+
102
+ llm_client = LLMClient (
103
+ stagehand_logger = mock_logger ,
104
+ default_model = "gpt-4o"
105
+ )
106
+
107
+ # Test that the actual async call works
108
+ response = await llm_client .create_response (
109
+ messages = [{"role" : "user" , "content" : "test" }],
110
+ model = "gpt-4o"
111
+ )
112
+
113
+ assert response is not None
114
+ print (f"✅ Real LiteLLM async call successful" )
115
+ print (f" Response: { response .choices [0 ].message .content } " )
116
+
117
+
118
+ async def main ():
119
+ """Run all tests"""
120
+ print ("=" * 50 )
121
+ print ("ASYNC IMPLEMENTATION VERIFICATION" )
122
+ print ("=" * 50 )
123
+
124
+ try :
125
+ await test_parallel_execution ()
126
+ await test_real_llm_async ()
127
+
128
+ print ("\n " + "=" * 50 )
129
+ print ("✅ ALL TESTS PASSED - ASYNC IS WORKING!" )
130
+ print ("=" * 50 )
131
+
132
+ except Exception as e :
133
+ print (f"\n ❌ Test failed: { e } " )
134
+ raise
135
+
136
+
137
+ if __name__ == "__main__" :
138
+ asyncio .run (main ())
0 commit comments