Skip to content

Commit f5e032c

Browse files
author
yuan.wang
committed
merge dev
2 parents 571770b + cbcf33b commit f5e032c

File tree

12 files changed

+498
-121
lines changed

12 files changed

+498
-121
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import time
2+
3+
from memos.api.routers.server_router import mem_scheduler
4+
from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue
5+
6+
7+
queue = mem_scheduler.memos_message_queue.memos_message_queue
8+
9+
10+
def fetch_status(queue: SchedulerRedisQueue) -> dict[str, dict[str, int]]:
11+
"""Fetch and print per-user Redis queue status using built-in API.
12+
13+
Returns a dict mapping user_id -> {"pending": int, "remaining": int}.
14+
"""
15+
# This method will also print a summary and per-user counts.
16+
return queue.show_task_status()
17+
18+
19+
def print_diff(prev: dict[str, dict[str, int]], curr: dict[str, dict[str, int]]) -> None:
20+
"""Print aggregated totals and per-user changes compared to previous snapshot."""
21+
ts = time.strftime("%Y-%m-%d %H:%M:%S")
22+
tot_p_prev = sum(v.get("pending", 0) for v in prev.values()) if prev else 0
23+
tot_r_prev = sum(v.get("remaining", 0) for v in prev.values()) if prev else 0
24+
tot_p_curr = sum(v.get("pending", 0) for v in curr.values())
25+
tot_r_curr = sum(v.get("remaining", 0) for v in curr.values())
26+
27+
dp_tot = tot_p_curr - tot_p_prev
28+
dr_tot = tot_r_curr - tot_r_prev
29+
30+
print(f"[{ts}] Total pending={tot_p_curr} ({dp_tot:+d}), remaining={tot_r_curr} ({dr_tot:+d})")
31+
32+
# Print per-user deltas (current counts are already printed by show_task_status)
33+
all_uids = sorted(set(prev.keys()) | set(curr.keys()))
34+
for uid in all_uids:
35+
p_prev = prev.get(uid, {}).get("pending", 0)
36+
r_prev = prev.get(uid, {}).get("remaining", 0)
37+
p_curr = curr.get(uid, {}).get("pending", 0)
38+
r_curr = curr.get(uid, {}).get("remaining", 0)
39+
dp = p_curr - p_prev
40+
dr = r_curr - r_prev
41+
# Only print when there is any change to reduce noise
42+
if dp != 0 or dr != 0:
43+
print(f" Δ {uid}: pending={dp:+d}, remaining={dr:+d}")
44+
45+
46+
# Note: queue.show_task_status() handles printing per-user counts internally.
47+
48+
49+
def main(interval_sec: float = 5.0) -> None:
50+
prev: dict[str, dict[str, int]] = {}
51+
while True:
52+
try:
53+
curr = fetch_status(queue)
54+
print_diff(prev, curr)
55+
print(f"stream_cache ({len(queue._stream_keys_cache)}): {queue._stream_keys_cache}")
56+
prev = curr
57+
time.sleep(interval_sec)
58+
except KeyboardInterrupt:
59+
print("Stopped.")
60+
break
61+
except Exception as e:
62+
print(f"Error while fetching status: {e}")
63+
time.sleep(interval_sec)
64+
65+
66+
if __name__ == "__main__":
67+
main()

src/memos/llms/openai.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@ def __init__(self, config: OpenAILLMConfig):
2828
)
2929
logger.info("OpenAI LLM instance initialized")
3030

31-
@timed_with_status(log_prefix="OpenAI LLM", log_args=["model_name_or_path"])
31+
@timed_with_status(
32+
log_prefix="OpenAI LLM",
33+
log_extra_args=lambda self, messages, **kwargs: {
34+
"model_name_or_path": kwargs.get("model_name_or_path", self.config.model_name_or_path)
35+
},
36+
)
3237
def generate(self, messages: MessageList, **kwargs) -> str:
3338
"""Generate a response from OpenAI LLM, optionally overriding generation params."""
3439
response = self.client.chat.completions.create(
@@ -55,7 +60,12 @@ def generate(self, messages: MessageList, **kwargs) -> str:
5560
return reasoning_content + response_content
5661
return response_content
5762

58-
@timed_with_status(log_prefix="OpenAI LLM", log_args=["model_name_or_path"])
63+
@timed_with_status(
64+
log_prefix="OpenAI LLM",
65+
log_extra_args=lambda self, messages, **kwargs: {
66+
"model_name_or_path": self.config.model_name_or_path
67+
},
68+
)
5969
def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]:
6070
"""Stream response from OpenAI LLM with optional reasoning support."""
6171
if kwargs.get("tools"):

