Skip to content

Commit eb16685

Browse files
committed
refactor & fix bugs: mem scheduler related codes including modules like dispatcher and retriever, the dependencies and the unittest.
1 parent 9323ede commit eb16685

File tree

11 files changed

+643
-344
lines changed

11 files changed

+643
-344
lines changed

poetry.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ sqlalchemy = "^2.0.41"
2828
redis = "^6.2.0"
2929
pika = "^1.3.2"
3030
schedule = "^1.2.2"
31+
pika = "^1.3.2"
3132

3233
[tool.poetry.group.dev]
3334
optional = false

src/memos/mem_scheduler/base_scheduler.py

Lines changed: 69 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,13 @@ def __init__(self, config: BaseSchedulerConfig):
5353

5454
# hyper-parameters
5555
self.top_k = self.config.get("top_k", 5)
56-
self.act_mem_update_interval = self.config.get("act_mem_update_interval", 300)
5756
self.context_window_size = self.config.get("context_window_size", 5)
5857
self.activation_mem_size = self.config.get(
5958
"activation_mem_size", DEFAULT_ACTIVATION_MEM_SIZE
6059
)
6160
self.enable_act_memory_update = self.config.get("enable_act_memory_update", False)
6261
self.act_mem_dump_path = self.config.get("act_mem_dump_path", DEFAULT_ACT_MEM_DUMP_PATH)
6362
self.search_method = TreeTextMemory_SEARCH_METHOD
64-
self._last_activation_mem_update_time = 0.0
6563

