Skip to content

Commit 3e721da

Browse files
committed
feat(mem_scheduler): add messages logging for stuck tasks monitoring
- Add RunningTaskItem schema with optional messages field in task_schemas.py - Update dispatcher to pass messages when creating RunningTaskItem instances - Enhance dispatcher_monitor to log messages info for stuck tasks (count + first 3 messages) - Add comprehensive unit tests for new messages functionality - Fix existing test assertions to handle dispatcher's message grouping logic This improvement provides better debugging visibility for stuck tasks by including the actual message content and count in monitoring logs.
1 parent 5481f56 commit 3e721da

File tree

6 files changed

+451
-50
lines changed

6 files changed

+451
-50
lines changed

src/memos/mem_scheduler/base_scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __init__(self, config: BaseSchedulerConfig):
5959
self.enable_activation_memory = self.config.get("enable_activation_memory", False)
6060
self.act_mem_dump_path = self.config.get("act_mem_dump_path", DEFAULT_ACT_MEM_DUMP_PATH)
6161
self.search_method = TreeTextMemory_SEARCH_METHOD
62-
self.enable_parallel_dispatch = self.config.get("enable_parallel_dispatch", False)
62+
self.enable_parallel_dispatch = self.config.get("enable_parallel_dispatch", True)
6363
self.thread_pool_max_workers = self.config.get(
6464
"thread_pool_max_workers", DEFAULT_THREAD_POOL_MAX_WORKERS
6565
)

src/memos/mem_scheduler/general_modules/dispatcher.py

Lines changed: 110 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
1111
from memos.mem_scheduler.general_modules.task_threads import ThreadManager
1212
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
13+
from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem
1314

1415

1516
logger = get_logger(__name__)
@@ -28,7 +29,7 @@ class SchedulerDispatcher(BaseSchedulerModule):
2829
- Thread race competition for parallel task execution
2930
"""
3031

31-
def __init__(self, max_workers=30, enable_parallel_dispatch=False, config=None):
32+
def __init__(self, max_workers=30, enable_parallel_dispatch=True, config=None):
3233
super().__init__()
3334
self.config = config
3435

@@ -58,6 +59,68 @@ def __init__(self, max_workers=30, enable_parallel_dispatch=False, config=None):
5859
# Thread race module for competitive task execution
5960
self.thread_manager = ThreadManager(thread_pool_executor=self.dispatcher_executor)
6061

62+
# Task tracking for monitoring
63+
self._running_tasks: dict[str, RunningTaskItem] = {}
64+
self._task_lock = threading.Lock()
65+
66+
def _create_task_wrapper(self, handler: Callable, task_item: RunningTaskItem):
67+
"""
68+
Create a wrapper around the handler to track task execution and capture results.
69+
70+
Args:
71+
handler: The original handler function
72+
task_item: The RunningTaskItem to track
73+
74+
Returns:
75+
Wrapped handler function that captures results and logs completion
76+
"""
77+
78+
def wrapped_handler(messages: list[ScheduleMessageItem]):
79+
try:
80+
# Execute the original handler
81+
result = handler(messages)
82+
83+
# Mark task as completed and remove from tracking
84+
with self._task_lock:
85+
if task_item.item_id in self._running_tasks:
86+
task_item.mark_completed(result)
87+
del self._running_tasks[task_item.item_id]
88+
89+
logger.info(f"Task completed: {task_item.get_execution_info()}")
90+
return result
91+
92+
except Exception as e:
93+
# Mark task as failed and remove from tracking
94+
with self._task_lock:
95+
if task_item.item_id in self._running_tasks:
96+
task_item.mark_failed(str(e))
97+
del self._running_tasks[task_item.item_id]
98+
99+
logger.error(f"Task failed: {task_item.get_execution_info()}, Error: {e}")
100+
raise
101+
102+
return wrapped_handler
103+
104+
def get_running_tasks(self) -> dict[str, RunningTaskItem]:
105+
"""
106+
Get a copy of currently running tasks.
107+
108+
Returns:
109+
Dictionary of running tasks keyed by task ID
110+
"""
111+
with self._task_lock:
112+
return self._running_tasks.copy()
113+
114+
def get_running_task_count(self) -> int:
115+
"""
116+
Get the count of currently running tasks.
117+
118+
Returns:
119+
Number of running tasks
120+
"""
121+
with self._task_lock:
122+
return len(self._running_tasks)
123+
61124
def register_handler(self, label: str, handler: Callable[[list[ScheduleMessageItem]], None]):
62125
"""
63126
Register a handler function for a specific message label.
@@ -126,7 +189,7 @@ def unregister_handlers(self, labels: list[str]) -> dict[str, bool]:
126189
def _default_message_handler(self, messages: list[ScheduleMessageItem]) -> None:
127190
logger.debug(f"Using _default_message_handler to deal with messages: {messages}")
128191

