Skip to content

Commit 7f39e7e

Browse files
committed
address the incompatible issue of local scheduler
1 parent a6881b4 commit 7f39e7e

File tree

5 files changed

+164
-34
lines changed

5 files changed

+164
-34
lines changed

src/memos/mem_scheduler/base_scheduler.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,14 +1009,24 @@ def _monitor_loop(self):
10091009
q_sizes = self.memos_message_queue.qsize()
10101010

10111011
for stream_key, queue_length in q_sizes.items():
1012-
# Expected format: "memos:stream:{user_id}:{mem_cube_id}" or "{user_id}"
1012+
# Skip aggregate keys like 'total_size'
1013+
if stream_key == "total_size":
1014+
continue
1015+
1016+
# Key format: ...:{user_id}:{mem_cube_id}:{task_label}
1017+
# We want to extract user_id, which is the 3rd component from the end.
10131018
parts = stream_key.split(":")
10141019
if len(parts) >= 3:
1015-
user_id = parts[2]
1016-
self.metrics.update_queue_length(queue_length, user_id)
1017-
elif not self.use_redis_queue: # local queue
1018-
user_id = stream_key
1020+
user_id = parts[-3]
10191021
self.metrics.update_queue_length(queue_length, user_id)
1022+
else:
1023+
# Fallback for unexpected key formats (e.g. legacy or testing)
1024+
# Try to use the key itself if it looks like a user_id (no colons)
1025+
# or just log a warning?
1026+
# For now, let's assume if it's not total_size and short, it might be a direct user_id key
1027+
# (though that shouldn't happen with current queue implementations)
1028+
if ":" not in stream_key:
1029+
self.metrics.update_queue_length(queue_length, stream_key)
10201030

10211031
except Exception as e:
10221032
logger.error(f"Error in metrics monitor loop: {e}", exc_info=True)

src/memos/mem_scheduler/task_schedule_modules/local_queue.py

Lines changed: 69 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def put(
6262
Exception: Any underlying error during queue.put() operation.
6363
"""
6464
stream_key = self.get_stream_key(
65-
user_id=message.user_id, mem_cube_id=message.mem_cube_id, task_label=message.task_label
65+
user_id=message.user_id, mem_cube_id=message.mem_cube_id, task_label=message.label
6666
)
6767

6868
message.stream_key = stream_key
@@ -108,35 +108,95 @@ def get(
108108
)
109109
return res
110110

111-
def get_nowait(self, batch_size: int | None = None) -> list[ScheduleMessageItem]:
111+
def get_nowait(
112+
self, stream_key: str, batch_size: int | None = None
113+
) -> list[ScheduleMessageItem]:
112114
"""
113-
Non-blocking version of get(). Equivalent to get(block=False, batch_size=batch_size).
115+
Non-blocking version of get(). Equivalent to get(stream_key, block=False, batch_size=batch_size).
114116
115117
Returns immediately with available messages or an empty list if queue is empty.
116118
117119
Args:
120+
stream_key (str): The stream/queue identifier.
118121
batch_size (int | None): Number of messages to retrieve in a batch.
119122
If None, retrieves one message.
120123
121124
Returns:
122125
List[ScheduleMessageItem]: Retrieved messages or empty list if queue is empty.
123126
"""
124-
logger.debug(f"get_nowait() called with batch_size: {batch_size}")
125-
return self.get(block=False, batch_size=batch_size)
127+
logger.debug(f"get_nowait() called for {stream_key} with batch_size: {batch_size}")
128+
return self.get(stream_key=stream_key, block=False, batch_size=batch_size)
129+
130+
def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]:
131+
"""
132+
Get messages from all streams in round-robin or sequential fashion.
133+
Equivalent to SchedulerRedisQueue.get_messages.
134+
"""
135+
messages = []
136+
# Snapshot keys to avoid runtime modification issues
137+
stream_keys = list(self.queue_streams.keys())
138+
139+
# Simple strategy: try to get up to batch_size messages across all streams
140+
# We can just iterate and collect.
141+
142+
# Calculate how many to get per stream to be fair?
143+
# Or just greedy? Redis implementation uses a complex logic.
144+
# For local, let's keep it simple: just iterate and take what's available (non-blocking)
145+
146+
for stream_key in stream_keys:
147+
if len(messages) >= batch_size:
148+
break
149+
150+
needed = batch_size - len(messages)
151+
# Use get_nowait to avoid blocking
152+
fetched = self.get_nowait(stream_key=stream_key, batch_size=needed)
153+
messages.extend(fetched)
154+
155+
return messages
126156

127157
def qsize(self) -> dict:
128158
"""
129159
Return the current size of all internal queues as a dictionary.
130160
131161
Each key is the stream name, and each value is the number of messages in that queue.
162+
Also includes 'total_size'.
132163
133164
Returns:
134165
Dict[str, int]: Mapping from stream name to current queue size.
135166
"""
136167
sizes = {stream: queue.qsize() for stream, queue in self.queue_streams.items()}
168+
total_size = sum(sizes.values())
169+
sizes["total_size"] = total_size
137170
logger.debug(f"Current queue sizes: {sizes}")
138171
return sizes
139172

173+
def size(self) -> int:
174+
"""
175+
Get the current size of the queue (total message count).
176+
Compatible with SchedulerRedisQueue.
177+
"""
178+
return self.unfinished_tasks
179+
180+
def empty(self) -> bool:
181+
"""
182+
Check if the queue is empty.
183+
Compatible with SchedulerRedisQueue.
184+
"""
185+
return self.size() == 0
186+
187+
def full(self) -> bool:
188+
"""
189+
Check if the queue is full.
190+
Compatible with SchedulerRedisQueue.
191+
192+
Returns True if all internal queues are full.
193+
If there are no queues, returns False.
194+
"""
195+
if not self.queue_streams:
196+
return False
197+
198+
return all(queue.full() for queue in self.queue_streams.values())
199+
140200
def clear(self) -> None:
141201
for queue in self.queue_streams.values():
142202
queue.clear()
@@ -151,6 +211,9 @@ def unfinished_tasks(self) -> int:
151211
Returns:
152212
int: Sum of all message counts in all internal queues.
153213
"""
154-
total = sum(self.qsize().values())
214+
# qsize() now includes "total_size", so we need to be careful not to double count if we use qsize() values
215+
# But qsize() implementation above sums values from queue_streams, then adds total_size.
216+
# So sum(self.queue_streams.values().qsize()) is safer.
217+
total = sum(queue.qsize() for queue in self.queue_streams.values())
155218
logger.debug(f"Total unfinished tasks across all queues: {total}")
156219
return total

src/memos/mem_scheduler/task_schedule_modules/task_queue.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -153,28 +153,7 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt
153153
)
154154

