-
Notifications
You must be signed in to change notification settings - Fork 677
Expand file tree
/
Copy pathdispatcher.py
More file actions
639 lines (543 loc) · 24.5 KB
/
dispatcher.py
File metadata and controls
639 lines (543 loc) · 24.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
import concurrent
import threading
import time
from collections import defaultdict
from collections.abc import Callable
from datetime import datetime, timezone
from typing import Any
from memos.context.context import (
ContextThreadPoolExecutor,
RequestContext,
generate_trace_id,
set_request_context,
)
from memos.log import get_logger
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
from memos.mem_scheduler.general_modules.task_threads import ThreadManager
from memos.mem_scheduler.schemas.general_schemas import (
DEFAULT_STOP_WAIT,
)
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem
from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator
from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue
from memos.mem_scheduler.task_schedule_modules.task_queue import ScheduleTaskQueue
from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube
from memos.mem_scheduler.utils.monitor_event_utils import emit_monitor_event, to_iso
from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker
logger = get_logger(__name__)
class SchedulerDispatcher(BaseSchedulerModule):
"""
Thread pool-based message dispatcher that routes messages to dedicated handlers
based on their labels.
Features:
- Dedicated thread pool per message label
- Batch message processing
- Graceful shutdown
- Bulk handler registration
- Thread race competition for parallel task execution
"""
def __init__(
self,
max_workers: int = 30,
memos_message_queue: ScheduleTaskQueue | None = None,
enable_parallel_dispatch: bool = True,
config=None,
status_tracker: TaskStatusTracker | None = None,
metrics: Any | None = None,
submit_web_logs: Callable | None = None, # ADDED
orchestrator: SchedulerOrchestrator | None = None,
):
super().__init__()
self.config = config
# Main dispatcher thread pool
self.max_workers = max_workers
# Accept either a ScheduleTaskQueue wrapper or a concrete queue instance
self.memos_message_queue = (
memos_message_queue.memos_message_queue
if hasattr(memos_message_queue, "memos_message_queue")
else memos_message_queue
)
self.orchestrator = SchedulerOrchestrator() if orchestrator is None else orchestrator
# Get multi-task timeout from config
self.multi_task_running_timeout = (
self.config.get("multi_task_running_timeout") if self.config else None
)
# Only initialize thread pool if in parallel mode
self.enable_parallel_dispatch = enable_parallel_dispatch
self.thread_name_prefix = "dispatcher"
if self.enable_parallel_dispatch:
self.dispatcher_executor = ContextThreadPoolExecutor(
max_workers=self.max_workers, thread_name_prefix=self.thread_name_prefix
)
logger.info(f"Max works of dispatcher is set to {self.max_workers}")
else:
self.dispatcher_executor = None
logger.info(f"enable_parallel_dispatch is set to {self.enable_parallel_dispatch}")
# Registered message handlers
self.handlers: dict[str, Callable] = {}
# Dispatcher running state
self._running = False
# Set to track active futures for monitoring purposes
self._futures = set()
# Thread race module for competitive task execution
self.thread_manager = ThreadManager(thread_pool_executor=self.dispatcher_executor)
# Task tracking for monitoring
self._running_tasks: dict[str, RunningTaskItem] = {}
self._task_lock = threading.Lock()
# Configure shutdown wait behavior from config or default
self.stop_wait = (
self.config.get("stop_wait", DEFAULT_STOP_WAIT) if self.config else DEFAULT_STOP_WAIT
)
self.metrics = metrics
self.status_tracker = status_tracker
self.submit_web_logs = submit_web_logs # ADDED
def on_messages_enqueued(self, msgs: list[ScheduleMessageItem]) -> None:
if not msgs:
return
# This is handled in BaseScheduler now
def _create_task_wrapper(self, handler: Callable, task_item: RunningTaskItem):
"""
Create a wrapper around the handler to track task execution and capture results.
Args:
handler: The original handler function
task_item: The RunningTaskItem to track
Returns:
Wrapped handler function that captures results and logs completion
"""
def wrapped_handler(messages: list[ScheduleMessageItem]):
start_time = time.time()
start_iso = datetime.fromtimestamp(start_time, tz=timezone.utc).isoformat()
if self.status_tracker:
for msg in messages:
self.status_tracker.task_started(task_id=msg.item_id, user_id=msg.user_id)
try:
first_msg = messages[0]
trace_id = getattr(first_msg, "trace_id", None) or generate_trace_id()
# Propagate trace_id and user info to logging context for this handler execution
ctx = RequestContext(
trace_id=trace_id,
user_name=getattr(first_msg, "user_name", None),
user_type=None,
)
set_request_context(ctx)
# --- mark start: record queuing time(now - enqueue_ts)---
now = time.time()
m = first_msg # All messages in this batch have same user and type
enq_ts = getattr(first_msg, "timestamp", None)
# Path 1: epoch seconds (preferred)
if isinstance(enq_ts, int | float):
enq_epoch = float(enq_ts)
# Path 2: datetime -> normalize to UTC epoch
elif hasattr(enq_ts, "timestamp"):
dt = enq_ts
if dt.tzinfo is None:
# treat naive as UTC to neutralize +8h skew
dt = dt.replace(tzinfo=timezone.utc)
enq_epoch = dt.timestamp()
else:
# fallback: treat as "just now"
enq_epoch = now
wait_sec = max(0.0, now - enq_epoch)
self.metrics.observe_task_wait_duration(wait_sec, m.user_id, m.label)
dequeue_ts = getattr(first_msg, "_dequeue_ts", None)
start_delay_ms = None
if isinstance(dequeue_ts, int | float):
start_delay_ms = max(0.0, start_time - dequeue_ts) * 1000
emit_monitor_event(
"start",
first_msg,
{
"start_ts": start_iso,
"start_delay_ms": start_delay_ms,
"enqueue_ts": to_iso(enq_ts),
"dequeue_ts": to_iso(
datetime.fromtimestamp(dequeue_ts, tz=timezone.utc)
if isinstance(dequeue_ts, int | float)
else None
),
},
)
# Execute the original handler
result = handler(messages)
# --- mark done ---
finish_time = time.time()
duration = finish_time - start_time
self.metrics.observe_task_duration(duration, m.user_id, m.label)
if self.status_tracker:
for msg in messages:
self.status_tracker.task_completed(task_id=msg.item_id, user_id=msg.user_id)
self.metrics.task_completed(user_id=m.user_id, task_type=m.label)
emit_monitor_event(
"finish",
first_msg,
{
"status": "ok",
"start_ts": start_iso,
"finish_ts": datetime.fromtimestamp(
finish_time, tz=timezone.utc
).isoformat(),
"exec_duration_ms": duration * 1000,
"total_duration_ms": self._calc_total_duration_ms(
finish_time, getattr(first_msg, "timestamp", None)
),
},
)
# Redis ack is handled in finally to cover failure cases
# Mark task as completed and remove from tracking
with self._task_lock:
if task_item.item_id in self._running_tasks:
task_item.mark_completed(result)
del self._running_tasks[task_item.item_id]
logger.info(f"Task completed: {task_item.get_execution_info()}")
return result
except Exception as e:
m = messages[0]
finish_time = time.time()
self.metrics.task_failed(m.user_id, m.label, type(e).__name__)
if self.status_tracker:
for msg in messages:
self.status_tracker.task_failed(
task_id=msg.item_id, user_id=msg.user_id, error_message=str(e)
)
emit_monitor_event(
"finish",
m,
{
"status": "fail",
"start_ts": start_iso,
"finish_ts": datetime.fromtimestamp(
finish_time, tz=timezone.utc
).isoformat(),
"exec_duration_ms": (finish_time - start_time) * 1000,
"error_type": type(e).__name__,
"error_msg": str(e),
"total_duration_ms": self._calc_total_duration_ms(
finish_time, getattr(m, "timestamp", None)
),
},
)
# Mark task as failed and remove from tracking
with self._task_lock:
if task_item.item_id in self._running_tasks:
task_item.mark_failed(str(e))
del self._running_tasks[task_item.item_id]
logger.error(f"Task failed: {task_item.get_execution_info()}, Error: {e}")
raise
finally:
# Ensure Redis messages are acknowledged even if handler fails
if (
isinstance(self.memos_message_queue, SchedulerRedisQueue)
and self.memos_message_queue is not None
):
try:
for msg in messages:
redis_message_id = msg.redis_message_id
self.memos_message_queue.ack_message(
user_id=msg.user_id,
mem_cube_id=msg.mem_cube_id,
task_label=msg.label,
redis_message_id=redis_message_id,
)
except Exception as ack_err:
logger.warning(f"Ack in finally failed: {ack_err}")
return wrapped_handler
def get_running_tasks(
self, filter_func: Callable[[RunningTaskItem], bool] | None = None
) -> dict[str, RunningTaskItem]:
"""
Get a copy of currently running tasks, optionally filtered by a custom function.
Args:
filter_func: Optional function that takes a RunningTaskItem and returns True if it should be included.
Common filters can be created using helper methods like filter_by_user_id, filter_by_task_name, etc.
Returns:
Dictionary of running tasks keyed by task ID
Examples:
# Get all running tasks
all_tasks = dispatcher.get_running_tasks()
# Get tasks for specific user
user_tasks = dispatcher.get_running_tasks(lambda task: task.user_id == "user123")
# Get tasks for specific task name
handler_tasks = dispatcher.get_running_tasks(lambda task: task.task_name == "test_handler")
# Get tasks with multiple conditions
filtered_tasks = dispatcher.get_running_tasks(
lambda task: task.user_id == "user123" and task.status == "running"
)
"""
with self._task_lock:
if filter_func is None:
return self._running_tasks.copy()
return {
task_id: task_item
for task_id, task_item in self._running_tasks.items()
if filter_func(task_item)
}
def get_running_task_count(self) -> int:
"""
Get the count of currently running tasks.
Returns:
Number of running tasks
"""
with self._task_lock:
return len(self._running_tasks)
def register_handler(self, label: str, handler: Callable[[list[ScheduleMessageItem]], None]):
"""
Register a handler function for a specific message label.
Args:
label: Message label to handle
handler: Callable that processes messages of this label
"""
self.handlers[label] = handler
def register_handlers(
self, handlers: dict[str, Callable[[list[ScheduleMessageItem]], None]]
) -> None:
"""
Bulk register multiple handlers from a dictionary.
Args:
handlers: Dictionary mapping labels to handler functions
Format: {label: handler_callable}
"""
for label, handler in handlers.items():
if not isinstance(label, str):
logger.error(f"Invalid label type: {type(label)}. Expected str.")
continue
if not callable(handler):
logger.error(f"Handler for label '{label}' is not callable.")
continue
self.register_handler(label=label, handler=handler)
logger.info(f"Registered {len(handlers)} handlers in bulk")
def unregister_handler(self, label: str) -> bool:
"""
Unregister a handler for a specific label.
Args:
label: The label to unregister the handler for
Returns:
bool: True if handler was found and removed, False otherwise
"""
if label in self.handlers:
del self.handlers[label]
logger.info(f"Unregistered handler for label: {label}")
return True
else:
logger.warning(f"No handler found for label: {label}")
return False
def unregister_handlers(self, labels: list[str]) -> dict[str, bool]:
"""
Unregister multiple handlers by their labels.
Args:
labels: List of labels to unregister handlers for
Returns:
dict[str, bool]: Dictionary mapping each label to whether it was successfully unregistered
"""
results = {}
for label in labels:
results[label] = self.unregister_handler(label)
logger.info(f"Unregistered handlers for {len(labels)} labels")
return results
def stats(self) -> dict[str, int]:
"""
Lightweight runtime stats for monitoring.
Returns:
{
'running': <number of running tasks>,
'inflight': <number of futures tracked (pending+running)>,
'handlers': <registered handler count>,
}
"""
try:
running = self.get_running_task_count()
except Exception:
running = 0
try:
with self._task_lock:
inflight = len(self._futures)
except Exception:
inflight = 0
try:
handlers = len(self.handlers)
except Exception:
handlers = 0
return {"running": running, "inflight": inflight, "handlers": handlers}
def _default_message_handler(self, messages: list[ScheduleMessageItem]) -> None:
logger.debug(f"Using _default_message_handler to deal with messages: {messages}")
def _handle_future_result(self, future):
with self._task_lock:
self._futures.discard(future)
try:
future.result() # this will throw exception
except Exception as e:
logger.error(f"Handler execution failed: {e!s}", exc_info=True)
@staticmethod
def _calc_total_duration_ms(finish_epoch: float, enqueue_ts) -> float | None:
"""
Calculate total duration from enqueue timestamp to finish time in milliseconds.
"""
try:
enq_epoch = None
if isinstance(enqueue_ts, int | float):
enq_epoch = float(enqueue_ts)
elif hasattr(enqueue_ts, "timestamp"):
dt = enqueue_ts
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
enq_epoch = dt.timestamp()
if enq_epoch is None:
return None
total_ms = max(0.0, finish_epoch - enq_epoch) * 1000
return total_ms
except Exception:
return None
def execute_task(
self,
user_id: str,
mem_cube_id: str,
task_label: str,
msgs: list[ScheduleMessageItem],
handler_call_back: Callable[[list[ScheduleMessageItem]], Any],
):
if isinstance(msgs, ScheduleMessageItem):
msgs = [msgs]
# Create task tracking item for this dispatch
task_item = RunningTaskItem(
user_id=user_id,
mem_cube_id=mem_cube_id,
task_info=f"Processing {len(msgs)} message(s) with label '{task_label}' for user {user_id} and mem_cube {mem_cube_id}",
task_name=f"{task_label}_handler",
messages=msgs,
)
# Uniformly register the task before execution
with self._task_lock:
self._running_tasks[task_item.item_id] = task_item
# Create wrapped handler for task tracking
wrapped_handler = self._create_task_wrapper(handler_call_back, task_item)
# dispatch to different handler
logger.debug(f"Task started: {task_item.get_execution_info()}")
# If priority is LEVEL_1, force synchronous execution regardless of thread pool availability
use_thread_pool = self.enable_parallel_dispatch and self.dispatcher_executor is not None
if use_thread_pool:
# Submit and track the future
future = self.dispatcher_executor.submit(wrapped_handler, msgs)
with self._task_lock:
self._futures.add(future)
future.add_done_callback(self._handle_future_result)
logger.info(
f"Dispatch {len(msgs)} message(s) to {task_label} handler for user {user_id} and mem_cube {mem_cube_id}."
)
else:
# For synchronous execution, the wrapper will run and remove the task upon completion
logger.info(
f"Execute {len(msgs)} message(s) synchronously for {task_label} for user {user_id} and mem_cube {mem_cube_id}."
)
wrapped_handler(msgs)
def dispatch(self, msg_list: list[ScheduleMessageItem]):
"""
Dispatch a list of messages to their respective handlers.
Args:
msg_list: List of ScheduleMessageItem objects to process
"""
if not msg_list:
logger.debug("Received empty message list, skipping dispatch")
return
# Group messages by user_id and mem_cube_id first
user_cube_groups = group_messages_by_user_and_mem_cube(msg_list)
# Process each user and mem_cube combination
for user_id, cube_groups in user_cube_groups.items():
for mem_cube_id, user_cube_msgs in cube_groups.items():
# Group messages by their labels within each user/mem_cube combination
label_groups = defaultdict(list)
for message in user_cube_msgs:
label_groups[message.label].append(message)
# Process each label group within this user/mem_cube combination
for label, msgs in label_groups.items():
handler = self.handlers.get(label, self._default_message_handler)
self.execute_task(
user_id=user_id,
mem_cube_id=mem_cube_id,
task_label=label,
msgs=msgs,
handler_call_back=handler,
)
def join(self, timeout: float | None = None) -> bool:
"""Wait for all dispatched tasks to complete.
Args:
timeout: Maximum time to wait in seconds. None means wait forever.
Returns:
bool: True if all tasks completed, False if timeout occurred.
"""
if not self.enable_parallel_dispatch or self.dispatcher_executor is None:
return True # Serial mode requires no waiting
done, not_done = concurrent.futures.wait(
self._futures, timeout=timeout, return_when=concurrent.futures.ALL_COMPLETED
)
# Check for exceptions in completed tasks
for future in done:
try:
future.result()
except Exception:
logger.error("Handler failed during shutdown", exc_info=True)
return len(not_done) == 0
def run_competitive_tasks(
self, tasks: dict[str, Callable[[threading.Event], Any]], timeout: float = 10.0
) -> tuple[str, Any] | None:
"""
Run multiple tasks in a competitive race, returning the result of the first task to complete.
Args:
tasks: Dictionary mapping task names to task functions that accept a stop_flag parameter
timeout: Maximum time to wait for any task to complete (in seconds)
Returns:
Tuple of (task_name, result) from the winning task, or None if no task completes
"""
logger.info(f"Starting competitive execution of {len(tasks)} tasks")
return self.thread_manager.run_race(tasks, timeout)
def run_multiple_tasks(
self,
tasks: dict[str, tuple[Callable, tuple]],
use_thread_pool: bool | None = None,
timeout: float | None = None,
) -> dict[str, Any]:
"""
Execute multiple tasks concurrently and return all results.
Args:
tasks: Dictionary mapping task names to (task_execution_function, task_execution_parameters) tuples
use_thread_pool: Whether to use ThreadPoolExecutor. If None, uses dispatcher's parallel mode setting
timeout: Maximum time to wait for all tasks to complete (in seconds). If None, uses config default.
Returns:
Dictionary mapping task names to their results
Raises:
TimeoutError: If tasks don't complete within the specified timeout
"""
# Use dispatcher's parallel mode setting if not explicitly specified
if use_thread_pool is None:
use_thread_pool = self.enable_parallel_dispatch
# Use config timeout if not explicitly provided
if timeout is None:
timeout = self.multi_task_running_timeout
logger.info(
f"Executing {len(tasks)} tasks concurrently (thread_pool: {use_thread_pool}, timeout: {timeout})"
)
try:
results = self.thread_manager.run_multiple_tasks(
tasks=tasks, use_thread_pool=use_thread_pool, timeout=timeout
)
logger.info(
f"Successfully completed {len([r for r in results.values() if r is not None])}/{len(tasks)} tasks"
)
return results
except Exception as e:
logger.error(f"Multiple tasks execution failed: {e}", exc_info=True)
raise
def shutdown(self) -> None:
"""Gracefully shutdown the dispatcher."""
self._running = False
# Shutdown executor
try:
self.dispatcher_executor.shutdown(wait=self.stop_wait, cancel_futures=True)
except Exception as e:
logger.error(f"Executor shutdown error: {e}", exc_info=True)
finally:
self._futures.clear()
def __enter__(self):
self._running = True
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.shutdown()