Skip to content

Commit bb17aef

Browse files
committed
feat: add a test_robustness execution to test thread pool execution
1 parent 2e70c15 commit bb17aef

File tree

1 file changed

+240
-0
lines changed

1 file changed

+240
-0
lines changed

tests/mem_scheduler/test_scheduler.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,246 @@ def test_scheduler_startup_mode_thread(self):
202202
# Stop the scheduler
203203
self.scheduler.stop()
204204

205+
def test_robustness(self):
206+
"""Test dispatcher robustness when thread pool is overwhelmed with tasks."""
207+
import threading
208+
import time
209+
210+
# Create a scheduler with a small thread pool for testing
211+
small_max_workers = 3
212+
self.scheduler.dispatcher.max_workers = small_max_workers
213+
214+
# Recreate dispatcher with smaller thread pool
215+
from memos.context.context import ContextThreadPoolExecutor
216+
217+
if self.scheduler.dispatcher.dispatcher_executor:
218+
self.scheduler.dispatcher.dispatcher_executor.shutdown(wait=True)
219+
220+
self.scheduler.dispatcher.dispatcher_executor = ContextThreadPoolExecutor(
221+
max_workers=small_max_workers, thread_name_prefix="test_dispatcher"
222+
)
223+
224+
# Track task completion
225+
completed_tasks = []
226+
failed_tasks = []
227+
task_lock = threading.Lock()
228+
229+
def slow_handler(messages: list[ScheduleMessageItem]) -> None:
230+
"""Handler that simulates slow processing to overwhelm thread pool."""
231+
try:
232+
task_id = messages[0].content if messages else "unknown"
233+
# Simulate slow processing (reduced from 2.0s to 20ms)
234+
time.sleep(0.02)
235+
with task_lock:
236+
completed_tasks.append(task_id)
237+
except Exception as e:
238+
with task_lock:
239+
failed_tasks.append(str(e))
240+
241+
def fast_handler(messages: list[ScheduleMessageItem]) -> None:
242+
"""Handler for quick tasks to test mixed workload."""
243+
try:
244+
task_id = messages[0].content if messages else "unknown"
245+
time.sleep(0.001) # Quick processing (reduced from 0.1s to 1ms)
246+
with task_lock:
247+
completed_tasks.append(f"fast_{task_id}")
248+
except Exception as e:
249+
with task_lock:
250+
failed_tasks.append(str(e))
251+
252+
# Register handlers
253+
slow_label = "slow_task"
254+
fast_label = "fast_task"
255+
self.scheduler.register_handlers({slow_label: slow_handler, fast_label: fast_handler})
256+
257+
# Start the scheduler
258+
self.scheduler.start()
259+
260+
# Test 1: Overwhelm thread pool with slow tasks
261+
print("Test 1: Overwhelming thread pool with slow tasks...")
262+
num_slow_tasks = small_max_workers * 3 # 9 tasks for 3 workers
263+
264+
slow_messages = []
265+
for i in range(num_slow_tasks):
266+
message = ScheduleMessageItem(
267+
label=slow_label,
268+
content=f"slow_task_{i}",
269+
user_id=f"test_user_{i}",
270+
mem_cube_id=f"test_mem_cube_{i}",
271+
mem_cube="test_mem_cube_obj",
272+
timestamp=datetime.now(),
273+
)
274+
slow_messages.append(message)
275+
276+
# Submit all slow tasks at once - directly dispatch instead of using submit_messages
277+
start_time = time.time()
278+
try:
279+
# Directly dispatch messages to bypass queue and immediately start processing
280+
self.scheduler.dispatcher.dispatch(slow_messages)
281+
except Exception as e:
282+
print(f"Exception during task dispatch: {e}")
283+
284+
# Test 2: Add fast tasks while slow tasks are running
285+
print("Test 2: Adding fast tasks while thread pool is busy...")
286+
time.sleep(0.005) # Let slow tasks start (reduced from 0.5s to 5ms)
287+
288+
num_fast_tasks = 5
289+
fast_messages = []
290+
for i in range(num_fast_tasks):
291+
message = ScheduleMessageItem(
292+
label=fast_label,
293+
content=f"fast_task_{i}",
294+
user_id=f"fast_user_{i}",
295+
mem_cube_id=f"fast_mem_cube_{i}",
296+
mem_cube="fast_mem_cube_obj",
297+
timestamp=datetime.now(),
298+
)
299+
fast_messages.append(message)
300+
301+
try:
302+
# Directly dispatch fast messages
303+
self.scheduler.dispatcher.dispatch(fast_messages)
304+
except Exception as e:
305+
print(f"Exception during fast task dispatch: {e}")
306+
307+
# Test 3: Check thread pool status during overload
308+
print("Test 3: Monitoring thread pool status...")
309+
running_tasks = self.scheduler.dispatcher.get_running_tasks()
310+
running_count = self.scheduler.dispatcher.get_running_task_count()
311+
print(f"Running tasks count: {running_count}")
312+
print(f"Running tasks: {list(running_tasks.keys())}")
313+
314+
# Test 4: Wait for some tasks to complete and verify recovery
315+
print("Test 4: Waiting for task completion and recovery...")
316+
max_wait_time = 0.5 # Maximum wait time (reduced from 15.0s to 0.5s)
317+
wait_start = time.time()
318+
319+
while time.time() - wait_start < max_wait_time:
320+
with task_lock:
321+
total_completed = len(completed_tasks)
322+
total_failed = len(failed_tasks)
323+
324+
if total_completed + total_failed >= num_slow_tasks + num_fast_tasks:
325+
break
326+
327+
time.sleep(0.01) # Check every 10ms (reduced from 1.0s)
328+
329+
# Final verification
330+
execution_time = time.time() - start_time
331+
with task_lock:
332+
final_completed = len(completed_tasks)
333+
final_failed = len(failed_tasks)
334+
335+
print(f"Execution completed in {execution_time:.2f} seconds")
336+
print(f"Completed tasks: {final_completed}")
337+
print(f"Failed tasks: {final_failed}")
338+
print(f"Completed task IDs: {completed_tasks}")
339+
if failed_tasks:
340+
print(f"Failed task errors: {failed_tasks}")
341+
342+
# Assertions for robustness test
343+
# At least some tasks should complete successfully
344+
self.assertGreater(final_completed, 0, "No tasks completed successfully")
345+
346+
# Total processed should be reasonable (allowing for some failures under stress)
347+
total_processed = final_completed + final_failed
348+
expected_total = num_slow_tasks + num_fast_tasks
349+
self.assertGreaterEqual(
350+
total_processed,
351+
expected_total * 0.7, # Allow 30% failure rate under extreme stress
352+
f"Too few tasks processed: {total_processed}/{expected_total}",
353+
)
354+
355+
# Fast tasks should generally complete faster than slow tasks
356+
fast_completed = [task for task in completed_tasks if task.startswith("fast_")]
357+
self.assertGreater(len(fast_completed), 0, "No fast tasks completed")
358+
359+
# Test 5: Verify thread pool recovery after stress
360+
print("Test 5: Testing thread pool recovery...")
361+
recovery_messages = []
362+
for i in range(3): # Small number of recovery tasks
363+
message = ScheduleMessageItem(
364+
label=fast_label,
365+
content=f"recovery_task_{i}",
366+
user_id=f"recovery_user_{i}",
367+
mem_cube_id=f"recovery_mem_cube_{i}",
368+
mem_cube="recovery_mem_cube_obj",
369+
timestamp=datetime.now(),
370+
)
371+
recovery_messages.append(message)
372+
373+
# Clear previous results
374+
with task_lock:
375+
completed_tasks.clear()
376+
failed_tasks.clear()
377+
378+
# Submit recovery tasks - directly dispatch
379+
try:
380+
self.scheduler.dispatcher.dispatch(recovery_messages)
381+
except Exception as e:
382+
print(f"Exception during recovery task dispatch: {e}")
383+
384+
# Wait for recovery tasks to be processed
385+
time.sleep(0.05) # Give time for recovery tasks to complete (reduced from 3.0s to 50ms)
386+
387+
with task_lock:
388+
recovery_completed = len(completed_tasks)
389+
recovery_failed = len(failed_tasks)
390+
391+
print(f"Recovery test - Completed: {recovery_completed}, Failed: {recovery_failed}")
392+
393+
# Recovery tasks should complete successfully
394+
self.assertGreaterEqual(
395+
recovery_completed,
396+
len(recovery_messages) * 0.8, # Allow some margin
397+
"Thread pool did not recover properly after stress test",
398+
)
399+
400+
# Stop the scheduler
401+
self.scheduler.stop()
402+
403+
# Test 6: Simulate dispatcher monitor restart functionality
404+
print("Test 6: Testing dispatcher monitor restart functionality...")
405+
406+
# Force a failure condition by setting failure count high
407+
monitor = self.scheduler.dispatcher_monitor
408+
if monitor and hasattr(monitor, "_pools"):
409+
with monitor._pool_lock:
410+
pool_name = monitor.dispatcher_pool_name
411+
if pool_name in monitor._pools:
412+
# Simulate multiple failures to trigger restart
413+
monitor._pools[pool_name]["failure_count"] = monitor.max_failures - 1
414+
monitor._pools[pool_name]["healthy"] = False
415+
print(f"Set failure count to {monitor._pools[pool_name]['failure_count']}")
416+
417+
# Trigger one more failure to cause restart
418+
monitor._check_pools_health()
419+
420+
# Wait a bit for restart to complete
421+
time.sleep(0.02) # Reduced from 2s to 20ms
422+
423+
# Check if pool was restarted (failure count should be reset)
424+
if pool_name in monitor._pools:
425+
final_failure_count = monitor._pools[pool_name]["failure_count"]
426+
is_healthy = monitor._pools[pool_name]["healthy"]
427+
print(
428+
f"After restart - Failure count: {final_failure_count}, Healthy: {is_healthy}"
429+
)
430+
431+
# Verify restart worked
432+
assert final_failure_count < monitor.max_failures, (
433+
f"Expected failure count to be reset, got {final_failure_count}"
434+
)
435+
print("Dispatcher monitor restart functionality verified!")
436+
else:
437+
print("Pool not found after restart attempt")
438+
else:
439+
print(f"Pool {pool_name} not found in monitor registry")
440+
else:
441+
print("Dispatcher monitor not available or pools not accessible")
442+
443+
print("Robustness test completed successfully!")
444+
205445
# Verify cleanup
206446
self.assertFalse(self.scheduler._running)
207447

0 commit comments

Comments
 (0)