155155
def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]:
156-
if isinstance(self.memos_message_queue, SchedulerRedisQueue):
157-
return self.memos_message_queue.get_messages(batch_size=batch_size)
158-
stream_keys = self.get_stream_keys()
159-
160-
if len(stream_keys) == 0:
161-
return []
162-
163-
messages: list[ScheduleMessageItem] = []
164-
165-
for stream_key in stream_keys:
166-
fetched = self.memos_message_queue.get(
167-
stream_key=stream_key,
168-
block=False,
169-
batch_size=batch_size,
170-
)
171-
172-
messages.extend(fetched)
173-
if len(messages) > 0:
174-
logger.debug(
175-
f"Fetched {len(messages)} messages across users with per-user batch_size={batch_size}"
176-
)
177-
return messages
156+
return self.memos_message_queue.get_messages(batch_size=batch_size)
178157

179158
def clear(self):
180159
self.memos_message_queue.clear()

src/memos/mem_scheduler/utils/status_tracker.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
class TaskStatusTracker:
1515
@require_python_package(import_name="redis", install_command="pip install redis")
16-
def __init__(self, redis_client: "redis.Redis"):
16+
def __init__(self, redis_client: "redis.Redis | None"):
1717
self.redis = redis_client
1818

1919
def _get_key(self, user_id: str) -> str:
@@ -41,6 +41,9 @@ def task_submitted(
4141
mem_cube_id: Memory cube identifier
4242
business_task_id: Optional business-level task ID (one task_id can have multiple item_ids)
4343
"""
44+
if not self.redis:
45+
return
46+
4447
key = self._get_key(user_id)
4548
payload = {
4649
"status": "waiting",
@@ -61,6 +64,9 @@ def task_submitted(
6164
self.redis.expire(key, timedelta(days=7))
6265

6366
def task_started(self, task_id: str, user_id: str):
67+
if not self.redis:
68+
return
69+
6470
key = self._get_key(user_id)
6571
existing_data_json = self.redis.hget(key, task_id)
6672
if not existing_data_json:
@@ -77,6 +83,9 @@ def task_started(self, task_id: str, user_id: str):
7783
self.redis.expire(key, timedelta(days=7))
7884

7985
def task_completed(self, task_id: str, user_id: str):
86+
if not self.redis:
87+
return
88+
8089
key = self._get_key(user_id)
8190
existing_data_json = self.redis.hget(key, task_id)
8291
if not existing_data_json:
@@ -91,6 +100,9 @@ def task_completed(self, task_id: str, user_id: str):
91100
self.redis.expire(key, timedelta(days=7))
92101

93102
def task_failed(self, task_id: str, user_id: str, error_message: str):
103+
if not self.redis:
104+
return
105+
94106
key = self._get_key(user_id)
95107
existing_data_json = self.redis.hget(key, task_id)
96108
if not existing_data_json:
@@ -108,11 +120,17 @@ def task_failed(self, task_id: str, user_id: str, error_message: str):
108120
self.redis.expire(key, timedelta(days=7))
109121

110122
def get_task_status(self, task_id: str, user_id: str) -> dict | None:
123+
if not self.redis:
124+
return None
125+
111126
key = self._get_key(user_id)
112127
data = self.redis.hget(key, task_id)
113128
return json.loads(data) if data else None
114129

115130
def get_all_tasks_for_user(self, user_id: str) -> dict[str, dict]:
131+
if not self.redis:
132+
return {}
133+
116134
key = self._get_key(user_id)
117135
all_tasks = self.redis.hgetall(key)
118136
return {tid: json.loads(t_data) for tid, t_data in all_tasks.items()}
@@ -132,6 +150,9 @@ def get_task_status_by_business_id(self, business_task_id: str, user_id: str) ->
132150
- If any item is 'failed' → 'failed'
133151
Returns None if task_id not found.
134152
"""
153+
if not self.redis:
154+
return None
155+
135156
# Get all item_ids for this task_id
136157
task_items_key = self._get_task_items_key(user_id, business_task_id)
137158
item_ids = self.redis.smembers(task_items_key)
@@ -180,6 +201,9 @@ def get_all_tasks_global(self) -> dict[str, dict[str, dict]]:
180201
Returns:
181202
dict: {user_id: {task_id: task_data, ...}, ...}
182203
"""
204+
if not self.redis:
205+
return {}
206+
183207
all_users_tasks = {}
184208
cursor: int | str = 0
185209
while True:

tests/test_local_queue_full.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import unittest
2+
3+
from datetime import datetime, timezone
4+
5+
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
6+
from memos.mem_scheduler.task_schedule_modules.local_queue import SchedulerLocalQueue
7+
8+
9+
class TestLocalQueueFull(unittest.TestCase):
10+
def test_full_behavior(self):
11+
# Create a queue with very small maxsize for testing
12+
lq = SchedulerLocalQueue(maxsize=1)
13+
14+
# Initially empty
15+
self.assertFalse(lq.full())
16+
17+
# Add message to stream 1
18+
msg1 = ScheduleMessageItem(
19+
user_id="u1",
20+
mem_cube_id="c1",
21+
label="l1",
22+
content="m1",
23+
timestamp=datetime.now(timezone.utc),
24+
)
25+
lq.put(msg1)
26+
27+
# Now stream 1 is full (maxsize=1).
28+
# Since it's the only stream, and it's full, lq.full() should be True.
29+
self.assertTrue(lq.full())
30+
31+
# Add message to stream 2
32+
msg2 = ScheduleMessageItem(
33+
user_id="u2",
34+
mem_cube_id="c2",
35+
label="l2",
36+
content="m2",
37+
timestamp=datetime.now(timezone.utc),
38+
)
39+
lq.put(msg2)
40+
41+
# Now both stream 1 and stream 2 are full. lq.full() should be True.
42+
self.assertTrue(lq.full())
43+
44+
# Remove message from stream 1
45+
stream1_key = lq.get_stream_key("u1", "c1", "l1")
46+
lq.get(stream1_key)
47+
48+
# Now stream 1 is empty, stream 2 is full.
49+
# "all streams are full" is False.
50+
self.assertFalse(lq.full())
51+
52+
53+
if __name__ == "__main__":
54+
unittest.main()

0 commit comments

Comments
 (0)