6664
self.enable_parallel_dispatch = self.config.get("enable_parallel_dispatch", False)
6765
self.max_workers = self.config.get(
@@ -143,6 +141,22 @@ def _set_current_context_from_message(self, msg: ScheduleMessageItem) -> None:
143141
self._current_mem_cube_id = msg.mem_cube_id
144142
self._current_mem_cube = msg.mem_cube
145143

144+
def _validate_messages(self, messages: list[ScheduleMessageItem], label: str):
145+
"""Validate if all messages match the expected label.
146+
147+
Args:
148+
messages: List of message items to validate.
149+
label: Expected message label (e.g., QUERY_LABEL/ANSWER_LABEL).
150+
151+
Returns:
152+
bool: True if all messages passed validation, False if any failed.
153+
"""
154+
for message in messages:
155+
if not self._validate_message(message, label):
156+
return False
157+
logger.error("Message batch contains invalid labels, aborting processing")
158+
return True
159+
146160
def _validate_message(self, message: ScheduleMessageItem, label: str):
147161
"""Validate if the message matches the expected label.
148162
@@ -270,62 +284,6 @@ def log_working_memory_replacement(
270284
logger.info(f"{len(added_memories)} {LONG_TERM_MEMORY_TYPE} memorie(s) "
271285
f"transformed to {WORKING_MEMORY_TYPE} memories.")
272286

273-
def update_activation_memory(
274-
self,
275-
new_memories: list[str | TextualMemoryItem],
276-
mem_cube: GeneralMemCube,
277-
) -> None:
278-
"""
279-
Update activation memory by extracting KVCacheItems from new_memory (list of str),
280-
add them to a KVCacheMemory instance, and dump to disk.
281-
"""
282-
if len(new_memories) == 0:
283-
logger.error("update_activation_memory: new_memory is empty.")
284-
return
285-
if isinstance(new_memories[0], TextualMemoryItem):
286-
new_text_memories = [mem.memory for mem in new_memories]
287-
elif isinstance(new_memories[0], str):
288-
new_text_memories = new_memories
289-
else:
290-
logger.error("Not Implemented.")
291-
292-
try:
293-
assert isinstance(mem_cube.act_mem, KVCacheMemory)
294-
act_mem: KVCacheMemory = mem_cube.act_mem
295-
296-
text_memory = MEMORY_ASSEMBLY_TEMPLATE.format(
297-
memory_text="".join(
298-
[
299-
f"{i + 1}. {sentence.strip()}\n"
300-
for i, sentence in enumerate(new_text_memories)
301-
if sentence.strip() # Skip empty strings
302-
]
303-
)
304-
)
305-
if self.act_mem_backend == ACTIVATION_MEMORY_HF_BACKEND :
306-
# huggingface kv cache
307-
original_cache_items: List[KVCacheItem] = act_mem.get_all()
308-
pre_cache_item: KVCacheItem = origin_cache_items[-1]
309-
original_text_memories = pre_cache_item.records.text_memories
310-
act_mem.delete_all()
311-
cache_item: KVCacheItem = act_mem.extract(text_memory)
312-
cache_item.records.text_memories = new_text_memories
313-
314-
act_mem.add(cache_item)
315-
act_mem.dump(self.act_mem_dump_path)
316-
317-
elif self.act_mem_backend == ACTIVATION_MEMORY_VLLM_BACKEND :
318-
# vllm kv cache
319-
self.log_activation_memory_update(original_text_memories=original_text_memories,
320-
new_text_memories=new_text_memories,
321-
user_id=user_id,
322-
mem_cube_id=mem_cube_id,
323-
mem_cube=mem_cube)
324-
else:
325-
raise NotImplementedError(self.act_mem_backend)
326-
327-
except Exception as e:
328-
logger.warning(f"MOS-based activation memory update failed: {e}")
329287

330288
def create_autofilled_log_item(
331289
self,
@@ -421,32 +379,72 @@ def _message_consumer(self) -> None:
421379

422380
def start(self) -> None:
423381
"""
424-
Start the message consumer thread.
382+
Start the message consumer thread and initialize dispatcher resources.
425383
426-
Initializes and starts a daemon thread that will periodically
427-
check for and process messages from the queue.
384+
Initializes and starts:
385+
1. Message consumer thread
386+
2. Dispatcher thread pool (if parallel dispatch enabled)
428387
"""
429-
if self._consumer_thread is not None and self._consumer_thread.is_alive():
430-
logger.warning("Memory Scheduler thread is already running")
388+
if self._running:
389+
logger.warning("Memory Scheduler is already running")
431390
return
432391

392+
# Initialize dispatcher resources
393+
if self.enable_parallel_dispatch:
394+
logger.info(f"Initializing dispatcher thread pool with {self.max_workers} workers")
395+
396+
# Start consumer thread
433397
self._running = True
434398
self._consumer_thread = threading.Thread(
435399
target=self._message_consumer,
436-
daemon=True, # Allows program to exit even if thread is running
400+
daemon=True,
437401
name="MessageConsumerThread",
438402
)
439403
self._consumer_thread.start()
440404
logger.info("Message consumer thread started")
441405

442406
def stop(self) -> None:
443-
"""Stop the consumer thread and clean up resources."""
444-
if self._consumer_thread is None or not self._running:
445-
logger.warning("Memory Scheduler thread is not running")
407+
"""Stop all scheduler components gracefully.
408+
409+
1. Stops message consumer thread
410+
2. Shuts down dispatcher thread pool
411+
3. Cleans up resources
412+
"""
413+
if not self._running:
414+
logger.warning("Memory Scheduler is not running")
446415
return
416+
417+
# Signal consumer thread to stop
447418
self._running = False
448-
if self._consumer_thread.is_alive():
449-
self._consumer_thread.join(timeout=5.0) # Wait up to 5 seconds
419+
420+
# Wait for consumer thread
421+
if self._consumer_thread and self._consumer_thread.is_alive():
422+
self._consumer_thread.join(timeout=5.0)
450423
if self._consumer_thread.is_alive():
451-
logger.warning("Memory Scheduler thread did not stop gracefully")
452-
logger.info("Memory Scheduler thread stopped")
424+
logger.warning("Consumer thread did not stop gracefully")
425+
else:
426+
logger.info("Consumer thread stopped")
427+
428+
# Shutdown dispatcher
429+
if hasattr(self, 'dispatcher') and self.dispatcher:
430+
logger.info("Shutting down dispatcher...")
431+
self.dispatcher.shutdown()
432+
433+
# Clean up queues
434+
self._cleanup_queues()
435+
logger.info("Memory Scheduler stopped completely")
436+
437+
def _cleanup_queues(self) -> None:
438+
"""Ensure all queues are emptied and marked as closed."""
439+
try:
440+
while not self.memos_message_queue.empty():
441+
self.memos_message_queue.get_nowait()
442+
self.memos_message_queue.task_done()
443+
except queue.Empty:
444+
pass
445+
446+
try:
447+
while not self._web_log_message_queue.empty():
448+
self._web_log_message_queue.get_nowait()
449+
except queue.Empty:
450+
pass

0 commit comments

Comments
 (0)