Skip to content

Commit 173bebc

Browse files
committed
fix bugs: address ai review issues
1 parent 28e1368 commit 173bebc

File tree

6 files changed

+44
-34
lines changed

6 files changed

+44
-34
lines changed

evaluation/scripts/locomo/locomo_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ async def main(frame, version="default", options=None, num_runs=1, max_workers=4
311311
with open(response_path) as file:
312312
locomo_responses = json.load(file)
313313

314-
num_users = 2
314+
num_users = 10
315315
all_grades = {}
316316

317317
total_responses_count = sum(

src/memos/mem_scheduler/monitors/task_schedule_monitor.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,13 +209,12 @@ async def _collect_async() -> dict:
209209
return local
210210

211211
try:
212-
loop = asyncio.get_running_loop()
213-
if loop.is_running():
214-
raise RuntimeError("event loop running")
212+
asyncio.get_running_loop()
213+
loop_running = True
215214
except RuntimeError:
216-
loop = None
215+
loop_running = False
217216

218-
if loop is None:
217+
if not loop_running:
219218
return asyncio.run(_collect_async())
220219
except Exception as e:
221220
logger.debug(f"Parallel status collection failed, fallback to sequential: {e}")

src/memos/mem_scheduler/task_schedule_modules/dispatcher.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,6 @@ def __init__(
9393
# Task tracking for monitoring
9494
self._running_tasks: dict[str, RunningTaskItem] = {}
9595
self._task_lock = threading.Lock()
96-
self._completed_tasks = []
97-
self.completed_tasks_max_show_size = 10
9896

9997
# Configure shutdown wait behavior from config or default
10098
self.stop_wait = (
@@ -128,6 +126,10 @@ def wrapped_handler(messages: list[ScheduleMessageItem]):
128126
self.status_tracker.task_started(
129127
task_id=task_item.item_id, user_id=task_item.user_id
130128
)
129+
# Record task as running for monitoring (LocalQueue only)
130+
if isinstance(self.memos_message_queue, SchedulerLocalQueue):
131+
with self._task_lock:
132+
self._running_tasks[task_item.item_id] = task_item
131133
try:
132134
# --- mark start: record queuing time(now - enqueue_ts)---
133135
now = time.time()
@@ -179,14 +181,12 @@ def wrapped_handler(messages: list[ScheduleMessageItem]):
179181
redis_message_id=redis_message_id,
180182
)
181183

182-
# Mark task as completed and remove from tracking
183-
with self._task_lock:
184-
if task_item.item_id in self._running_tasks:
185-
task_item.mark_completed(result)
186-
del self._running_tasks[task_item.item_id]
187-
self._completed_tasks.append(task_item)
188-
if len(self._completed_tasks) > self.completed_tasks_max_show_size:
189-
self._completed_tasks.pop(0)
184+
# Mark task as completed and remove from tracking (LocalQueue only)
185+
if isinstance(self.memos_message_queue, SchedulerLocalQueue):
186+
with self._task_lock:
187+
if task_item.item_id in self._running_tasks:
188+
task_item.mark_completed(result)
189+
del self._running_tasks[task_item.item_id]
190190
logger.info(f"Task completed: {task_item.get_execution_info()}")
191191
return result
192192

@@ -197,13 +197,12 @@ def wrapped_handler(messages: list[ScheduleMessageItem]):
197197
self.status_tracker.task_failed(
198198
task_id=task_item.item_id, user_id=task_item.user_id, error_message=str(e)
199199
)
200-
# Mark task as failed and remove from tracking
201-
with self._task_lock:
202-
if task_item.item_id in self._running_tasks:
203-
task_item.mark_failed(str(e))
204-
del self._running_tasks[task_item.item_id]
205-
if len(self._completed_tasks) > self.completed_tasks_max_show_size:
206-
self._completed_tasks.pop(0)
200+
# Mark task as failed and remove from tracking (LocalQueue only)
201+
if isinstance(self.memos_message_queue, SchedulerLocalQueue):
202+
with self._task_lock:
203+
if task_item.item_id in self._running_tasks:
204+
task_item.mark_failed(str(e))
205+
del self._running_tasks[task_item.item_id]
207206
logger.error(f"Task failed: {task_item.get_execution_info()}, Error: {e}")
208207

209208
raise
@@ -238,10 +237,20 @@ def get_running_tasks(
238237
lambda task: task.user_id == "user123" and task.status == "running"
239238
)
240239
"""
241-
with self._task_lock:
240+
# Use lock only for LocalQueue; otherwise read without lock
241+
if isinstance(self.memos_message_queue, SchedulerLocalQueue):
242+
with self._task_lock:
243+
if filter_func is None:
244+
return self._running_tasks.copy()
245+
246+
return {
247+
task_id: task_item
248+
for task_id, task_item in self._running_tasks.items()
249+
if filter_func(task_item)
250+
}
251+
else:
242252
if filter_func is None:
243253
return self._running_tasks.copy()
244-
245254
return {
246255
task_id: task_item
247256
for task_id, task_item in self._running_tasks.items()
@@ -255,7 +264,11 @@ def get_running_task_count(self) -> int:
255264
Returns:
256265
Number of running tasks
257266
"""
258-
with self._task_lock:
267+
# Use lock only for LocalQueue; otherwise read without lock
268+
if isinstance(self.memos_message_queue, SchedulerLocalQueue):
269+
with self._task_lock:
270+
return len(self._running_tasks)
271+
else:
259272
return len(self._running_tasks)
260273

261274
def register_handler(self, label: str, handler: Callable[[list[ScheduleMessageItem]], None]):
@@ -393,10 +406,6 @@ def dispatch(self, msg_list: list[ScheduleMessageItem]):
393406
messages=msgs,
394407
)
395408

396-
# Track running task for status/monitoring
397-
with self._task_lock:
398-
self._running_tasks[task_item.item_id] = task_item
399-
400409
# Create wrapped handler for task tracking
401410
wrapped_handler = self._create_task_wrapper(handler, task_item)
402411

src/memos/mem_scheduler/task_schedule_modules/local_queue.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ def put(
6161
queue.Full: If the queue is full and block=False or timeout expires.
6262
Exception: Any underlying error during queue.put() operation.
6363
"""
64-
stream_key = self.get_stream_key(user_id=message.user_id, mem_cube_id=message.mem_cube_id)
64+
stream_key = self.get_stream_key(
65+
user_id=message.user_id, mem_cube_id=message.mem_cube_id, task_label=message.task_label
66+
)
6567

6668
message.stream_key = stream_key
6769

src/memos/multi_mem_cube/single_cube.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,8 +375,8 @@ def _search_pref(
375375
"""
376376
if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true":
377377
return []
378-
print(f"search_req.filter for preference memory: {search_req.filter}")
379-
print(f"type of pref_mem: {type(self.naive_mem_cube.pref_mem)}")
378+
logger.info(f"search_req.filter for preference memory: {search_req.filter}")
379+
logger.info(f"type of pref_mem: {type(self.naive_mem_cube.pref_mem)}")
380380
try:
381381
results = self.naive_mem_cube.pref_mem.search(
382382
query=search_req.query,

src/memos/vec_dbs/milvus.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def search(
229229
List of search results with distance scores and payloads.
230230
"""
231231
# Convert filter to Milvus expression
232-
print(f"filter for milvus: {filter}")
232+
logger.info(f"filter for milvus: {filter}")
233233
expr = self._dict_to_expr(filter) if filter else ""
234234

235235
search_func_map = {

0 commit comments

Comments
 (0)