Skip to content

Commit f0af4e1

Browse files
authored
hotfix: redis dependency in scheduler (#781)
## Description <!-- Please include a summary of the changes below; Fill in the issue number that this PR addresses (if applicable); Fill in the related MemOS-Docs repository issue or PR link (if applicable); Mention the person who will review this PR (if you know who it is); Replace (summary), (issue), (docs-issue-or-pr-link), and (reviewer) with the appropriate information. 请在下方填写更改的摘要; 填写此 PR 解决的问题编号(如果适用); 填写相关的 MemOS-Docs 仓库 issue 或 PR 链接(如果适用); 提及将审查此 PR 的人(如果您知道是谁); 替换 (summary)、(issue)、(docs-issue-or-pr-link) 和 (reviewer) 为适当的信息。 --> Summary: (summary) Fix: #(issue) Docs Issue/PR: (docs-issue-or-pr-link) Reviewer: @(reviewer) ## Checklist: - [ ] I have performed a self-review of my own code | 我已自行检查了自己的代码 - [ ] I have commented my code in hard-to-understand areas | 我已在难以理解的地方对代码进行了注释 - [ ] I have added tests that prove my fix is effective or that my feature works | 我已添加测试以证明我的修复有效或功能正常 - [ ] I have created related documentation issue/PR in [MemOS-Docs](https://github.com/MemTensor/MemOS-Docs) (if applicable) | 我已在 [MemOS-Docs](https://github.com/MemTensor/MemOS-Docs) 中创建了相关的文档 issue/PR(如果适用) - [ ] I have linked the issue to this PR (if applicable) | 我已将 issue 链接到此 PR(如果适用) - [ ] I have mentioned the person who will review this PR | 我已提及将审查此 PR 的人
2 parents 33fc280 + 0990db6 commit f0af4e1

File tree

5 files changed

+188
-79
lines changed

5 files changed

+188
-79
lines changed

src/memos/mem_scheduler/task_schedule_modules/dispatcher.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,6 @@ def status_tracker(self) -> TaskStatusTracker | None:
129129
try:
130130
self._status_tracker = TaskStatusTracker(self.redis)
131131
# Propagate to submodules when created lazily
132-
if self.dispatcher:
133-
self.dispatcher.status_tracker = self._status_tracker
134132
if self.memos_message_queue:
135133
self.memos_message_queue.set_status_tracker(self._status_tracker)
136134
except Exception as e:

src/memos/mem_scheduler/task_schedule_modules/local_queue.py

Lines changed: 109 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,18 @@
44
the local memos_message_queue functionality in BaseScheduler.
55
"""
66

7+
from typing import TYPE_CHECKING
8+
9+
10+
if TYPE_CHECKING:
11+
from collections.abc import Callable
12+
713
from memos.log import get_logger
814
from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue
915
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
16+
from memos.mem_scheduler.schemas.task_schemas import DEFAULT_STREAM_KEY_PREFIX
17+
from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator
18+
from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker
1019
from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule
1120

1221

@@ -16,26 +25,38 @@
1625
class SchedulerLocalQueue(RedisSchedulerModule):
1726
def __init__(
1827
self,
19-
maxsize: int,
28+
maxsize: int = 0,
29+
stream_key_prefix: str = DEFAULT_STREAM_KEY_PREFIX,
30+
orchestrator: SchedulerOrchestrator | None = None,
31+
status_tracker: TaskStatusTracker | None = None,
2032
):
2133
"""
2234
Initialize the SchedulerLocalQueue with a maximum queue size limit.
35+
Arguments match SchedulerRedisQueue for compatibility.
2336
2437
Args:
25-
maxsize (int): Maximum number of messages allowed
26-
in each individual queue.
27-
If exceeded, subsequent puts will block
28-
or raise an exception based on `block` parameter.
38+
maxsize (int): Maximum number of messages allowed in each individual queue.
39+
stream_key_prefix (str): Prefix for stream keys (simulated).
40+
orchestrator: SchedulerOrchestrator instance (ignored).
41+
status_tracker: TaskStatusTracker instance (ignored).
2942
"""
3043
super().__init__()
3144

32-
self.stream_key_prefix = "local_queue"
45+
self.stream_key_prefix = stream_key_prefix or "local_queue"
3346

3447
self.max_internal_message_queue_size = maxsize
48+
3549
# Dictionary to hold per-stream queues: key = stream_key, value = Queue[ScheduleMessageItem]
3650
self.queue_streams: dict[str, Queue[ScheduleMessageItem]] = {}
51+
52+
self.orchestrator = orchestrator
53+
self.status_tracker = status_tracker
54+
55+
self._is_listening = False
56+
self._message_handler: Callable[[ScheduleMessageItem], None] | None = None
57+
3758
logger.info(
38-
f"SchedulerLocalQueue initialized with max_internal_message_queue_size={maxsize}"
59+
f"SchedulerLocalQueue initialized with max_internal_message_queue_size={self.max_internal_message_queue_size}"
3960
)
4061

4162
def get_stream_key(self, user_id: str, mem_cube_id: str, task_label: str) -> str:
@@ -62,7 +83,7 @@ def put(
6283
Exception: Any underlying error during queue.put() operation.
6384
"""
6485
stream_key = self.get_stream_key(
65-
user_id=message.user_id, mem_cube_id=message.mem_cube_id, task_label=message.task_label
86+
user_id=message.user_id, mem_cube_id=message.mem_cube_id, task_label=message.label
6687
)
6788

6889
message.stream_key = stream_key
@@ -86,7 +107,7 @@ def get(
86107
stream_key: str,
87108
block: bool = True,
88109
timeout: float | None = None,
89-
batch_size: int | None = None,
110+
batch_size: int | None = 1,
90111
) -> list[ScheduleMessageItem]:
91112
if batch_size is not None and batch_size <= 0:
92113
logger.warning(
@@ -99,47 +120,85 @@ def get(
99120
logger.error(f"Stream {stream_key} does not exist when trying to get messages.")
100121
return []
101122

123+
# Ensure we always request a batch so we get a list back
124+
effective_batch_size = batch_size if batch_size is not None else 1
125+
102126
# Note: Assumes custom Queue implementation supports batch_size parameter
103127
res = self.queue_streams[stream_key].get(
104-
block=block, timeout=timeout, batch_size=batch_size
128+
block=block, timeout=timeout, batch_size=effective_batch_size
105129
)
106130
logger.debug(
107131
f"Retrieved {len(res)} messages from queue '{stream_key}'. Current size: {self.queue_streams[stream_key].qsize()}"
108132
)
109133
return res
110134

111-
def get_nowait(self, batch_size: int | None = None) -> list[ScheduleMessageItem]:
135+
def get_nowait(self, stream_key: str, batch_size: int | None = 1) -> list[ScheduleMessageItem]:
112136
"""
113-
Non-blocking version of get(). Equivalent to get(block=False, batch_size=batch_size).
137+
Non-blocking version of get(). Equivalent to get(stream_key, block=False, batch_size=batch_size).
114138
115139
Returns immediately with available messages or an empty list if queue is empty.
116140
117141
Args:
142+
stream_key (str): The stream/queue identifier.
118143
batch_size (int | None): Number of messages to retrieve in a batch.
119144
If None, retrieves one message.
120145
121146
Returns:
122147
List[ScheduleMessageItem]: Retrieved messages or empty list if queue is empty.
123148
"""
124-
logger.debug(f"get_nowait() called with batch_size: {batch_size}")
125-
return self.get(block=False, batch_size=batch_size)
149+
logger.debug(f"get_nowait() called for {stream_key} with batch_size: {batch_size}")
150+
return self.get(stream_key=stream_key, block=False, batch_size=batch_size)
151+
152+
def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]:
153+
"""
154+
Get messages from all streams in round-robin or sequential fashion.
155+
Equivalent to SchedulerRedisQueue.get_messages.
156+
"""
157+
messages = []
158+
# Snapshot keys to avoid runtime modification issues
159+
stream_keys = list(self.queue_streams.keys())
160+
161+
# Simple strategy: try to get up to batch_size messages across all streams
162+
# We can just iterate and collect.
163+
164+
# Calculate how many to get per stream to be fair?
165+
# Or just greedy? Redis implementation uses a complex logic.
166+
# For local, let's keep it simple: just iterate and take what's available (non-blocking)
167+
168+
for stream_key in stream_keys:
169+
if len(messages) >= batch_size:
170+
break
171+
172+
needed = batch_size - len(messages)
173+
# Use get_nowait to avoid blocking
174+
fetched = self.get_nowait(stream_key=stream_key, batch_size=needed)
175+
messages.extend(fetched)
176+
177+
return messages
126178

127179
def qsize(self) -> dict:
128180
"""
129181
Return the current size of all internal queues as a dictionary.
130182
131183
Each key is the stream name, and each value is the number of messages in that queue.
184+
Also includes 'total_size'.
132185
133186
Returns:
134187
Dict[str, int]: Mapping from stream name to current queue size.
135188
"""
136189
sizes = {stream: queue.qsize() for stream, queue in self.queue_streams.items()}
190+
total_size = sum(sizes.values())
191+
sizes["total_size"] = total_size
137192
logger.debug(f"Current queue sizes: {sizes}")
138193
return sizes
139194

140-
def clear(self) -> None:
141-
for queue in self.queue_streams.values():
142-
queue.clear()
195+
def clear(self, stream_key: str | None = None) -> None:
196+
if stream_key:
197+
if stream_key in self.queue_streams:
198+
self.queue_streams[stream_key].clear()
199+
else:
200+
for queue in self.queue_streams.values():
201+
queue.clear()
143202

144203
@property
145204
def unfinished_tasks(self) -> int:
@@ -151,6 +210,38 @@ def unfinished_tasks(self) -> int:
151210
Returns:
152211
int: Sum of all message counts in all internal queues.
153212
"""
154-
total = sum(self.qsize().values())
213+
# qsize() now includes "total_size", so we need to be careful not to double count if we use qsize() values
214+
# But qsize() implementation above sums values from queue_streams, then adds total_size.
215+
# So sum(self.queue_streams.values().qsize()) is safer.
216+
total = sum(queue.qsize() for queue in self.queue_streams.values())
155217
logger.debug(f"Total unfinished tasks across all queues: {total}")
156218
return total
219+
220+
def get_stream_keys(self, stream_key_prefix: str | None = None) -> list[str]:
221+
"""
222+
Return list of active stream keys.
223+
"""
224+
prefix = stream_key_prefix or self.stream_key_prefix
225+
return [k for k in self.queue_streams if k.startswith(prefix)]
226+
227+
def size(self) -> int:
228+
"""
229+
Total size of all queues.
230+
"""
231+
return sum(q.qsize() for q in self.queue_streams.values())
232+
233+
def empty(self) -> bool:
234+
"""
235+
Check if all queues are empty.
236+
"""
237+
return self.size() == 0
238+
239+
def full(self) -> bool:
240+
"""
241+
Check if any queue is full (approximate).
242+
"""
243+
if self.max_internal_message_queue_size <= 0:
244+
return False
245+
return any(
246+
q.qsize() >= self.max_internal_message_queue_size for q in self.queue_streams.values()
247+
)

src/memos/mem_scheduler/task_schedule_modules/redis_queue.py

Lines changed: 59 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
the local memos_message_queue functionality in BaseScheduler.
66
"""
77

8-
import contextlib
98
import os
109
import re
1110
import threading
@@ -201,6 +200,20 @@ def _refresh_stream_keys(
201200
recent_seconds=DEFAULT_STREAM_RECENT_ACTIVE_SECONDS,
202201
now_sec=now_sec,
203202
)
203+
204+
# Ensure consumer groups for newly discovered active streams
205+
with self._stream_keys_lock:
206+
# Identify keys we haven't seen yet
207+
new_streams = [k for k in active_stream_keys if k not in self.seen_streams]
208+
209+
# Create groups outside the lock to avoid blocking
210+
for key in new_streams:
211+
self._ensure_consumer_group(key)
212+
213+
if new_streams:
214+
with self._stream_keys_lock:
215+
self.seen_streams.update(new_streams)
216+
204217
deleted_count = self._delete_streams(keys_to_delete)
205218
self._update_stream_cache_with_log(
206219
stream_key_prefix=stream_key_prefix,
@@ -560,10 +573,7 @@ def _read_new_messages_batch(
560573
return {}
561574

562575
# Pre-ensure consumer groups to avoid NOGROUP during batch reads
563-
for stream_key in stream_keys:
564-
with contextlib.suppress(Exception):
565-
self._ensure_consumer_group(stream_key=stream_key)
566-
576+
# (Optimization: rely on put() and _refresh_stream_keys() to ensure groups)
567577
pipe = self._redis_conn.pipeline(transaction=False)
568578
for stream_key in stream_keys:
569579
pipe.xreadgroup(
@@ -676,13 +686,6 @@ def _batch_claim_pending_messages(
676686
Returns:
677687
A list of (stream_key, claimed_entries) pairs for all successful claims.
678688
"""
679-
if not self._redis_conn or not claims_spec:
680-
return []
681-
682-
# Ensure consumer groups exist to avoid NOGROUP errors during batch claim
683-
for stream_key, _need_count, _label in claims_spec:
684-
with contextlib.suppress(Exception):
685-
self._ensure_consumer_group(stream_key=stream_key)
686689

687690
pipe = self._redis_conn.pipeline(transaction=False)
688691
for stream_key, need_count, label in claims_spec:
@@ -696,26 +699,42 @@ def _batch_claim_pending_messages(
696699
justid=False,
697700
)
698701

699-
results = []
700702
try:
701-
results = pipe.execute()
702-
except Exception:
703-
# Fallback: attempt sequential xautoclaim for robustness
704-
for stream_key, need_count, label in claims_spec:
705-
try:
706-
self._ensure_consumer_group(stream_key=stream_key)
707-
res = self._redis_conn.xautoclaim(
708-
name=stream_key,
709-
groupname=self.consumer_group,
710-
consumername=self.consumer_name,
711-
min_idle_time=self.orchestrator.get_task_idle_min(task_label=label),
712-
start_id="0-0",
713-
count=need_count,
714-
justid=False,
715-
)
716-
results.append(res)
717-
except Exception:
718-
continue
703+
# Execute with raise_on_error=False so we get exceptions in the results list
704+
# instead of aborting the whole batch.
705+
results = pipe.execute(raise_on_error=False)
706+
except Exception as e:
707+
logger.error(f"Pipeline execution critical failure: {e}")
708+
results = [e] * len(claims_spec)
709+
710+
# Handle individual failures (e.g. NOGROUP) by retrying just that stream
711+
final_results = []
712+
for i, res in enumerate(results):
713+
if isinstance(res, Exception):
714+
err_msg = str(res).lower()
715+
if "nogroup" in err_msg or "no such key" in err_msg:
716+
stream_key, need_count, label = claims_spec[i]
717+
try:
718+
self._ensure_consumer_group(stream_key=stream_key)
719+
retry_res = self._redis_conn.xautoclaim(
720+
name=stream_key,
721+
groupname=self.consumer_group,
722+
consumername=self.consumer_name,
723+
min_idle_time=self.orchestrator.get_task_idle_min(task_label=label),
724+
start_id="0-0",
725+
count=need_count,
726+
justid=False,
727+
)
728+
final_results.append(retry_res)
729+
except Exception as retry_err:
730+
logger.warning(f"Retry xautoclaim failed for {stream_key}: {retry_err}")
731+
final_results.append(None)
732+
else:
733+
final_results.append(None)
734+
else:
735+
final_results.append(res)
736+
737+
results = final_results
719738

720739
claimed_pairs: list[tuple[str, list[tuple[str, dict]]]] = []
721740
for (stream_key, _need_count, _label), claimed_result in zip(
@@ -1159,17 +1178,23 @@ def _delete_streams(self, keys_to_delete: list[str]) -> int:
11591178
del_pipe.delete(key)
11601179
del_pipe.execute()
11611180
deleted_count = len(keys_to_delete)
1162-
# Clean up empty-tracking state for deleted keys
1181+
# Clean up empty-tracking state and seen_streams for deleted keys
11631182
with self._empty_stream_seen_lock:
11641183
for key in keys_to_delete:
11651184
self._empty_stream_seen_times.pop(key, None)
1185+
1186+
with self._stream_keys_lock:
1187+
for key in keys_to_delete:
1188+
self.seen_streams.discard(key)
11661189
except Exception:
11671190
for key in keys_to_delete:
11681191
try:
11691192
self._redis_conn.delete(key)
11701193
deleted_count += 1
11711194
with self._empty_stream_seen_lock:
11721195
self._empty_stream_seen_times.pop(key, None)
1196+
with self._stream_keys_lock:
1197+
self.seen_streams.discard(key)
11731198
except Exception:
11741199
pass
11751200
return deleted_count
@@ -1190,8 +1215,6 @@ def _update_stream_cache_with_log(
11901215
self._stream_keys_last_refresh = time.time()
11911216
cache_count = len(self._stream_keys_cache)
11921217
logger.info(
1193-
f"[REDIS_QUEUE] Stream keys refresh: prefix='{stream_key_prefix}', "
1194-
f"total={len(candidate_keys)}, active={len(active_stream_keys)}, cached={cache_count}, "
1195-
f"active_threshold_sec={int(active_threshold_sec)}, deleted={deleted_count}, "
1196-
f"inactive_threshold_sec={int(DEFAULT_STREAM_INACTIVITY_DELETE_SECONDS)}"
1218+
f"Refreshed stream keys cache: {cache_count} active keys, "
1219+
f"{deleted_count} deleted, {len(candidate_keys)} candidates examined."
11971220
)

0 commit comments

Comments
 (0)