@@ -202,6 +202,246 @@ def test_scheduler_startup_mode_thread(self):
202202 # Stop the scheduler
203203 self .scheduler .stop ()
204204
205+ def test_robustness (self ):
206+ """Test dispatcher robustness when thread pool is overwhelmed with tasks."""
207+ import threading
208+ import time
209+
210+ # Create a scheduler with a small thread pool for testing
211+ small_max_workers = 3
212+ self .scheduler .dispatcher .max_workers = small_max_workers
213+
214+ # Recreate dispatcher with smaller thread pool
215+ from memos .context .context import ContextThreadPoolExecutor
216+
217+ if self .scheduler .dispatcher .dispatcher_executor :
218+ self .scheduler .dispatcher .dispatcher_executor .shutdown (wait = True )
219+
220+ self .scheduler .dispatcher .dispatcher_executor = ContextThreadPoolExecutor (
221+ max_workers = small_max_workers , thread_name_prefix = "test_dispatcher"
222+ )
223+
224+ # Track task completion
225+ completed_tasks = []
226+ failed_tasks = []
227+ task_lock = threading .Lock ()
228+
229+ def slow_handler (messages : list [ScheduleMessageItem ]) -> None :
230+ """Handler that simulates slow processing to overwhelm thread pool."""
231+ try :
232+ task_id = messages [0 ].content if messages else "unknown"
233+ # Simulate slow processing (reduced from 2.0s to 20ms)
234+ time .sleep (0.02 )
235+ with task_lock :
236+ completed_tasks .append (task_id )
237+ except Exception as e :
238+ with task_lock :
239+ failed_tasks .append (str (e ))
240+
241+ def fast_handler (messages : list [ScheduleMessageItem ]) -> None :
242+ """Handler for quick tasks to test mixed workload."""
243+ try :
244+ task_id = messages [0 ].content if messages else "unknown"
245+ time .sleep (0.001 ) # Quick processing (reduced from 0.1s to 1ms)
246+ with task_lock :
247+ completed_tasks .append (f"fast_{ task_id } " )
248+ except Exception as e :
249+ with task_lock :
250+ failed_tasks .append (str (e ))
251+
252+ # Register handlers
253+ slow_label = "slow_task"
254+ fast_label = "fast_task"
255+ self .scheduler .register_handlers ({slow_label : slow_handler , fast_label : fast_handler })
256+
257+ # Start the scheduler
258+ self .scheduler .start ()
259+
260+ # Test 1: Overwhelm thread pool with slow tasks
261+ print ("Test 1: Overwhelming thread pool with slow tasks..." )
262+ num_slow_tasks = small_max_workers * 3 # 9 tasks for 3 workers
263+
264+ slow_messages = []
265+ for i in range (num_slow_tasks ):
266+ message = ScheduleMessageItem (
267+ label = slow_label ,
268+ content = f"slow_task_{ i } " ,
269+ user_id = f"test_user_{ i } " ,
270+ mem_cube_id = f"test_mem_cube_{ i } " ,
271+ mem_cube = "test_mem_cube_obj" ,
272+ timestamp = datetime .now (),
273+ )
274+ slow_messages .append (message )
275+
276+ # Submit all slow tasks at once - directly dispatch instead of using submit_messages
277+ start_time = time .time ()
278+ try :
279+ # Directly dispatch messages to bypass queue and immediately start processing
280+ self .scheduler .dispatcher .dispatch (slow_messages )
281+ except Exception as e :
282+ print (f"Exception during task dispatch: { e } " )
283+
284+ # Test 2: Add fast tasks while slow tasks are running
285+ print ("Test 2: Adding fast tasks while thread pool is busy..." )
286+ time .sleep (0.005 ) # Let slow tasks start (reduced from 0.5s to 5ms)
287+
288+ num_fast_tasks = 5
289+ fast_messages = []
290+ for i in range (num_fast_tasks ):
291+ message = ScheduleMessageItem (
292+ label = fast_label ,
293+ content = f"fast_task_{ i } " ,
294+ user_id = f"fast_user_{ i } " ,
295+ mem_cube_id = f"fast_mem_cube_{ i } " ,
296+ mem_cube = "fast_mem_cube_obj" ,
297+ timestamp = datetime .now (),
298+ )
299+ fast_messages .append (message )
300+
301+ try :
302+ # Directly dispatch fast messages
303+ self .scheduler .dispatcher .dispatch (fast_messages )
304+ except Exception as e :
305+ print (f"Exception during fast task dispatch: { e } " )
306+
307+ # Test 3: Check thread pool status during overload
308+ print ("Test 3: Monitoring thread pool status..." )
309+ running_tasks = self .scheduler .dispatcher .get_running_tasks ()
310+ running_count = self .scheduler .dispatcher .get_running_task_count ()
311+ print (f"Running tasks count: { running_count } " )
312+ print (f"Running tasks: { list (running_tasks .keys ())} " )
313+
314+ # Test 4: Wait for some tasks to complete and verify recovery
315+ print ("Test 4: Waiting for task completion and recovery..." )
316+ max_wait_time = 0.5 # Maximum wait time (reduced from 15.0s to 0.5s)
317+ wait_start = time .time ()
318+
319+ while time .time () - wait_start < max_wait_time :
320+ with task_lock :
321+ total_completed = len (completed_tasks )
322+ total_failed = len (failed_tasks )
323+
324+ if total_completed + total_failed >= num_slow_tasks + num_fast_tasks :
325+ break
326+
327+ time .sleep (0.01 ) # Check every 10ms (reduced from 1.0s)
328+
329+ # Final verification
330+ execution_time = time .time () - start_time
331+ with task_lock :
332+ final_completed = len (completed_tasks )
333+ final_failed = len (failed_tasks )
334+
335+ print (f"Execution completed in { execution_time :.2f} seconds" )
336+ print (f"Completed tasks: { final_completed } " )
337+ print (f"Failed tasks: { final_failed } " )
338+ print (f"Completed task IDs: { completed_tasks } " )
339+ if failed_tasks :
340+ print (f"Failed task errors: { failed_tasks } " )
341+
342+ # Assertions for robustness test
343+ # At least some tasks should complete successfully
344+ self .assertGreater (final_completed , 0 , "No tasks completed successfully" )
345+
346+ # Total processed should be reasonable (allowing for some failures under stress)
347+ total_processed = final_completed + final_failed
348+ expected_total = num_slow_tasks + num_fast_tasks
349+ self .assertGreaterEqual (
350+ total_processed ,
351+ expected_total * 0.7 , # Allow 30% failure rate under extreme stress
352+ f"Too few tasks processed: { total_processed } /{ expected_total } " ,
353+ )
354+
355+ # Fast tasks should generally complete faster than slow tasks
356+ fast_completed = [task for task in completed_tasks if task .startswith ("fast_" )]
357+ self .assertGreater (len (fast_completed ), 0 , "No fast tasks completed" )
358+
359+ # Test 5: Verify thread pool recovery after stress
360+ print ("Test 5: Testing thread pool recovery..." )
361+ recovery_messages = []
362+ for i in range (3 ): # Small number of recovery tasks
363+ message = ScheduleMessageItem (
364+ label = fast_label ,
365+ content = f"recovery_task_{ i } " ,
366+ user_id = f"recovery_user_{ i } " ,
367+ mem_cube_id = f"recovery_mem_cube_{ i } " ,
368+ mem_cube = "recovery_mem_cube_obj" ,
369+ timestamp = datetime .now (),
370+ )
371+ recovery_messages .append (message )
372+
373+ # Clear previous results
374+ with task_lock :
375+ completed_tasks .clear ()
376+ failed_tasks .clear ()
377+
378+ # Submit recovery tasks - directly dispatch
379+ try :
380+ self .scheduler .dispatcher .dispatch (recovery_messages )
381+ except Exception as e :
382+ print (f"Exception during recovery task dispatch: { e } " )
383+
384+ # Wait for recovery tasks to be processed
385+ time .sleep (0.05 ) # Give time for recovery tasks to complete (reduced from 3.0s to 50ms)
386+
387+ with task_lock :
388+ recovery_completed = len (completed_tasks )
389+ recovery_failed = len (failed_tasks )
390+
391+ print (f"Recovery test - Completed: { recovery_completed } , Failed: { recovery_failed } " )
392+
393+ # Recovery tasks should complete successfully
394+ self .assertGreaterEqual (
395+ recovery_completed ,
396+ len (recovery_messages ) * 0.8 , # Allow some margin
397+ "Thread pool did not recover properly after stress test" ,
398+ )
399+
400+ # Stop the scheduler
401+ self .scheduler .stop ()
402+
403+ # Test 6: Simulate dispatcher monitor restart functionality
404+ print ("Test 6: Testing dispatcher monitor restart functionality..." )
405+
406+ # Force a failure condition by setting failure count high
407+ monitor = self .scheduler .dispatcher_monitor
408+ if monitor and hasattr (monitor , "_pools" ):
409+ with monitor ._pool_lock :
410+ pool_name = monitor .dispatcher_pool_name
411+ if pool_name in monitor ._pools :
412+ # Simulate multiple failures to trigger restart
413+ monitor ._pools [pool_name ]["failure_count" ] = monitor .max_failures - 1
414+ monitor ._pools [pool_name ]["healthy" ] = False
415+ print (f"Set failure count to { monitor ._pools [pool_name ]['failure_count' ]} " )
416+
417+ # Trigger one more failure to cause restart
418+ monitor ._check_pools_health ()
419+
420+ # Wait a bit for restart to complete
421+ time .sleep (0.02 ) # Reduced from 2s to 20ms
422+
423+ # Check if pool was restarted (failure count should be reset)
424+ if pool_name in monitor ._pools :
425+ final_failure_count = monitor ._pools [pool_name ]["failure_count" ]
426+ is_healthy = monitor ._pools [pool_name ]["healthy" ]
427+ print (
428+ f"After restart - Failure count: { final_failure_count } , Healthy: { is_healthy } "
429+ )
430+
431+ # Verify restart worked
432+ assert final_failure_count < monitor .max_failures , (
433+ f"Expected failure count to be reset, got { final_failure_count } "
434+ )
435+ print ("Dispatcher monitor restart functionality verified!" )
436+ else :
437+ print ("Pool not found after restart attempt" )
438+ else :
439+ print (f"Pool { pool_name } not found in monitor registry" )
440+ else :
441+ print ("Dispatcher monitor not available or pools not accessible" )
442+
443+ print ("Robustness test completed successfully!" )
444+
205445 # Verify cleanup
206446 self .assertFalse (self .scheduler ._running )
207447
0 commit comments