129-
def group_messages_by_user_and_cube(
192+
def _group_messages_by_user_and_mem_cube(
130193
self, messages: list[ScheduleMessageItem]
131194
) -> dict[str, dict[str, list[ScheduleMessageItem]]]:
132195
"""
@@ -176,25 +239,51 @@ def dispatch(self, msg_list: list[ScheduleMessageItem]):
176239
logger.debug("Received empty message list, skipping dispatch")
177240
return
178241

179-
# Group messages by their labels, and organize messages by label
180-
label_groups = defaultdict(list)
181-
for message in msg_list:
182-
label_groups[message.label].append(message)
183-
184-
# Process each label group
185-
for label, msgs in label_groups.items():
186-
handler = self.handlers.get(label, self._default_message_handler)
187-
188-
# dispatch to different handler
189-
logger.debug(f"Dispatch {len(msgs)} message(s) to {label} handler.")
190-
if self.enable_parallel_dispatch and self.dispatcher_executor is not None:
191-
# Capture variables in lambda to avoid loop variable issues
192-
future = self.dispatcher_executor.submit(handler, msgs)
193-
self._futures.add(future)
194-
future.add_done_callback(self._handle_future_result)
195-
logger.info(f"Dispatched {len(msgs)} message(s) as future task")
196-
else:
197-
handler(msgs)
242+
# Group messages by user_id and mem_cube_id first
243+
user_cube_groups = self._group_messages_by_user_and_mem_cube(msg_list)
244+
245+
# Process each user and mem_cube combination
246+
for user_id, cube_groups in user_cube_groups.items():
247+
for mem_cube_id, user_cube_msgs in cube_groups.items():
248+
# Group messages by their labels within each user/mem_cube combination
249+
label_groups = defaultdict(list)
250+
for message in user_cube_msgs:
251+
label_groups[message.label].append(message)
252+
253+
# Process each label group within this user/mem_cube combination
254+
for label, msgs in label_groups.items():
255+
handler = self.handlers.get(label, self._default_message_handler)
256+
257+
# Create task tracking item for this dispatch
258+
task_item = RunningTaskItem(
259+
user_id=user_id,
260+
mem_cube_id=mem_cube_id,
261+
task_info=f"Processing {len(msgs)} message(s) with label '{label}' for user {user_id} and mem_cube {mem_cube_id}",
262+
task_name=f"{label}_handler",
263+
messages=msgs,
264+
)
265+
266+
# Add to running tasks
267+
with self._task_lock:
268+
self._running_tasks[task_item.item_id] = task_item
269+
270+
# Create wrapped handler for task tracking
271+
wrapped_handler = self._create_task_wrapper(handler, task_item)
272+
273+
# dispatch to different handler
274+
logger.debug(
275+
f"Dispatch {len(msgs)} message(s) to {label} handler for user {user_id} and mem_cube {mem_cube_id}."
276+
)
277+
logger.info(f"Task started: {task_item.get_execution_info()}")
278+
279+
if self.enable_parallel_dispatch and self.dispatcher_executor is not None:
280+
# Capture variables in lambda to avoid loop variable issues
281+
future = self.dispatcher_executor.submit(wrapped_handler, msgs)
282+
self._futures.add(future)
283+
future.add_done_callback(self._handle_future_result)
284+
logger.info(f"Dispatched {len(msgs)} message(s) as future task")
285+
else:
286+
wrapped_handler(msgs)
198287

199288
def join(self, timeout: float | None = None) -> bool:
200289
"""Wait for all dispatched tasks to complete.

src/memos/mem_scheduler/monitors/dispatcher_monitor.py

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
from memos.log import get_logger
1010
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
1111
from memos.mem_scheduler.general_modules.dispatcher import SchedulerDispatcher
12+
from memos.mem_scheduler.schemas.general_schemas import (
13+
DEFAULT_DISPATCHER_MONITOR_CHECK_INTERVAL,
14+
DEFAULT_DISPATCHER_MONITOR_MAX_FAILURES,
15+
DEFAULT_STUCK_THREAD_TOLERANCE,
16+
)
1217

1318

1419
logger = get_logger(__name__)
@@ -21,8 +26,12 @@ def __init__(self, config: BaseSchedulerConfig):
2126
super().__init__()
2227
self.config: BaseSchedulerConfig = config
2328

24-
self.check_interval = self.config.get("dispatcher_monitor_check_interval", 300)
25-
self.max_failures = self.config.get("dispatcher_monitor_max_failures", 2)
29+
self.check_interval = self.config.get(
30+
"dispatcher_monitor_check_interval", DEFAULT_DISPATCHER_MONITOR_CHECK_INTERVAL
31+
)
32+
self.max_failures = self.config.get(
33+
"dispatcher_monitor_max_failures", DEFAULT_DISPATCHER_MONITOR_MAX_FAILURES
34+
)
2635

2736
# Registry of monitored thread pools
2837
self._pools: dict[str, dict] = {}
@@ -189,22 +198,77 @@ def _check_pools_health(self) -> None:
189198
):
190199
self._restart_pool(name, pool_info)
191200