src/memos/mem_reader/simple_struct.py

Lines changed: 36 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def get_memory(
453453
@staticmethod
454454
def _parse_hallucination_filter_response(text: str) -> tuple[bool, dict[int, dict]]:
455455
"""Parse index-keyed JSON from hallucination filter response.
456-
Expected shape: { "0": {"delete": bool, "rewritten": str, "reason": str}, ... }
456+
Expected shape: { "0": {"need_rewrite": bool, "rewritten_suffix": str, "reason": str}, ... }
457457
Returns (success, parsed_dict) with int keys.
458458
"""
459459
try:
@@ -476,27 +476,33 @@ def _parse_hallucination_filter_response(text: str) -> tuple[bool, dict[int, dic
476476
continue
477477
if not isinstance(v, dict):
478478
continue
479-
delete_flag = v.get("delete")
480-
rewritten = v.get("rewritten", "")
479+
need_rewrite = v.get("need_rewrite")
480+
rewritten_suffix = v.get("rewritten_suffix", "")
481481
reason = v.get("reason", "")
482482
if (
483-
isinstance(delete_flag, bool)
484-
and isinstance(rewritten, str)
483+
isinstance(need_rewrite, bool)
484+
and isinstance(rewritten_suffix, str)
485485
and isinstance(reason, str)
486486
):
487-
result[idx] = {"delete": delete_flag, "rewritten": rewritten, "reason": reason}
487+
result[idx] = {
488+
"need_rewrite": need_rewrite,
489+
"rewritten_suffix": rewritten_suffix,
490+
"reason": reason,
491+
}
488492

489493
return (len(result) > 0), result
490494

491495
def filter_hallucination_in_memories(
492-
self, user_messages: list[str], memory_list: list[TextualMemoryItem]
496+
self, messages: list[dict], memory_list: list[TextualMemoryItem]
493497
) -> list[TextualMemoryItem]:
494-
flat_memories = [one.memory for one in memory_list]
498+
# Build input objects with memory text and metadata (timestamps, sources, etc.)
495499
template = PROMPT_MAPPING["hallucination_filter"]
496500
prompt_args = {
497-
"user_messages_inline": "\n".join([f"- {memory}" for memory in user_messages]),
501+
"messages_inline": "\n".join(
502+
[f"- [{message['role']}]: {message['content']}" for message in messages]
503+
),
498504
"memories_inline": json.dumps(
499-
{str(i): memory for i, memory in enumerate(flat_memories)},
505+
{idx: mem.memory for idx, mem in enumerate(memory_list)},
500506
ensure_ascii=False,
501507
indent=2,
502508
),
@@ -511,40 +517,31 @@ def filter_hallucination_in_memories(
511517
f"[filter_hallucination_in_memories] Hallucination filter parsed successfully: {success}"
512518
)
513519
if success:
520+
new_mem_list = []
514521
logger.info(f"Hallucination filter result: {parsed}")
515-
total = len(memory_list)
516-
keep_flags = [True] * total
522+
assert len(parsed) == len(memory_list)
517523
for mem_idx, content in parsed.items():
518-
# Validate index bounds
519-
if not isinstance(mem_idx, int) or mem_idx < 0 or mem_idx >= total:
520-
logger.warning(
521-
f"[filter_hallucination_in_memories] Ignoring out-of-range index: {mem_idx}"
522-
)
523-
continue
524-
525-
delete_flag = content.get("delete", False)
526-
rewritten = content.get("rewritten", None)
524+
need_rewrite = content.get("need_rewrite", False)
525+
rewritten_suffix = content.get("rewritten_suffix", "")
527526
reason = content.get("reason", "")
528527

529-
logger.info(
530-
f"[filter_hallucination_in_memories] index={mem_idx}, delete={delete_flag}, rewritten='{(rewritten or '')[:100]}', reason='{reason[:120]}'"
531-
)
528+
# Append a new memory item instead of replacing the original
529+
if (
530+
need_rewrite
531+
and isinstance(rewritten_suffix, str)
532+
and len(rewritten_suffix.strip()) > 0
533+
):
534+
original_text = memory_list[mem_idx].memory
535+
536+
logger.info(
537+
f"[filter_hallucination_in_memories] index={mem_idx}, need_rewrite={need_rewrite}, rewritten_suffix='{rewritten_suffix}', reason='{reason}', original memory='{original_text}', action='append_suffix'"
538+
)
532539

533-
if delete_flag is True and rewritten is not None:
534-
# Mark for deletion
535-
keep_flags[mem_idx] = False
540+
# Append only the suffix to the original memory text
541+
memory_list[mem_idx].memory = original_text + rewritten_suffix
542+
new_mem_list.append(memory_list[mem_idx])
536543
else:
537-
# Apply rewrite if provided (safe-by-default: keep item when not mentioned or delete=False)
538-
try:
539-
if isinstance(rewritten, str):
540-
memory_list[mem_idx].memory = rewritten
541-
except Exception as e:
542-
logger.warning(
543-
f"[filter_hallucination_in_memories] Failed to apply rewrite for index {mem_idx}: {e}"
544-
)
545-
546-
# Build result, preserving original order; keep items not mentioned by LLM by default
547-
new_mem_list = [memory_list[i] for i in range(total) if keep_flags[i]]
544+
new_mem_list.append(memory_list[mem_idx])
548545
return new_mem_list
549546
else:
550547
logger.warning("Hallucination filter parsing failed or returned empty result.")
@@ -602,11 +599,8 @@ def _read_memory(
602599
# Build inputs
603600
new_memory_list = []
604601
for unit_messages, unit_memory_list in zip(messages, memory_list, strict=False):
605-
unit_user_messages = [
606-
msg["content"] for msg in unit_messages if msg["role"] == "user"
607-
]
608602
unit_memory_list = self.filter_hallucination_in_memories(
609-
user_messages=unit_user_messages, memory_list=unit_memory_list
603+
messages=unit_messages, memory_list=unit_memory_list
610604
)
611605
new_memory_list.append(unit_memory_list)
612606
memory_list = new_memory_list

src/memos/mem_scheduler/general_scheduler.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,10 @@ def long_memory_update_process(
126126
top_k=self.top_k,
127127
)
128128
logger.info(
129-
f"[long_memory_update_process] Processed {len(queries)} queries {queries} and retrieved {len(new_candidates)} new candidate memories for user_id={user_id}"
129+
# Build the candidate preview string outside the f-string to avoid backslashes in expression
130+
f"[long_memory_update_process] Processed {len(queries)} queries {queries} and retrieved {len(new_candidates)} "
131+
f"new candidate memories for user_id={user_id}: "
132+
+ ("\n- " + "\n- ".join([f"{one.id}: {one.memory}" for one in new_candidates]))
130133
)
131134

132135
# rerank
@@ -141,10 +144,14 @@ def long_memory_update_process(
141144
f"[long_memory_update_process] Final working memory size: {len(new_order_working_memory)} memories for user_id={user_id}"
142145
)
143146

144-
old_memory_texts = [mem.memory for mem in cur_working_memory]
145-
new_memory_texts = [mem.memory for mem in new_order_working_memory]
147+
old_memory_texts = "\n- " + "\n- ".join(
148+
[f"{one.id}: {one.memory}" for one in cur_working_memory]
149+
)
150+
new_memory_texts = "\n- " + "\n- ".join(
151+
[f"{one.id}: {one.memory}" for one in new_order_working_memory]
152+
)
146153

147-
logger.debug(
154+
logger.info(
148155
f"[long_memory_update_process] For user_id='{user_id}', mem_cube_id='{mem_cube_id}': "
149156
f"Scheduler replaced working memory based on query history {queries}. "
150157
f"Old working memory ({len(old_memory_texts)} items): {old_memory_texts}. "
@@ -1413,20 +1420,21 @@ def process_session_turn(
14131420
logger.info(
14141421
f"[process_session_turn] Searching for missing evidence: '{item}' with top_k={k_per_evidence} for user_id={user_id}"
14151422
)
1416-
info = {
1417-
"user_id": user_id,
1418-
"session_id": "",
1419-
}
14201423

1424+
search_args = {}
14211425
results: list[TextualMemoryItem] = self.retriever.search(
14221426
query=item,
1427+
user_id=user_id,
1428+
mem_cube_id=mem_cube_id,
14231429
mem_cube=mem_cube,
14241430
top_k=k_per_evidence,
14251431
method=self.search_method,
1426-
info=info,
1432+
search_args=search_args,
14271433
)
1434+
14281435
logger.info(
1429-
f"[process_session_turn] Search results for missing evidence '{item}': {[one.memory for one in results]}"
1436+
f"[process_session_turn] Search results for missing evidence '{item}': "
1437+
+ ("\n- " + "\n- ".join([f"{one.id}: {one.memory}" for one in results]))
14301438
)
14311439
new_candidates.extend(results)
14321440
return cur_working_memory, new_candidates

src/memos/mem_scheduler/memory_manage_modules/retriever.py

Lines changed: 64 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@
2222
from memos.mem_scheduler.utils.misc_utils import extract_json_obj, extract_list_items_in_answer
2323
from memos.memories.textual.item import TextualMemoryMetadata
2424
from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory
25-
from memos.types.general_types import FINE_STRATEGY, FineStrategy
25+
from memos.types.general_types import (
26+
FINE_STRATEGY,
27+
FineStrategy,
28+
SearchMode,
29+
)
2630

2731
# Extract JSON response
2832
from .memory_filter import MemoryFilter
@@ -237,10 +241,12 @@ def recall_for_missing_memories(
237241
def search(
238242
self,
239243
query: str,
244+
user_id: str,
245+
mem_cube_id: str,
240246
mem_cube: GeneralMemCube,
241247
top_k: int,
242248
method: str = TreeTextMemory_SEARCH_METHOD,
243-
info: dict | None = None,
249+
search_args: dict | None = None,
244250
) -> list[TextualMemoryItem]:
245251
"""Search in text memory with the given query.
246252
@@ -253,22 +259,67 @@ def search(
253259
Search results or None if not implemented
254260
"""
255261
text_mem_base = mem_cube.text_mem
262+
# Normalize default for mutable argument
263+
search_args = search_args or {}
256264
try:
257265
if method in [TreeTextMemory_SEARCH_METHOD, TreeTextMemory_FINE_SEARCH_METHOD]:
258266
assert isinstance(text_mem_base, TreeTextMemory)
259-
if info is None:
260-
logger.warning(
261-
"Please input 'info' when use tree.search so that "
262-
"the database would store the consume history."
263-
)
264-
info = {"user_id": "", "session_id": ""}
267+
session_id = search_args.get("session_id", "default_session")
268+
target_session_id = session_id
269+
search_priority = (
270+
{"session_id": target_session_id} if "session_id" in search_args else None
271+
)
272+
search_filter = search_args.get("filter")
273+
search_source = search_args.get("source")
274+
plugin = bool(search_source is not None and search_source == "plugin")
275+
user_name = search_args.get("user_name", mem_cube_id)
276+
internet_search = search_args.get("internet_search", False)
277+
chat_history = search_args.get("chat_history")
278+
search_tool_memory = search_args.get("search_tool_memory", False)
279+
tool_mem_top_k = search_args.get("tool_mem_top_k", 6)
280+
playground_search_goal_parser = search_args.get(
281+
"playground_search_goal_parser", False
282+
)
265283

266-
mode = "fast" if method == TreeTextMemory_SEARCH_METHOD else "fine"
267-
results_long_term = text_mem_base.search(
268-
query=query, top_k=top_k, memory_type="LongTermMemory", mode=mode, info=info
284+
info = search_args.get(
285+
"info",
286+
{
287+
"user_id": user_id,
288+
"session_id": target_session_id,
289+
"chat_history": chat_history,
290+
},
269291
)
270-
results_user = text_mem_base.search(
271-
query=query, top_k=top_k, memory_type="UserMemory", mode=mode, info=info
292+
293+
results_long_term = mem_cube.text_mem.search(
294+
query=query,
295+
user_name=user_name,
296+
top_k=top_k,
297+
mode=SearchMode.FAST,
298+
manual_close_internet=not internet_search,
299+
memory_type="LongTermMemory",
300+
search_filter=search_filter,
301+
search_priority=search_priority,
302+
info=info,
303+
plugin=plugin,
304+
search_tool_memory=search_tool_memory,
305+
tool_mem_top_k=tool_mem_top_k,
306+
playground_search_goal_parser=playground_search_goal_parser,
307+
)
308+
309+
results_user = mem_cube.text_mem.search(
310+
query=query,
311+
user_name=user_name,
312+
top_k=top_k,
313+
mode=SearchMode.FAST,
314+
manual_close_internet=not internet_search,
315+
memory_type="UserMemory",
316+
search_filter=search_filter,
317+
search_priority=search_priority,
318+
info=info,
319+
plugin=plugin,
320+
search_tool_memory=search_tool_memory,
321+
tool_mem_top_k=tool_mem_top_k,
322+
playground_search_goal_parser=playground_search_goal_parser,
272323
)
273324
results = results_long_term + results_user
274325
else:

0 commit comments

Comments
 (0)