@@ -53,15 +53,13 @@ def __init__(self, config: BaseSchedulerConfig):
5353
5454 # hyper-parameters
5555 self .top_k = self .config .get ("top_k" , 5 )
56- self .act_mem_update_interval = self .config .get ("act_mem_update_interval" , 300 )
5756 self .context_window_size = self .config .get ("context_window_size" , 5 )
5857 self .activation_mem_size = self .config .get (
5958 "activation_mem_size" , DEFAULT_ACTIVATION_MEM_SIZE
6059 )
6160 self .enable_act_memory_update = self .config .get ("enable_act_memory_update" , False )
6261 self .act_mem_dump_path = self .config .get ("act_mem_dump_path" , DEFAULT_ACT_MEM_DUMP_PATH )
6362 self .search_method = TreeTextMemory_SEARCH_METHOD
64- self ._last_activation_mem_update_time = 0.0
6563
6664 self .enable_parallel_dispatch = self .config .get ("enable_parallel_dispatch" , False )
6765 self .max_workers = self .config .get (
@@ -143,6 +141,22 @@ def _set_current_context_from_message(self, msg: ScheduleMessageItem) -> None:
143141 self ._current_mem_cube_id = msg .mem_cube_id
144142 self ._current_mem_cube = msg .mem_cube
145143
144+ def _validate_messages (self , messages : list [ScheduleMessageItem ], label : str ):
145+ """Validate if all messages match the expected label.
146+
147+ Args:
148+ messages: List of message items to validate.
149+ label: Expected message label (e.g., QUERY_LABEL/ANSWER_LABEL).
150+
151+ Returns:
152+ bool: True if all messages passed validation, False if any failed.
153+ """
154+ for message in messages :
155+ if not self ._validate_message (message , label ):
156+ return False
157+ logger .error ("Message batch contains invalid labels, aborting processing" )
158+ return True
159+
146160 def _validate_message (self , message : ScheduleMessageItem , label : str ):
147161 """Validate if the message matches the expected label.
148162
@@ -270,62 +284,6 @@ def log_working_memory_replacement(
270284 logger .info (f"{ len (added_memories )} { LONG_TERM_MEMORY_TYPE } memorie(s) "
271285 f"transformed to { WORKING_MEMORY_TYPE } memories." )
272286
273- def update_activation_memory (
274- self ,
275- new_memories : list [str | TextualMemoryItem ],
276- mem_cube : GeneralMemCube ,
277- ) -> None :
278- """
279- Update activation memory by extracting KVCacheItems from new_memory (list of str),
280- add them to a KVCacheMemory instance, and dump to disk.
281- """
282- if len (new_memories ) == 0 :
283- logger .error ("update_activation_memory: new_memory is empty." )
284- return
285- if isinstance (new_memories [0 ], TextualMemoryItem ):
286- new_text_memories = [mem .memory for mem in new_memories ]
287- elif isinstance (new_memories [0 ], str ):
288- new_text_memories = new_memories
289- else :
290- logger .error ("Not Implemented." )
291-
292- try :
293- assert isinstance (mem_cube .act_mem , KVCacheMemory )
294- act_mem : KVCacheMemory = mem_cube .act_mem
295-
296- text_memory = MEMORY_ASSEMBLY_TEMPLATE .format (
297- memory_text = "" .join (
298- [
299- f"{ i + 1 } . { sentence .strip ()} \n "
300- for i , sentence in enumerate (new_text_memories )
301- if sentence .strip () # Skip empty strings
302- ]
303- )
304- )
305- if self .act_mem_backend == ACTIVATION_MEMORY_HF_BACKEND :
306- # huggingface kv cache
307- original_cache_items : List [KVCacheItem ] = act_mem .get_all ()
308- pre_cache_item : KVCacheItem = origin_cache_items [- 1 ]
309- original_text_memories = pre_cache_item .records .text_memories
310- act_mem .delete_all ()
311- cache_item : KVCacheItem = act_mem .extract (text_memory )
312- cache_item .records .text_memories = new_text_memories
313-
314- act_mem .add (cache_item )
315- act_mem .dump (self .act_mem_dump_path )
316-
317- elif self .act_mem_backend == ACTIVATION_MEMORY_VLLM_BACKEND :
318- # vllm kv cache
319- self .log_activation_memory_update (original_text_memories = original_text_memories ,
320- new_text_memories = new_text_memories ,
321- user_id = user_id ,
322- mem_cube_id = mem_cube_id ,
323- mem_cube = mem_cube )
324- else :
325- raise NotImplementedError (self .act_mem_backend )
326-
327- except Exception as e :
328- logger .warning (f"MOS-based activation memory update failed: { e } " )
329287
330288 def create_autofilled_log_item (
331289 self ,
@@ -421,32 +379,72 @@ def _message_consumer(self) -> None:
421379
422380 def start (self ) -> None :
423381 """
424- Start the message consumer thread.
382+ Start the message consumer thread and initialize dispatcher resources .
425383
426- Initializes and starts a daemon thread that will periodically
427- check for and process messages from the queue.
384+ Initializes and starts:
385+ 1. Message consumer thread
386+ 2. Dispatcher thread pool (if parallel dispatch enabled)
428387 """
429- if self ._consumer_thread is not None and self . _consumer_thread . is_alive () :
430- logger .warning ("Memory Scheduler thread is already running" )
388+ if self ._running :
389+ logger .warning ("Memory Scheduler is already running" )
431390 return
432391
392+ # Initialize dispatcher resources
393+ if self .enable_parallel_dispatch :
394+ logger .info (f"Initializing dispatcher thread pool with { self .max_workers } workers" )
395+
396+ # Start consumer thread
433397 self ._running = True
434398 self ._consumer_thread = threading .Thread (
435399 target = self ._message_consumer ,
436- daemon = True , # Allows program to exit even if thread is running
400+ daemon = True ,
437401 name = "MessageConsumerThread" ,
438402 )
439403 self ._consumer_thread .start ()
440404 logger .info ("Message consumer thread started" )
441405
442406 def stop (self ) -> None :
443- """Stop the consumer thread and clean up resources."""
444- if self ._consumer_thread is None or not self ._running :
445- logger .warning ("Memory Scheduler thread is not running" )
407+ """Stop all scheduler components gracefully.
408+
409+ 1. Stops message consumer thread
410+ 2. Shuts down dispatcher thread pool
411+ 3. Cleans up resources
412+ """
413+ if not self ._running :
414+ logger .warning ("Memory Scheduler is not running" )
446415 return
416+
417+ # Signal consumer thread to stop
447418 self ._running = False
448- if self ._consumer_thread .is_alive ():
449- self ._consumer_thread .join (timeout = 5.0 ) # Wait up to 5 seconds
419+
420+ # Wait for consumer thread
421+ if self ._consumer_thread and self ._consumer_thread .is_alive ():
422+ self ._consumer_thread .join (timeout = 5.0 )
450423 if self ._consumer_thread .is_alive ():
451- logger .warning ("Memory Scheduler thread did not stop gracefully" )
452- logger .info ("Memory Scheduler thread stopped" )
424+ logger .warning ("Consumer thread did not stop gracefully" )
425+ else :
426+ logger .info ("Consumer thread stopped" )
427+
428+ # Shutdown dispatcher
429+ if hasattr (self , 'dispatcher' ) and self .dispatcher :
430+ logger .info ("Shutting down dispatcher..." )
431+ self .dispatcher .shutdown ()
432+
433+ # Clean up queues
434+ self ._cleanup_queues ()
435+ logger .info ("Memory Scheduler stopped completely" )
436+
437+ def _cleanup_queues (self ) -> None :
438+ """Ensure all queues are emptied and marked as closed."""
439+ try :
440+ while not self .memos_message_queue .empty ():
441+ self .memos_message_queue .get_nowait ()
442+ self .memos_message_queue .task_done ()
443+ except queue .Empty :
444+ pass
445+
446+ try :
447+ while not self ._web_log_message_queue .empty ():
448+ self ._web_log_message_queue .get_nowait ()
449+ except queue .Empty :
450+ pass
0 commit comments