192-
def _check_pool_health(self, pool_info: dict, stuck_max_interval=4) -> tuple[bool, str]:
201+
def _check_pool_health(
202+
self, pool_info: dict, stuck_max_interval=4, stuck_thread_tolerance=None
203+
) -> tuple[bool, str]:
193204
"""
194-
Check health of a single thread pool.
205+
Check health of a single thread pool with enhanced task tracking.
195206
196207
Args:
197208
pool_info: Dictionary containing pool configuration
209+
stuck_max_interval: Maximum intervals before considering pool stuck
210+
stuck_thread_tolerance: Maximum number of stuck threads to tolerate before restarting pool
198211
199212
Returns:
200213
Tuple: (is_healthy, reason) where reason explains failure if not healthy
201214
"""
215+
if stuck_thread_tolerance is None:
216+
stuck_thread_tolerance = DEFAULT_STUCK_THREAD_TOLERANCE
217+
202218
executor = pool_info["executor"]
203219

204220
# Check if executor is shutdown
205221
if executor._shutdown: # pylint: disable=protected-access
206222
return False, "Executor is shutdown"
207223

224+
# Enhanced health check using dispatcher task tracking
225+
stuck_tasks = []
226+
if self.dispatcher:
227+
running_tasks = self.dispatcher.get_running_tasks()
228+
running_count = self.dispatcher.get_running_task_count()
229+
230+
# Log detailed task information
231+
if running_tasks:
232+
logger.debug(f"Currently running {running_count} tasks:")
233+
for _task_id, task in running_tasks.items():
234+
logger.debug(f" - {task.get_execution_info()}")
235+
else:
236+
logger.debug("No tasks currently running")
237+
238+
# Check for stuck tasks (running longer than expected)
239+
for task in running_tasks.values():
240+
if task.duration_seconds and task.duration_seconds > (
241+
self.check_interval * stuck_max_interval
242+
):
243+
stuck_tasks.append(task)
244+
245+
# Always log stuck tasks if any exist
246+
if stuck_tasks:
247+
logger.warning(f"Found {len(stuck_tasks)} potentially stuck tasks:")
248+
for task in stuck_tasks:
249+
task_info = task.get_execution_info()
250+
messages_info = ""
251+
if task.messages:
252+
messages_info = f", Messages: {len(task.messages)} items - {[str(msg) for msg in task.messages[:3]]}"
253+
if len(task.messages) > 3:
254+
messages_info += f" ... and {len(task.messages) - 3} more"
255+
logger.warning(f" - Stuck task: {task_info}{messages_info}")
256+
257+
# Check if stuck task count exceeds tolerance
258+
# If thread pool size is smaller, use the smaller value as threshold
259+
max_workers = pool_info.get("max_workers", 0)
260+
effective_tolerance = (
261+
min(stuck_thread_tolerance, max_workers)
262+
if max_workers > 0
263+
else stuck_thread_tolerance
264+
)
265+
266+
if len(stuck_tasks) >= effective_tolerance:
267+
return (
268+
False,
269+
f"Found {len(stuck_tasks)} stuck tasks (tolerance: {effective_tolerance})",
270+
)
271+
208272
# Check thread activity
209273
active_threads = sum(
210274
1
@@ -216,13 +280,24 @@ def _check_pool_health(self, pool_info: dict, stuck_max_interval=4) -> tuple[boo
216280
if active_threads == 0 and pool_info["max_workers"] > 0:
217281
return False, "No active worker threads"
218282

219-
# Check if threads are stuck (no activity for 2 intervals)
283+
# Check if threads are stuck (no activity for specified intervals)
220284
time_delta = (datetime.utcnow() - pool_info["last_active"]).total_seconds()
221285
if time_delta >= self.check_interval * stuck_max_interval:
222-
return False, "No recent activity"
286+
return False, f"No recent activity for {time_delta:.1f} seconds"
223287

224288
# If we got here, pool appears healthy
225289
pool_info["last_active"] = datetime.utcnow()
290+
291+
# Log health status with comprehensive information
292+
if self.dispatcher:
293+
task_count = self.dispatcher.get_running_task_count()
294+
max_workers = pool_info.get("max_workers", 0)
295+
stuck_count = len(stuck_tasks)
296+
logger.info(
297+
f"Pool health check passed - {active_threads} active threads, "
298+
f"{task_count} running tasks, pool size: {max_workers}, stuck tasks: {stuck_count}"
299+
)
300+
226301
return True, ""
227302

228303
def _restart_pool(self, name: str, pool_info: dict) -> None:

src/memos/mem_scheduler/schemas/general_schemas.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@
1717
DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT = 30
1818
DEFAULT_ACTIVATION_MEM_MONITOR_SIZE_LIMIT = 20
1919
DEFAULT_ACT_MEM_DUMP_PATH = f"{BASE_DIR}/outputs/mem_scheduler/mem_cube_scheduler_test.kv_cache"
20-
DEFAULT_THREAD_POOL_MAX_WORKERS = 10
20+
DEFAULT_THREAD_POOL_MAX_WORKERS = 30
2121
DEFAULT_CONSUME_INTERVAL_SECONDS = 0.05
22+
DEFAULT_DISPATCHER_MONITOR_CHECK_INTERVAL = 300
23+
DEFAULT_DISPATCHER_MONITOR_MAX_FAILURES = 2
24+
DEFAULT_STUCK_THREAD_TOLERANCE = 10
25+
2226
NOT_INITIALIZED = -1
2327

2428

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from datetime import datetime
2+
from pathlib import Path
3+
from typing import Any
4+
from uuid import uuid4
5+
6+
from pydantic import BaseModel, Field, computed_field
7+
8+
from memos.log import get_logger
9+
from memos.mem_scheduler.general_modules.misc import DictConversionMixin
10+
11+
12+
logger = get_logger(__name__)
13+
14+
FILE_PATH = Path(__file__).absolute()
15+
BASE_DIR = FILE_PATH.parent.parent.parent.parent.parent
16+
17+
18+
# ============== Running Tasks ==============
19+
class RunningTaskItem(BaseModel, DictConversionMixin):
20+
"""Data class for tracking running tasks in SchedulerDispatcher."""
21+
22+
item_id: str = Field(
23+
description="Unique identifier for the task item", default_factory=lambda: str(uuid4())
24+
)
25+
user_id: str = Field(..., description="Required user identifier", min_length=1)
26+
mem_cube_id: str = Field(..., description="Required memory cube identifier", min_length=1)
27+
task_info: str = Field(..., description="Information about the task being executed")
28+
task_name: str = Field(..., description="Name/type of the task handler")
29+
start_time: datetime = Field(description="Task start time", default_factory=datetime.utcnow)
30+
end_time: datetime | None = Field(default=None, description="Task completion time")
31+
status: str = Field(default="running", description="Task status: running, completed, failed")
32+
result: Any | None = Field(default=None, description="Task execution result")
33+
error_message: str | None = Field(default=None, description="Error message if task failed")
34+
messages: list[Any] | None = Field(
35+
default=None, description="List of messages being processed by this task"
36+
)
37+
38+
def mark_completed(self, result: Any | None = None) -> None:
39+
"""Mark task as completed with optional result."""
40+
self.end_time = datetime.utcnow()
41+
self.status = "completed"
42+
self.result = result
43+
44+
def mark_failed(self, error_message: str) -> None:
45+
"""Mark task as failed with error message."""
46+
self.end_time = datetime.utcnow()
47+
self.status = "failed"
48+
self.error_message = error_message
49+
50+
@computed_field
51+
@property
52+
def duration_seconds(self) -> float | None:
53+
"""Calculate task duration in seconds."""
54+
if self.end_time:
55+
return (self.end_time - self.start_time).total_seconds()
56+
return None
57+
58+
def get_execution_info(self) -> str:
59+
"""Get formatted execution information for logging."""
60+
duration = self.duration_seconds
61+
duration_str = f"{duration:.2f}s" if duration else "ongoing"
62+
63+
return (
64+
f"Task {self.task_name} (ID: {self.item_id[:8]}) "
65+
f"for user {self.user_id}, cube {self.mem_cube_id} - "
66+
f"Status: {self.status}, Duration: {duration_str}"
67+
)

0 commit comments

Comments
 (0)