diff --git a/examples/basic_modules/tree_textual_memory_relation_reason_detector.py b/examples/basic_modules/tree_textual_memory_relation_reason_detector.py index 72e4deb60..294f8973a 100644 --- a/examples/basic_modules/tree_textual_memory_relation_reason_detector.py +++ b/examples/basic_modules/tree_textual_memory_relation_reason_detector.py @@ -27,6 +27,8 @@ ) embedder = EmbedderFactory.from_config(embedder_config) +user_name = "lucy4" + # === Step 2: Initialize Neo4j GraphStore === graph_config = GraphDBConfigFactory( backend="neo4j", @@ -34,7 +36,7 @@ "uri": "bolt://localhost:7687", "user": "neo4j", "password": "12345678", - "db_name": "lucy4", + "db_name": user_name, "auto_create": True, }, ) @@ -178,6 +180,7 @@ results = relation_detector.process_node( node=node, + user_name=user_name, exclude_ids=[node.id], # Exclude self when searching for neighbors top_k=5, ) diff --git a/examples/core_memories/tree_textual_memory.py b/examples/core_memories/tree_textual_memory.py index d2e197e5b..0db3af196 100644 --- a/examples/core_memories/tree_textual_memory.py +++ b/examples/core_memories/tree_textual_memory.py @@ -172,7 +172,8 @@ added_ids = my_tree_textual_memory.add(m_list) for i, id in enumerate(added_ids): print(f"{i}'th added result is:" + my_tree_textual_memory.get(id).memory) - my_tree_textual_memory.memory_manager.wait_reorganizer() + # wait the synchronous thread + # TODO: USE SCHEDULE MODULE TO WAIT time.sleep(60) @@ -233,7 +234,8 @@ for m_list in doc_memory: added_ids = my_tree_textual_memory.add(m_list) - my_tree_textual_memory.memory_manager.wait_reorganizer() + # wait the synchronous thread + # TODO: USE SCHEDULE MODULE TO WAIT results = my_tree_textual_memory.search( "Tell me about what memos consist of?", @@ -245,9 +247,10 @@ print(f"{i}'th similar result is: " + str(r["memory"])) print(f"Successfully search {len(results)} memories") -# close the synchronous thread in memory manager -my_tree_textual_memory.memory_manager.close() +# close the synchronous thread +# TODO: USE SCHEDULE MODULE TO CLOSE # my_tree_textual_memory.dump +# Note that you cannot drop this tree when`use_multi_db` == +# false. my_tree_textual_memory.drop() """ my_tree_textual_memory.dump("tmp/my_tree_textual_memory") -my_tree_textual_memory.drop() diff --git a/src/memos/graph_dbs/base.py b/src/memos/graph_dbs/base.py index b26db5afa..ba1611cbf 100644 --- a/src/memos/graph_dbs/base.py +++ b/src/memos/graph_dbs/base.py @@ -70,7 +70,7 @@ def edge_exists(self, source_id: str, target_id: str, type: str) -> bool: # Graph Query & Reasoning @abstractmethod - def get_node(self, id: str, include_embedding: bool = False) -> dict[str, Any] | None: + def get_node(self, id: str, include_embedding: bool = False, **kwargs) -> dict[str, Any] | None: """ Retrieve the metadata and content of a node. Args: diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 12b493e58..dd810c9bc 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -1174,7 +1174,7 @@ def get_grouped_counts( group_by_fields.append(alias) # Full GQL query construction gql = f""" - MATCH (n /*+ INDEX(idx_memory_user_name) */) + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) {where_clause} RETURN {", ".join(return_fields)}, COUNT(n) AS count """ @@ -1381,31 +1381,55 @@ def get_structure_optimization_candidates( where_clause += f' AND n.user_name = "{user_name}"' return_fields = self._build_return_fields(include_embedding) - return_fields += f", n.{self.dim_field} AS {self.dim_field}" - query = f""" + gql = f""" MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE {where_clause} - OPTIONAL MATCH (n)-[@PARENT]->(c@Memory) - OPTIONAL MATCH (p@Memory)-[@PARENT]->(n) - WHERE c IS NULL AND p IS NULL - RETURN {return_fields} + OPTIONAL MATCH (n)-[@PARENT]->(c@Memory {{user_name: "{user_name}"}}) + OPTIONAL MATCH (p@Memory {{user_name: "{user_name}"}})-[@PARENT]->(n) + RETURN {return_fields}, + c.id AS child_id, + p.id AS parent_id """ - candidates = [] - node_ids = set() + per_node_seen_has_child_or_parent: dict[str, bool] = {} + per_node_payload: dict[str, dict] = {} + try: - results = self.execute_query(query) - for row in results: - props = {k: v.value for k, v in row.items()} - node = self._parse_node(props) - node_id = node["id"] - if node_id not in node_ids: - candidates.append(node) - node_ids.add(node_id) + results = self.execute_query(gql) except Exception as e: - logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}") - return candidates + logger.error( + f"[get_structure_optimization_candidates] Query failed: {e}, " + f"traceback: {traceback.format_exc()}" + ) + return [] + + for row in results: + props = {k: v.value for k, v in row.items() if k not in ("child_id", "parent_id")} + node = self._parse_node(props) + nid = node["id"] + + if nid not in per_node_payload: + per_node_payload[nid] = node + per_node_seen_has_child_or_parent[nid] = False + + child_val = row.get("child_id") + parent_val = row.get("parent_id") + + child_unwrapped = self._parse_value(child_val) if (child_val is not None) else None + parent_unwrapped = self._parse_value(parent_val) if (parent_val is not None) else None + + if child_unwrapped: + per_node_seen_has_child_or_parent[nid] = True + if parent_unwrapped: + per_node_seen_has_child_or_parent[nid] = True + + isolated_nodes: list[dict] = [] + for nid, node_obj in per_node_payload.items(): + if not per_node_seen_has_child_or_parent[nid]: + isolated_nodes.append(node_obj) + + return isolated_nodes @timed def drop_database(self) -> None: @@ -1450,7 +1474,7 @@ def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: @timed def get_neighbors( - self, id: str, type: str, direction: Literal["in", "out", "both"] = "out" + self, id: str, type: str, direction: Literal["in", "out", "both"] = "both" ) -> list[str]: """ Get connected node IDs in a specific direction and relationship type. @@ -1461,7 +1485,70 @@ def get_neighbors( Returns: List of neighboring node IDs. """ - raise NotImplementedError + if direction not in ("in", "out", "both"): + raise ValueError(f"Unsupported direction: {direction}") + + user_name = self.config.user_name + id_val = self._format_value(id) # e.g. '"5225-uuid..."' + user_val = self._format_value(user_name) # e.g. '"lme_user_1"' + edge_type = type # assume caller passes valid edge tag + + def _run_out_query() -> list[str]: + # out: (this)-[edge_type]->(dst) + gql = f""" + MATCH (src@Memory {{id: {id_val}, user_name: {user_val}}}) + -[r@{edge_type}]-> + (dst@Memory {{user_name: {user_val}}}) + RETURN DISTINCT dst.id AS neighbor + """.strip() + try: + result = self.execute_query(gql) + except Exception as e: + logger.error(f"[get_neighbors][out] Query failed: {e}, gql={gql}") + return [] + + out_ids = [] + try: + for row in result: + out_ids.append(row["neighbor"].value) + except Exception as e: + logger.error(f"[get_neighbors][out] Parse failed: {e}") + return out_ids + + def _run_in_query() -> list[str]: + # in: (src)-[edge_type]->(this) + gql = f""" + MATCH (src@Memory {{user_name: {user_val}}}) + -[r@{edge_type}]-> + (dst@Memory {{id: {id_val}, user_name: {user_val}}}) + RETURN DISTINCT src.id AS neighbor + """.strip() + try: + result = self.execute_query(gql) + except Exception as e: + logger.error(f"[get_neighbors][in] Query failed: {e}, gql={gql}") + return [] + + in_ids = [] + try: + for row in result: + in_ids.append(row["neighbor"].value) + except Exception as e: + logger.error(f"[get_neighbors][in] Parse failed: {e}") + return in_ids + + if direction == "out": + return list(set(_run_out_query())) + elif direction == "in": + return list(set(_run_in_query())) + else: # direction == "both" + out_ids = _run_out_query() + in_ids = _run_in_query() + merged = set(out_ids) + merged.update(in_ids) + if id in merged: + merged.remove(id) + return list(merged) @timed def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]: diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index ec8a673d7..736b04b74 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -18,6 +18,7 @@ from memos.mem_scheduler.schemas.general_schemas import ( ADD_LABEL, ANSWER_LABEL, + MEM_ORGANIZE_LABEL, MEM_READ_LABEL, PREF_ADD_LABEL, QUERY_LABEL, @@ -166,25 +167,6 @@ def mem_scheduler_off(self) -> bool: logger.error(f"Failed to stop scheduler: {e!s}") return False - def mem_reorganizer_on(self) -> bool: - pass - - def mem_reorganizer_off(self) -> bool: - """temporally implement""" - for mem_cube in self.mem_cubes.values(): - logger.info(f"try to close reorganizer for {mem_cube.text_mem.config.cube_id}") - if mem_cube.text_mem and mem_cube.text_mem.is_reorganize: - logger.info(f"close reorganizer for {mem_cube.text_mem.config.cube_id}") - mem_cube.text_mem.memory_manager.close() - mem_cube.text_mem.memory_manager.wait_reorganizer() - - def mem_reorganizer_wait(self) -> bool: - for mem_cube in self.mem_cubes.values(): - logger.info(f"try to close reorganizer for {mem_cube.text_mem.config.cube_id}") - if mem_cube.text_mem and mem_cube.text_mem.is_reorganize: - logger.info(f"close reorganizer for {mem_cube.text_mem.config.cube_id}") - mem_cube.text_mem.memory_manager.wait_reorganizer() - def _register_chat_history( self, user_id: str | None = None, session_id: str | None = None ) -> None: @@ -727,9 +709,12 @@ def add( f"time add: get mem_cube_id time user_id: {target_user_id} time is: {time.time() - time_start}" ) + time_start_0 = time.time() if mem_cube_id not in self.mem_cubes: raise ValueError(f"MemCube '{mem_cube_id}' is not loaded. Please register.") - + logger.info( + f"time add: get mem_cube_id check in mem_cubes time user_id: {target_user_id} time is: {time.time() - time_start_0}" + ) sync_mode = self.mem_cubes[mem_cube_id].text_mem.mode if sync_mode == "async": assert self.mem_scheduler is not None, ( @@ -779,16 +764,25 @@ def process_textual_memory(): timestamp=datetime.utcnow(), ) self.mem_scheduler.submit_messages(messages=[message_item]) - - message_item = ScheduleMessageItem( - user_id=target_user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - label=ADD_LABEL, - content=json.dumps(mem_ids), - timestamp=datetime.utcnow(), - ) - self.mem_scheduler.submit_messages(messages=[message_item]) + elif sync_mode == "sync": + message_item = ScheduleMessageItem( + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + label=MEM_ORGANIZE_LABEL, + content=json.dumps(mem_ids), + timestamp=datetime.utcnow(), + ) + self.mem_scheduler.submit_messages(messages=[message_item]) + message_item = ScheduleMessageItem( + user_id=target_user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + label=ADD_LABEL, + content=json.dumps(mem_ids), + timestamp=datetime.utcnow(), + ) + self.mem_scheduler.submit_messages(messages=[message_item]) def process_preference_memory(): if ( diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py index fed8f7278..d814d82c4 100644 --- a/src/memos/mem_os/product.py +++ b/src/memos/mem_os/product.py @@ -3,6 +3,7 @@ import os import random import time +import traceback from collections.abc import Generator from datetime import datetime @@ -215,7 +216,7 @@ def _restore_user_instances( logger.error(f"Failed to restore user configuration for {user_id}: {e}") except Exception as e: - logger.error(f"Error during user instance restoration: {e}") + logger.error(f"Error during user instance restoration: {e}: {traceback.print_exc()}") def _initialize_cube_from_default_config( self, cube_id: str, user_id: str, default_config: GeneralMemCubeConfig diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 0f74adead..fe6249027 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -51,7 +51,8 @@ _ENC = tiktoken.get_encoding("cl100k_base") def _count_tokens_text(s: str) -> int: - return len(_ENC.encode(s or "")) + # allow special tokens like <|endoftext|> instead of raising ValueError + return len(_ENC.encode(s or "", disallowed_special=())) except Exception: # Heuristic fallback: zh chars ~1 token, others ~1 token per ~4 chars def _count_tokens_text(s: str) -> int: diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index e475ea225..767301256 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -98,6 +98,9 @@ def __init__(self, config: BaseSchedulerConfig): self.max_internal_message_queue_size = self.config.get( "max_internal_message_queue_size", DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE ) + self.memos_message_queue: Queue[ScheduleMessageItem] = Queue( + maxsize=self.max_internal_message_queue_size + ) # Initialize message queue based on configuration if self.use_redis_queue: diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index d84ebb242..124206565 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -1,7 +1,10 @@ import concurrent.futures import json +import threading import traceback +from datetime import datetime + from memos.configs.mem_scheduler import GeneralSchedulerConfig from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube @@ -47,6 +50,11 @@ def __init__(self, config: GeneralSchedulerConfig): } self.dispatcher.register_handlers(handlers) + # Lazy-initialize reorganize state only if organize handler is enabled + if handlers.get(MEM_ORGANIZE_LABEL): + self._reorg_state = {} + self._reorg_locks = {} + def long_memory_update_process( self, user_id: str, mem_cube_id: str, messages: list[ScheduleMessageItem] ): @@ -347,6 +355,8 @@ def _process_memories_with_reader( flattened_memories = [] for memory_list in processed_memories: flattened_memories.extend(memory_list) + for mem in memory_list: + logger.debug(f"Add Processed Mem Reader Mem: {mem.id}: {mem.memory}") logger.info(f"mem_reader processed {len(flattened_memories)} enhanced memories") @@ -356,6 +366,29 @@ def _process_memories_with_reader( logger.info( f"Added {len(enhanced_mem_ids)} enhanced memories: {enhanced_mem_ids}" ) + # Trigger organize only when we really added new nodes + try: + if "handlers" in dir(self.dispatcher) and MEM_ORGANIZE_LABEL not in getattr( + self.dispatcher, "handlers", {} + ): + # Dispatcher exists but organize not enabled; skip enqueue. + pass + elif not getattr( + text_mem.memory_manager.reorganizer, "is_reorganize", True + ): + pass + else: + message_item = ScheduleMessageItem( + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + label=MEM_ORGANIZE_LABEL, + content=json.dumps(enhanced_mem_ids), + timestamp=datetime.utcnow(), + ) + self.submit_messages(messages=[message_item]) + except Exception as e: + logger.error(f"Failed to enqueue MEM_ORGANIZE task: {e}", exc_info=True) else: logger.info("No enhanced memories generated by mem_reader") else: @@ -373,53 +406,132 @@ def _process_memories_with_reader( ) def _mem_reorganize_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: - logger.info(f"Messages {messages} assigned to {MEM_READ_LABEL} handler.") + logger.info(f"Messages {messages} assigned to {MEM_ORGANIZE_LABEL} handler.") - def process_message(message: ScheduleMessageItem): + # Group by cube; we only trigger once per cube per batch. + grouped_by_cube: dict[str, GeneralMemCube] = {} + for msg in messages: try: - user_id = message.user_id - mem_cube_id = message.mem_cube_id - mem_cube = message.mem_cube - content = message.content + if msg.mem_cube_id and msg.mem_cube: + grouped_by_cube[msg.mem_cube_id] = msg.mem_cube + except Exception: + continue - # Parse the memory IDs from content - mem_ids = json.loads(content) if isinstance(content, str) else content - if not mem_ids: - return + if not grouped_by_cube: + logger.debug("[Reorganize] No valid mem_cube in messages; skip.") + return - logger.info( - f"Processing mem_read for user_id={user_id}, mem_cube_id={mem_cube_id}, mem_ids={mem_ids}" + # Fire reorganize in parallel across different cubes; each cube is single-flight via lock. + with concurrent.futures.ThreadPoolExecutor( + max_workers=min(8, len(grouped_by_cube)) + ) as executor: + futures = [] + for mem_cube_id, mem_cube in grouped_by_cube.items(): + futures.append( + executor.submit(self._run_reorganize_singleflight, mem_cube, mem_cube_id, None) ) - # Get the text memory from the mem_cube - text_mem = mem_cube.text_mem - if not isinstance(text_mem, TreeTextMemory): - logger.error(f"Expected TreeTextMemory but got {type(text_mem).__name__}") - return + for f in concurrent.futures.as_completed(futures): + try: + f.result() + except Exception as e: + logger.error(f"[Reorganize] Task failed: {e}", exc_info=True) - # Use mem_reader to process the memories - self._process_memories_with_reorganize( - mem_ids=mem_ids, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - text_mem=text_mem, + def _get_reorg_lock(self, mem_cube_id: str) -> threading.Lock: + """ + Return a per-cube lock; lazily create it only when MEM_ORGANIZE is enabled. + """ + # If organize handler is disabled, this dict may not exist; guard it. + if not hasattr(self, "_reorg_locks"): + self._reorg_locks = {} + lock = self._reorg_locks.get(mem_cube_id) + if lock is None: + lock = threading.Lock() + self._reorg_locks[mem_cube_id] = lock + return lock + + def _get_reorg_state(self, mem_cube_id: str): + if not hasattr(self, "_reorg_state"): + self._reorg_state = {} + st = self._reorg_state.get(mem_cube_id) + if st is None: + st = {"running": False, "rerun_requested": False} + self._reorg_state[mem_cube_id] = st + return st + + def _run_reorganize_singleflight( + self, + mem_cube: GeneralMemCube, + mem_cube_id: str, + scopes: list[str] | None = None, + ) -> None: + """ + Run one reorganize pass for a mem_cube ensuring single-flight per cube. + If `scopes` is None, run both LongTermMemory and UserMemory (safe default). + """ + lock = self._get_reorg_lock(mem_cube_id) + state = self._get_reorg_state(mem_cube_id) + + with lock: + if state["running"]: + state["rerun_requested"] = True + print( + f"[Reorganize] Already running for {mem_cube_id}; mark trailing rerun and skip." ) + return + state["running"] = True + state["rerun_requested"] = False - logger.info( - f"Successfully processed mem_read for user_id={user_id}, mem_cube_id={mem_cube_id}" - ) + try: + print("state is not running.. start to run!") + # ===== Run First Turn ===== + self._run_reorganize_once(mem_cube, mem_cube_id, scopes) + + # ===== Run Trailing Turn ===== + do_trailing = False + with lock: + if state["rerun_requested"]: + state["rerun_requested"] = False + do_trailing = True + + if do_trailing: + logger.info(f"[Reorganize] Running single trailing pass for {mem_cube_id}.") + self._run_reorganize_once(mem_cube, mem_cube_id, scopes) + + finally: + with lock: + state["running"] = False + + def _run_reorganize_once( + self, mem_cube: GeneralMemCube, mem_cube_id: str, scopes: list[str] | None + ): + print(f"[Reorganize] Acquired lock for mem_cube_id={mem_cube_id}; starting reorganize.") + text_mem = mem_cube.text_mem + if not isinstance(text_mem, TreeTextMemory): + logger.error( + f"[Reorganize] Expected TreeTextMemory but got {type(text_mem).__name__} for mem_cube_id={mem_cube_id}" + ) + return - except Exception as e: - logger.error(f"Error processing mem_read message: {e}", exc_info=True) + reorganizer = text_mem.memory_manager.reorganizer + if not reorganizer or not getattr(reorganizer, "is_reorganize", True): + logger.debug( + f"[Reorganize] Reorganizer disabled or missing for mem_cube_id={mem_cube_id}; skip." + ) + return - with concurrent.futures.ThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: - futures = [executor.submit(process_message, msg) for msg in messages] - for future in concurrent.futures.as_completed(futures): - try: - future.result() - except Exception as e: - logger.error(f"Thread task failed: {e}", exc_info=True) + run_scopes = scopes or ["LongTermMemory", "UserMemory"] + for scope in run_scopes: + logger.info( + f"[Reorganize] Start optimize_structure(scope={scope}) for mem_cube_id={mem_cube_id}" + ) + try: + reorganizer.optimize_structure(scope=scope) + except Exception as e: + logger.warning( + f"[Reorganize] optimize_structure failed for scope={scope}, mem_cube_id={mem_cube_id}: {e}", + exc_info=True, + ) def _process_memories_with_reorganize( self, diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index f6254efbb..ca4712814 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -42,6 +42,11 @@ class SourceMessage(BaseModel): model_config = ConfigDict(extra="allow") + @property + def content_safe(self) -> str: + """Always return a string, fallback to '' if content is None.""" + return self.content or "" + class TextualMemoryMetadata(BaseModel): """Metadata for a memory item. diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 54776134b..776f71765 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -9,11 +9,8 @@ from memos.graph_dbs.neo4j import Neo4jGraphDB from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM from memos.log import get_logger -from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata -from memos.memories.textual.tree_text_memory.organize.reorganizer import ( - GraphStructureReorganizer, - QueueMessage, -) +from memos.memories.textual.item import TextualMemoryItem +from memos.memories.textual.tree_text_memory.organize.reorganizer import GraphStructureReorganizer logger = get_logger(__name__) @@ -27,7 +24,6 @@ def __init__( llm: OpenAILLM | OllamaLLM | AzureLLM, memory_size: dict | None = None, threshold: float | None = 0.80, - merged_threshold: float | None = 0.92, is_reorganize: bool = False, ): self.graph_store = graph_store @@ -50,7 +46,6 @@ def __init__( self.reorganizer = GraphStructureReorganizer( graph_store, llm, embedder, is_reorganize=is_reorganize ) - self._merged_threshold = merged_threshold def add( self, memories: list[TextualMemoryItem], user_name: str | None = None, mode: str = "sync" @@ -194,92 +189,6 @@ def _add_to_graph_memory( memory.metadata.model_dump(exclude_none=True), user_name=user_name, ) - self.reorganizer.add_message( - QueueMessage( - op="add", - after_node=[node_id], - ) - ) - return node_id - - def _inherit_edges(self, from_id: str, to_id: str) -> None: - """ - Migrate all non-lineage edges from `from_id` to `to_id`, - and remove them from `from_id` after copying. - """ - edges = self.graph_store.get_edges(from_id, type="ANY", direction="ANY") - - for edge in edges: - if edge["type"] == "MERGED_TO": - continue # Keep lineage edges - - new_from = to_id if edge["from"] == from_id else edge["from"] - new_to = to_id if edge["to"] == from_id else edge["to"] - - if new_from == new_to: - continue - - # Add edge to merged node if it doesn't already exist - if not self.graph_store.edge_exists(new_from, new_to, edge["type"], direction="ANY"): - self.graph_store.add_edge(new_from, new_to, edge["type"]) - - # Remove original edge if it involved the archived node - self.graph_store.delete_edge(edge["from"], edge["to"], edge["type"]) - - def _ensure_structure_path( - self, memory_type: str, metadata: TreeNodeTextualMemoryMetadata - ) -> str: - """ - Ensure structural path exists (ROOT → ... → final node), return last node ID. - - Args: - path: like ["hobby", "photography"] - - Returns: - Final node ID of the structure path. - """ - # Step 1: Try to find an existing memory node with content == tag - existing = self.graph_store.get_by_metadata( - [ - {"field": "memory", "op": "=", "value": metadata.key}, - {"field": "memory_type", "op": "=", "value": memory_type}, - ] - ) - if existing: - node_id = existing[0] # Use the first match - else: - # Step 2: If not found, create a new structure node - new_node = TextualMemoryItem( - memory=metadata.key, - metadata=TreeNodeTextualMemoryMetadata( - user_id=metadata.user_id, - session_id=metadata.session_id, - memory_type=memory_type, - status="activated", - tags=[], - key=metadata.key, - embedding=self.embedder.embed([metadata.key])[0], - usage=[], - sources=[], - confidence=0.99, - background="", - ), - ) - self.graph_store.add_node( - id=new_node.id, - memory=new_node.memory, - metadata=new_node.metadata.model_dump(exclude_none=True), - ) - self.reorganizer.add_message( - QueueMessage( - op="add", - after_node=[new_node.id], - ) - ) - - node_id = new_node.id - - # Step 3: Return this structure node ID as the parent_id return node_id def remove_and_refresh_memory(self): @@ -306,17 +215,3 @@ def _cleanup_memories_if_needed(self) -> None: logger.debug(f"Cleaned up {memory_type}: {current_count} -> {limit}") except Exception: logger.warning(f"Remove {memory_type} error: {traceback.format_exc()}") - - def wait_reorganizer(self): - """ - Wait for the reorganizer to finish processing all messages. - """ - logger.debug("Waiting for reorganizer to finish processing messages...") - self.reorganizer.wait_until_current_task_done() - - def close(self): - self.wait_reorganizer() - self.reorganizer.stop() - - def __del__(self): - self.close() diff --git a/src/memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py b/src/memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py index ad9dcb2b8..2d8b72ecc 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +++ b/src/memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py @@ -1,9 +1,9 @@ import json import traceback -from memos.embedders.factory import OllamaEmbedder +from memos.embedders.base import BaseEmbedder +from memos.graph_dbs.base import BaseGraphDB from memos.graph_dbs.item import GraphDBNode -from memos.graph_dbs.neo4j import Neo4jGraphDB from memos.llms.base import BaseLLM from memos.log import get_logger from memos.memories.textual.item import TreeNodeTextualMemoryMetadata @@ -18,12 +18,18 @@ class RelationAndReasoningDetector: - def __init__(self, graph_store: Neo4jGraphDB, llm: BaseLLM, embedder: OllamaEmbedder): + def __init__(self, graph_store: BaseGraphDB, llm: BaseLLM, embedder: BaseEmbedder): self.graph_store = graph_store self.llm = llm self.embedder = embedder - def process_node(self, node: GraphDBNode, exclude_ids: list[str], top_k: int = 5): + def process_node( + self, + node: GraphDBNode, + exclude_ids: list[str], + top_k: int = 5, + user_name: str | None = None, + ): """ Unified pipeline for: 1) Pairwise relations (cause, condition, conflict, relate) @@ -52,6 +58,7 @@ def process_node(self, node: GraphDBNode, exclude_ids: list[str], top_k: int = 5 exclude_ids=exclude_ids, top_k=top_k, min_overlap=2, + user_name=user_name, ) nearest = [GraphDBNode(**cand_data) for cand_data in nearest] @@ -62,7 +69,7 @@ def process_node(self, node: GraphDBNode, exclude_ids: list[str], top_k: int = 5 """ # 2) Inferred nodes (from causal/condition) - inferred = self._infer_fact_nodes_from_relations(pairwise) + inferred = self._infer_fact_nodes_from_relations(pairwise, user_name=user_name) results["inferred_nodes"].extend(inferred) """ @@ -115,12 +122,18 @@ def _detect_pairwise_causal_condition_relations( return results - def _infer_fact_nodes_from_relations(self, pairwise_results: dict): + def _infer_fact_nodes_from_relations(self, pairwise_results: dict, user_name: str): inferred_nodes = [] for rel in pairwise_results["relations"]: if rel["relation_type"] in ("CAUSE", "CONDITION"): - src = self.graph_store.get_node(rel["source_id"]) - tgt = self.graph_store.get_node(rel["target_id"]) + src = self.graph_store.get_node( + rel["source_id"], + user_name=user_name, + ) + tgt = self.graph_store.get_node( + rel["target_id"], + user_name=user_name, + ) if not src or not tgt: continue diff --git a/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py b/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py index 0337225d1..596f64f70 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py +++ b/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py @@ -1,24 +1,19 @@ import json -import threading -import time import traceback from collections import defaultdict from concurrent.futures import as_completed -from queue import PriorityQueue -from typing import Literal import numpy as np from memos.context.context import ContextThreadPoolExecutor from memos.dependency import require_python_package from memos.embedders.factory import OllamaEmbedder -from memos.graph_dbs.item import GraphDBEdge, GraphDBNode +from memos.graph_dbs.item import GraphDBNode from memos.graph_dbs.neo4j import Neo4jGraphDB from memos.llms.base import BaseLLM from memos.log import get_logger from memos.memories.textual.item import SourceMessage, TreeNodeTextualMemoryMetadata -from memos.memories.textual.tree_text_memory.organize.handler import NodeHandler from memos.memories.textual.tree_text_memory.organize.relation_reason_detector import ( RelationAndReasoningDetector, ) @@ -44,30 +39,6 @@ def build_summary_parent_node(cluster_nodes): return normalized_sources -class QueueMessage: - def __init__( - self, - op: Literal["add", "remove", "merge", "update", "end"], - # `str` for node and edge IDs, `GraphDBNode` and `GraphDBEdge` for actual objects - before_node: list[str] | list[GraphDBNode] | None = None, - before_edge: list[str] | list[GraphDBEdge] | None = None, - after_node: list[str] | list[GraphDBNode] | None = None, - after_edge: list[str] | list[GraphDBEdge] | None = None, - ): - self.op = op - self.before_node = before_node - self.before_edge = before_edge - self.after_node = after_node - self.after_edge = after_edge - - def __str__(self) -> str: - return f"QueueMessage(op={self.op}, before_node={self.before_node if self.before_node is None else len(self.before_node)}, after_node={self.after_node if self.after_node is None else len(self.after_node)})" - - def __lt__(self, other: "QueueMessage") -> bool: - op_priority = {"add": 2, "remove": 2, "merge": 1, "end": 0} - return op_priority[self.op] < op_priority[other.op] - - def extract_first_to_last_brace(text: str): start = text.find("{") end = text.rfind("}") @@ -77,135 +48,92 @@ def extract_first_to_last_brace(text: str): return json_str, json.loads(json_str) -class GraphStructureReorganizer: - def __init__( - self, graph_store: Neo4jGraphDB, llm: BaseLLM, embedder: OllamaEmbedder, is_reorganize: bool - ): - self.queue = PriorityQueue() # Min-heap - self.graph_store = graph_store - self.llm = llm - self.embedder = embedder - self.relation_detector = RelationAndReasoningDetector( - self.graph_store, self.llm, self.embedder - ) - self.resolver = NodeHandler(graph_store=graph_store, llm=llm, embedder=embedder) +def recursive_clustering( + nodes_list, depth=0, max_cluster_size: int = 20, min_cluster_size: int = 10 +): + """Recursively split clusters until each is <= max_cluster_size.""" + from sklearn.cluster import MiniBatchKMeans - self.is_reorganize = is_reorganize - self._reorganize_needed = True - if self.is_reorganize: - # ____ 1. For queue message driven thread ___________ - self.thread = threading.Thread(target=self._run_message_consumer_loop) - self.thread.start() - # ____ 2. For periodic structure optimization _______ - self._stop_scheduler = False - self._is_optimizing = {"LongTermMemory": False, "UserMemory": False} - self.structure_optimizer_thread = threading.Thread( - target=self._run_structure_organizer_loop - ) - self.structure_optimizer_thread.start() + indent = " " * depth + logger.info(f"{indent}[Recursive] Start clustering {len(nodes_list)} nodes at depth {depth}") - def add_message(self, message: QueueMessage): - self.queue.put_nowait(message) + if len(nodes_list) <= max_cluster_size: + logger.info(f"{indent}[Recursive] Node count <= {max_cluster_size}, stop splitting.") + return [nodes_list] + # Try kmeans with k = ceil(len(nodes) / max_cluster_size) + x_nodes = [n for n in nodes_list if n.metadata.embedding] + x = np.array([n.metadata.embedding for n in x_nodes]) - def wait_until_current_task_done(self): - """ - Wait until: - 1) queue is empty - 2) any running structure optimization is done - """ - deadline = time.time() + 600 - if not self.is_reorganize: - return + if len(x) < min_cluster_size: + logger.info(f"{indent}[Recursive] Too few embeddings ({len(x)}), skipping clustering.") + return [nodes_list] - if not self.queue.empty(): - self.queue.join() - logger.debug("Queue is now empty.") + k = min(len(x), (len(nodes_list) + max_cluster_size - 1) // max_cluster_size) + k = max(1, k) - while any(self._is_optimizing.values()): - logger.debug(f"Waiting for structure optimizer to finish... {self._is_optimizing}") - if time.time() > deadline: - logger.error(f"Wait timed out; flags={self._is_optimizing}") - break - time.sleep(1) - logger.debug("Structure optimizer is now idle.") + try: + logger.info(f"{indent}[Recursive] Clustering with k={k} on {len(x)} points.") + kmeans = MiniBatchKMeans(n_clusters=k, batch_size=256, random_state=42) + labels = kmeans.fit_predict(x) - def _run_message_consumer_loop(self): - while True: - message = self.queue.get() - if message.op == "end": - break + label_groups = defaultdict(list) + for node, label in zip(x_nodes, labels, strict=False): + label_groups[label].append(node) - try: - if self._preprocess_message(message): - self.handle_message(message) - except Exception: - logger.error(traceback.format_exc()) - self.queue.task_done() + # Map: label -> nodes with no embedding (fallback group) + no_embedding_nodes = [n for n in nodes_list if not n.metadata.embedding] + if no_embedding_nodes: + logger.warning( + f"{indent}[Recursive] {len(no_embedding_nodes)} nodes have no embedding. Added to largest cluster." + ) + # Assign to the largest cluster + largest_label = max(label_groups.items(), key=lambda kv: len(kv[1]))[0] + label_groups[largest_label].extend(no_embedding_nodes) + + result = [] + for label, sub_group in label_groups.items(): + logger.info(f"{indent} Cluster-{label}: {len(sub_group)} nodes") + result.extend( + recursive_clustering( + sub_group, + depth=depth + 1, + max_cluster_size=max_cluster_size, + min_cluster_size=min_cluster_size, + ) + ) + return result - @require_python_package( - import_name="schedule", - install_command="pip install schedule", - install_link="https://schedule.readthedocs.io/en/stable/installation.html", - ) - def _run_structure_organizer_loop(self): - """ - Use schedule library to periodically trigger structure optimization. - This runs until the stop flag is set. - """ - import schedule - - schedule.every(100).seconds.do(self.optimize_structure, scope="LongTermMemory") - schedule.every(100).seconds.do(self.optimize_structure, scope="UserMemory") - - logger.info("Structure optimizer schedule started.") - while not getattr(self, "_stop_scheduler", False): - if any(self._is_optimizing.values()): - time.sleep(1) - continue - if self._reorganize_needed: - logger.info("[Reorganizer] Triggering optimize_structure due to new nodes.") - self.optimize_structure(scope="LongTermMemory") - self.optimize_structure(scope="UserMemory") - self._reorganize_needed = False - time.sleep(30) - - def stop(self): - """ - Stop the reorganizer thread. - """ - if not self.is_reorganize: - return + except Exception as e: + logger.warning(f"{indent}[Recursive] Clustering failed: {e}, fallback to one cluster.") + return [nodes_list] - self.add_message(QueueMessage(op="end")) - self.thread.join() - logger.info("Reorganize thread stopped.") - self._stop_scheduler = True - self.structure_optimizer_thread.join() - logger.info("Structure optimizer stopped.") - - def handle_message(self, message: QueueMessage): - handle_map = {"add": self.handle_add, "remove": self.handle_remove} - handle_map[message.op](message) - logger.debug(f"message queue size: {self.queue.qsize()}") - - def handle_add(self, message: QueueMessage): - logger.debug(f"Handling add operation: {str(message)[:500]}") - added_node = message.after_node[0] - detected_relationships = self.resolver.detect( - added_node, scope=added_node.metadata.memory_type - ) - if detected_relationships: - for added_node, existing_node, relation in detected_relationships: - self.resolver.resolve(added_node, existing_node, relation) - self._reorganize_needed = True +def _parse_json_result(response_text): + try: + response_text = response_text.replace("```", "").replace("json", "") + response_json = extract_first_to_last_brace(response_text)[1] + return response_json + except json.JSONDecodeError as e: + logger.warning(f"Failed to parse LLM response as JSON: {e}\nRaw response:\n{response_text}") + return {} + - def handle_remove(self, message: QueueMessage): - logger.debug(f"Handling remove operation: {str(message)[:50]}") +class GraphStructureReorganizer: + def __init__( + self, graph_store: Neo4jGraphDB, llm: BaseLLM, embedder: OllamaEmbedder, is_reorganize: bool + ): + self.graph_store = graph_store + self.llm = llm + self.embedder = embedder + self.relation_detector = RelationAndReasoningDetector( + self.graph_store, self.llm, self.embedder + ) + self.is_reorganize = is_reorganize def optimize_structure( self, scope: str = "LongTermMemory", + user_name: str | None = None, local_tree_threshold: int = 10, min_cluster_size: int = 4, min_group_size: int = 20, @@ -218,6 +146,8 @@ def optimize_structure( 3. Create parent nodes and build local PARENT trees. """ # --- Total time watch dog: check functions --- + import time + start_ts = time.time() def _check_deadline(where: str): @@ -229,29 +159,43 @@ def _check_deadline(where: str): return True return False - if self._is_optimizing[scope]: - logger.info(f"[GraphStructureReorganize] Already optimizing for {scope}. Skipping.") - return - - if self.graph_store.node_not_exist(scope): + if self.graph_store.node_not_exist(scope, user_name=user_name): logger.debug(f"[GraphStructureReorganize] No nodes for scope={scope}. Skip.") return - self._is_optimizing[scope] = True try: logger.debug( f"[GraphStructureReorganize] 🔍 Starting structure optimization for scope: {scope}" ) - logger.debug( f"[GraphStructureReorganize] Num of scope in self.graph_store is" - f" {self.graph_store.get_memory_count(scope)}" + f" {self.graph_store.get_memory_count(scope, user_name=user_name)}" ) # Load candidate nodes if _check_deadline("[GraphStructureReorganize] Before loading candidates"): return - raw_nodes = self.graph_store.get_structure_optimization_candidates(scope) - nodes = [GraphDBNode(**n) for n in raw_nodes] + raw_nodes = self.graph_store.get_structure_optimization_candidates( + scope, user_name=user_name, include_embedding=True + ) + logger.debug( + f"[GraphStructureReorganize] Find {len(raw_nodes)} nodes to optimize" + f"which is {[node['id'] for node in raw_nodes]}" + ) + + def _norm(s): + return s.strip().lower() if isinstance(s, str) else s + + filtered_raw = [] + for n in raw_nodes: + tags = (n.get("metadata") or {}).get("tags") or [] + if not any(_norm(t) == "mode:fast" for t in tags if isinstance(t, str)): + filtered_raw.append(n) + dropped = len(raw_nodes) - len(filtered_raw) + if dropped: + logger.info( + f"[GraphStructureReorganize] Tag filter dropped {dropped} nodes (mode:fast)." + ) + nodes = [GraphDBNode(**n) for n in filtered_raw] if not nodes: logger.info("[GraphStructureReorganize] No nodes to optimize. Skipping.") @@ -266,10 +210,9 @@ def _check_deadline(where: str): if _check_deadline("[GraphStructureReorganize] Before partition"): return partitioned_groups = self._partition(nodes) - logger.info( + logger.debug( f"[GraphStructureReorganize] Partitioned into {len(partitioned_groups)} clusters." ) - if _check_deadline("[GraphStructureReorganize] Before submit partition task"): return with ContextThreadPoolExecutor(max_workers=4) as executor: @@ -282,6 +225,8 @@ def _check_deadline(where: str): scope, local_tree_threshold, min_cluster_size, + user_name, + _check_deadline, ) ) @@ -299,7 +244,6 @@ def _check_deadline(where: str): logger.info("[GraphStructure Reorganize] Structure optimization finished.") finally: - self._is_optimizing[scope] = False logger.info("[GraphStructureReorganize] Structure optimization finished.") def _process_cluster_and_write( @@ -308,6 +252,8 @@ def _process_cluster_and_write( scope: str, local_tree_threshold: int, min_cluster_size: int, + user_name: str, + check_deadline_func, ): if len(cluster_nodes) <= min_cluster_size: return @@ -316,19 +262,46 @@ def _process_cluster_and_write( sub_clusters = self._local_subcluster(cluster_nodes) sub_parents = [] - for sub_nodes in sub_clusters: - if len(sub_nodes) < min_cluster_size: - continue # Skip tiny noise - sub_parent_node = self._summarize_cluster(sub_nodes, scope) - self._create_parent_node(sub_parent_node) - self._link_cluster_nodes(sub_parent_node, sub_nodes) - sub_parents.append(sub_parent_node) - + def _process_one_subcluster(sub_nodes): + try: + sub_parent_node = self._summarize_cluster(sub_nodes, scope) + self._create_parent_node(sub_parent_node, user_name) + self._link_cluster_nodes(sub_parent_node, sub_nodes, user_name) + sub_nodes_str = "\n|_____".join([sub_node.memory for sub_node in sub_nodes]) + logger.debug( + f"Processed a group by nodes. \nThe Structure is: " + f"\n Parent Node: {sub_parent_node.memory}\n" + f"\n Child Node: {sub_nodes_str}" + ) + return sub_parent_node + except Exception as e: + logger.warning(f"Process sub-cluster failed: {e}", exc_info=True) + return None + + valid_sub_clusters = [sc for sc in sub_clusters if len(sc) >= min_cluster_size] + + max_workers = min(4, len(valid_sub_clusters)) + if max_workers > 0: + with ContextThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [ + executor.submit(_process_one_subcluster, sc) for sc in valid_sub_clusters + ] + for fut in as_completed(futures): + res = fut.result() + if res is not None: + sub_parents.append(res) + + logger.debug(f"len of sub-parents: {len(sub_parents)}") if sub_parents and len(sub_parents) >= min_cluster_size: cluster_parent_node = self._summarize_cluster(cluster_nodes, scope) - self._create_parent_node(cluster_parent_node) + logger.debug( + f"Find cluster_parent node: {cluster_parent_node.id}: {cluster_parent_node.memory}" + ) + self._create_parent_node(cluster_parent_node, user_name) for sub_parent in sub_parents: - self.graph_store.add_edge(cluster_parent_node.id, sub_parent.id, "PARENT") + self.graph_store.add_edge( + cluster_parent_node.id, sub_parent.id, "PARENT", user_name=user_name + ) logger.info("Adding relations/reasons") nodes_to_check = cluster_nodes @@ -343,19 +316,34 @@ def _process_cluster_and_write( node, exclude_ids, 10, # top_k + user_name=user_name, ) ) - for f in as_completed(futures, timeout=300): - results = f.result() + for f in as_completed(futures): + if check_deadline_func("[GraphStructureReorganize] Relations/reasons"): + for x in futures: + x.cancel() + return + try: + results = f.result() + except Exception as e: + logger.warning(f"Relation task failed: {e}", exc_info=True) + continue # 1) Add pairwise relations for rel in results["relations"]: if not self.graph_store.edge_exists( - rel["source_id"], rel["target_id"], rel["relation_type"] + rel["source_id"], + rel["target_id"], + rel["relation_type"], + user_name=user_name, ): self.graph_store.add_edge( - rel["source_id"], rel["target_id"], rel["relation_type"] + rel["source_id"], + rel["target_id"], + rel["relation_type"], + user_name=user_name, ) # 2) Add inferred nodes and link to sources @@ -364,14 +352,21 @@ def _process_cluster_and_write( inf_node.id, inf_node.memory, inf_node.metadata.model_dump(exclude_none=True), + user_name=user_name, ) for src_id in inf_node.metadata.sources: - self.graph_store.add_edge(src_id, inf_node.id, "INFERS") + self.graph_store.add_edge( + src_id, inf_node.id, "INFERS", user_name=user_name + ) # 3) Add sequence links for seq in results["sequence_links"]: - if not self.graph_store.edge_exists(seq["from_id"], seq["to_id"], "FOLLOWS"): - self.graph_store.add_edge(seq["from_id"], seq["to_id"], "FOLLOWS") + if not self.graph_store.edge_exists( + seq["from_id"], seq["to_id"], "FOLLOWS", user_name=user_name + ): + self.graph_store.add_edge( + seq["from_id"], seq["to_id"], "FOLLOWS", user_name=user_name + ) # 4) Add aggregate concept nodes for agg_node in results["aggregate_nodes"]: @@ -379,9 +374,12 @@ def _process_cluster_and_write( agg_node.id, agg_node.memory, agg_node.metadata.model_dump(exclude_none=True), + user_name=user_name, ) for child_id in agg_node.metadata.sources: - self.graph_store.add_edge(agg_node.id, child_id, "AGGREGATE_TO") + self.graph_store.add_edge( + agg_node.id, child_id, "AGGREGATE_TO", user_name=user_name + ) logger.info("[Reorganizer] Cluster relation/reasoning done.") @@ -407,7 +405,8 @@ def _local_subcluster( messages = [{"role": "user", "content": prompt}] response_text = self.llm.generate(messages) - response_json = self._parse_json_result(response_text) + response_json = _parse_json_result(response_text) + logger.debug(f"In Sub-Cluster: \ninput: {prompt}\n output: {response_json}") assigned_ids = set() result_subclusters = [] @@ -442,7 +441,6 @@ def _partition(self, nodes, min_cluster_size: int = 10, max_cluster_size: int = Returns: List of clusters, each as a list of GraphDBNode """ - from sklearn.cluster import MiniBatchKMeans if len(nodes) <= max_cluster_size: logger.info( @@ -450,73 +448,33 @@ def _partition(self, nodes, min_cluster_size: int = 10, max_cluster_size: int = ) return [nodes] - def recursive_clustering(nodes_list, depth=0): - """Recursively split clusters until each is <= max_cluster_size.""" - indent = " " * depth - logger.info( - f"{indent}[Recursive] Start clustering {len(nodes_list)} nodes at depth {depth}" - ) - - if len(nodes_list) <= max_cluster_size: - logger.info( - f"{indent}[Recursive] Node count <= {max_cluster_size}, stop splitting." - ) - return [nodes_list] - # Try kmeans with k = ceil(len(nodes) / max_cluster_size) - x_nodes = [n for n in nodes_list if n.metadata.embedding] - x = np.array([n.metadata.embedding for n in x_nodes]) - - if len(x) < min_cluster_size: - logger.info( - f"{indent}[Recursive] Too few embeddings ({len(x)}), skipping clustering." - ) - return [nodes_list] - - k = min(len(x), (len(nodes_list) + max_cluster_size - 1) // max_cluster_size) - k = max(1, k) - - try: - logger.info(f"{indent}[Recursive] Clustering with k={k} on {len(x)} points.") - kmeans = MiniBatchKMeans(n_clusters=k, batch_size=256, random_state=42) - labels = kmeans.fit_predict(x) - - label_groups = defaultdict(list) - for node, label in zip(x_nodes, labels, strict=False): - label_groups[label].append(node) - - # Map: label -> nodes with no embedding (fallback group) - no_embedding_nodes = [n for n in nodes_list if not n.metadata.embedding] - if no_embedding_nodes: - logger.warning( - f"{indent}[Recursive] {len(no_embedding_nodes)} nodes have no embedding. Added to largest cluster." - ) - # Assign to largest cluster - largest_label = max(label_groups.items(), key=lambda kv: len(kv[1]))[0] - label_groups[largest_label].extend(no_embedding_nodes) - - result = [] - for label, sub_group in label_groups.items(): - logger.info(f"{indent} Cluster-{label}: {len(sub_group)} nodes") - result.extend(recursive_clustering(sub_group, depth=depth + 1)) - return result - - except Exception as e: - logger.warning( - f"{indent}[Recursive] Clustering failed: {e}, fallback to one cluster." - ) - return [nodes_list] - - raw_clusters = recursive_clustering(nodes) + raw_clusters = recursive_clustering( + nodes, max_cluster_size=max_cluster_size, min_cluster_size=min_cluster_size + ) filtered_clusters = [c for c in raw_clusters if len(c) > min_cluster_size] logger.info(f"[KMeansPartition] Total clusters before filtering: {len(raw_clusters)}") for i, cluster in enumerate(raw_clusters): - logger.info(f"[KMeansPartition] Cluster-{i}: {len(cluster)} nodes") - - logger.info( + logger.debug(f"[KMeansPartition] Cluster-{i}: {len(cluster)} nodes") + logger.debug(f"[KMeansPartition] Total clusters before filtering: {len(raw_clusters)}") + logger.debug( f"[KMeansPartition] Clusters after filtering (>{min_cluster_size}): {len(filtered_clusters)}" ) + seen_ids = set() + duplicate_ids = set() + + for i, cluster in enumerate(raw_clusters): + ids = [n.id for n in cluster] + mems = [n.memory[:80].replace("\n", " ") + "..." for n in cluster] + logger.debug(f"[Cluster-{i}] size={len(cluster)}") + for nid, mem in zip(ids, mems, strict=False): + logger.debug(f" - id={nid} | mem={mem}") + if nid in seen_ids: + duplicate_ids.add(nid) + else: + seen_ids.add(nid) + return filtered_clusters def _summarize_cluster(self, cluster_nodes: list[GraphDBNode], scope: str) -> GraphDBNode: @@ -526,19 +484,37 @@ def _summarize_cluster(self, cluster_nodes: list[GraphDBNode], scope: str) -> Gr if not cluster_nodes: raise ValueError("Cluster nodes cannot be empty.") - memories_items_text = "\n\n".join( - [ - f"{i}. key: {n.metadata.key}\nvalue: {n.memory}\nsummary:{n.metadata.background}" - for i, n in enumerate(cluster_nodes) - ] - ) + memories_items_text = "" + for i, n in enumerate(cluster_nodes): + # Build raw dialogue excerpt + # We won't hard-cut mid-sentence. We'll collect turns until ~300 chars, then stop before breaking. + excerpt_parts = [] + current_len = 0 + for source_j in n.metadata.sources: + turn_text = f'{source_j.role}: "{source_j.content_safe}"' + # if adding this turn blows us past ~300, break BEFORE adding + if current_len + len(turn_text) > 1500: + break + excerpt_parts.append(turn_text) + current_len += len(turn_text) + excerpt_parts.append("...") + raw_dialogue_excerpt = "\n".join(excerpt_parts) + + mem_i = ( + f"\nChild Memory {i}:\n" + f"- canonical_value: {n.memory}\n" + f"- user_summary: {n.metadata.background}\n" + f"- raw_dialogue_excerpt:\n{raw_dialogue_excerpt if raw_dialogue_excerpt else '(none)'}\n" + ) + + memories_items_text += mem_i # Build prompt prompt = REORGANIZE_PROMPT.replace("{memory_items_text}", memories_items_text) messages = [{"role": "user", "content": prompt}] response_text = self.llm.generate(messages) - response_json = self._parse_json_result(response_text) + response_json = _parse_json_result(response_text) # Extract fields parent_key = response_json.get("key", "").strip() @@ -567,18 +543,7 @@ def _summarize_cluster(self, cluster_nodes: list[GraphDBNode], scope: str) -> Gr ) return parent_node - def _parse_json_result(self, response_text): - try: - response_text = response_text.replace("```", "").replace("json", "") - response_json = extract_first_to_last_brace(response_text)[1] - return response_json - except json.JSONDecodeError as e: - logger.warning( - f"Failed to parse LLM response as JSON: {e}\nRaw response:\n{response_text}" - ) - return {} - - def _create_parent_node(self, parent_node: GraphDBNode) -> None: + def _create_parent_node(self, parent_node: GraphDBNode, user_name: str) -> None: """ Create a new parent node for the cluster. """ @@ -586,38 +551,17 @@ def _create_parent_node(self, parent_node: GraphDBNode) -> None: parent_node.id, parent_node.memory, parent_node.metadata.model_dump(exclude_none=True), + user_name=user_name, ) - def _link_cluster_nodes(self, parent_node: GraphDBNode, child_nodes: list[GraphDBNode]): + def _link_cluster_nodes( + self, parent_node: GraphDBNode, child_nodes: list[GraphDBNode], user_name: str + ): """ Add PARENT edges from the parent node to all nodes in the cluster. """ for child in child_nodes: if not self.graph_store.edge_exists( - parent_node.id, child.id, "PARENT", direction="OUTGOING" + parent_node.id, child.id, "PARENT", direction="OUTGOING", user_name=user_name ): - self.graph_store.add_edge(parent_node.id, child.id, "PARENT") - - def _preprocess_message(self, message: QueueMessage) -> bool: - message = self._convert_id_to_node(message) - if message.after_node is None or None in message.after_node: - logger.debug( - f"Found non-existent node in after_node in message: {message}, skip this message." - ) - return False - return True - - def _convert_id_to_node(self, message: QueueMessage) -> QueueMessage: - """ - Convert IDs in the message.after_node to GraphDBNode objects. - """ - for i, node in enumerate(message.after_node or []): - if not isinstance(node, str): - continue - raw_node = self.graph_store.get_node(node, include_embedding=True) - if raw_node is None: - logger.debug(f"Node with ID {node} not found in the graph store.") - message.after_node[i] = None - else: - message.after_node[i] = GraphDBNode(**raw_node) - return message + self.graph_store.add_edge(parent_node.id, child.id, "PARENT", user_name=user_name) diff --git a/src/memos/templates/tree_reorganize_prompts.py b/src/memos/templates/tree_reorganize_prompts.py index 086f59a1e..88730b2b5 100644 --- a/src/memos/templates/tree_reorganize_prompts.py +++ b/src/memos/templates/tree_reorganize_prompts.py @@ -1,40 +1,111 @@ -REORGANIZE_PROMPT = """You are a memory clustering and summarization expert. +REORGANIZE_PROMPT = """YYou are a memory consolidation and summarization expert. -Given the following child memory items: +You will receive a set of child memories that have already been clustered together. These child memories all belong to the same ongoing life thread for the user — the same situation, goal, or period of focus. -{memory_items_text} +Your job is to generate one parent memory node for this life thread. -Please perform: -1. Identify information that reflects user's experiences, beliefs, concerns, decisions, plans, or reactions — including meaningful input from assistant that user acknowledged or responded to. -2. Resolve all time, person, and event references clearly: - - Convert relative time expressions (e.g., “yesterday,” “next Friday”) into absolute dates using the message timestamp if possible. - - Clearly distinguish between event time and message time. - - If uncertainty exists, state it explicitly (e.g., “around June 2025,” “exact date unclear”). - - Include specific locations if mentioned. - - Resolve all pronouns, aliases, and ambiguous references into full names or identities. - - Disambiguate people with the same name if applicable. -3. Always write from a third-person perspective, referring to user as -"The user" or by name if name mentioned, rather than using first-person ("I", "me", "my"). -For example, write "The user felt exhausted..." instead of "I felt exhausted...". -4. Do not omit any information that user is likely to remember. - - Include all key experiences, thoughts, emotional responses, and plans — even if they seem minor. - - Prioritize completeness and fidelity over conciseness. - - Do not generalize or skip details that could be personally meaningful to user. -5. Summarize all child memory items into one memory item. +This parent node will sit above all the child memories. It should read like a concise outline of what this whole thread is about: what the user was working on, why it mattered, and roughly when it was happening. -Language rules: -- The `key`, `value`, `tags`, `summary` fields must match the mostly used language of the input memory items. **如果输入是中文,请输出中文** -- Keep `memory_type` in English. +Input format: +Each child memory will appear in the following structure: -Return valid JSON: +Child Memory X: +- canonical_value: A factual description of what the user asked, did, planned, or cared about (time, entity, need). +- user_summary: A higher-level narrative summary, which may contain interpretation. +- raw_dialogue_excerpt: Short excerpts from the real conversation between the user and the assistant. This is the evidence of what the user actually said, committed to, or felt. + +Evidence priority (this is critical): +1. Treat raw_dialogue_excerpt as the highest-fidelity source of the user's actual intent, feelings, concerns, plans, or commitments. +2. Use canonical_value to bring in clear factual context: dates, places, roles, objects of interest. +3. Use user_summary only to help you recognize that these moments are part of the same thread. Do NOT import personality claims, value judgments, or motivations from user_summary unless they are also supported by raw_dialogue_excerpt or canonical_value. + +Do NOT invent new intentions, emotions, commitments, or timelines that are not supported by the provided evidence. + +Your output must follow these rules: + +1. Capture the throughline, not every step: + - What was the sustained situation, goal, or focus across these memories? + - Over what approximate time period did this happen? Use clear absolute timing if available (e.g. "early March 2025"). If timing is unclear, say "timeframe unclear." + - Which key places, roles, people, or assets keep showing up in this thread? (e.g. a Berlin conference, the user's manager Elena, the user's injured knee, house hunting in Oakland) + - What recurring motivation or concern did the user express? (e.g. wanting to perform well without sounding too salesy; wanting to protect their knee without losing training progress) + +2. Stay high-level, not chronological: + - Do NOT dump every detail from each child memory. + - Do NOT list every piece of advice the assistant gave. + - Do NOT regurgitate every number or spec. + - Instead, in 2–5 sentences, describe what this thread is about, why it mattered to the user, and the general timing/context. + +3. Be strictly factual: + - Only include statements supported by raw_dialogue_excerpt or clearly stated in canonical_value. + - If the user is “planning to,” “trying to,” or “considering,” say exactly that. Do not upgrade it to “the user has done.” + - If timing is fuzzy, acknowledge that (“timeframe unclear”). + +4. Tone and perspective: + - Write in third-person. Refer to the user as “The user” (or by their explicit name if provided). Never use first-person (“I,” “my”). + - Use a neutral, descriptive tone. This is not marketing copy and not an emotional diary. + - The output language must match the dominant language of the child memories. If the child memories are mostly English, write in English. 如果输入主要是中文,就用中文。 + - Do not use bullet points. + +Output format (must be strictly valid JSON): { - "key": , - "memory_type": , - "value": , - "tags": , - "summary": + "key": , + "memory_type": "LongTermMemory", + "value": , + "tags": +} + +Definitions: +- `key`: This is the title of the life thread. It should sound like something the user would remember later (e.g. "Preparing for the Berlin security talk (March 2025)") rather than something like "Q1 External Stakeholder Communications Enablement." +- `value`: This is the concise narrative of what was going on, why it mattered, and when. +- `tags`: Retrieval hooks for later. + +======================== +EXAMPLE +======================== + +Example input sub-cluster (3 items): +Child Memory 0: +- canonical_value: On March 2, 2025, the user said they were nervous about giving a talk in Berlin next week and asked for help cleaning up their presentation slides. +- user_summary: The user was preparing to speak at a conference in Berlin and wanted the presentation to feel confident and professional. +- raw_dialogue_excerpt: +user: "I'm giving a talk in Berlin next week and I'm honestly nervous." +user: "Can you help me clean up my slides so I don't sound like I'm just selling?" +assistant: "You mentioned your manager Elena wants you to highlight the product's security roadmap." + +Child Memory 1: +- canonical_value: The user said their manager Elena wanted them to highlight the product's security roadmap in that Berlin talk, and the user was worried about sounding too 'salesy.' +- user_summary: The user wanted to come across as credible, not like pure marketing. +- raw_dialogue_excerpt: +user: "Elena wants me to talk about the security roadmap, but I don't want to sound like a salesperson." + +Child Memory 2: +- canonical_value: The user asked what clothes would look professional but still comfortable under stage lighting at the Berlin conference. +- user_summary: The user was trying to present well on stage. +- raw_dialogue_excerpt: +user: "What should I wear on stage so I look professional but I'm not dying under the lights?" + +Correct output JSON: + +{ + "key": "Preparing for the Berlin security talk (March 2025)", + "memory_type": "LongTermMemory", + "value": "In early March 2025, The user was preparing to present at a conference in Berlin and felt anxious about performing well. The user asked for help refining their slides and mentioned that their manager Elena wanted the presentation to emphasize the product's security roadmap, but the user did not want the talk to sound overly salesy. The user also asked about what to wear on stage so they would look professional while staying comfortable under the conference lighting.", + "tags": ["Berlin talk prep", "manager Elena", "security roadmap", "presentation anxiety", "stage presence", "March 2025"] } +Why this is correct: +- It captures the ongoing thread (preparing for the Berlin conference talk). +- It states the approximate timeframe ("early March 2025"). +- It mentions the key person (manager Elena) and the main concern (sound credible, not salesy). +- It includes the performance/appearance angle (slides, clothing under lights). +- It keeps third-person (“The user”) and doesn’t invent anything that wasn’t in the evidence. +- It is an outline-style summary, not a blow-by-blow timeline. + +======================== + +Sub-cluster input: +{memory_items_text} + """ DOC_REORGANIZE_PROMPT = """You are a document summarization and knowledge extraction expert. @@ -74,36 +145,97 @@ """ - LOCAL_SUBCLUSTER_PROMPT = """You are a memory organization expert. -You are given a cluster of memory items, each with an ID and content. -Your task is to divide these into smaller, semantically meaningful sub-clusters. +You will receive a batch of memory items from the same user. Each item has an ID and some content. + +Your task is to group these memory items into sub-clusters. Each sub-cluster should represent one coherent "life thread" the user was actively dealing with during a specific period, in a specific context, for a specific goal. + +Definition of a sub-cluster / life thread: +- A sub-cluster is a set of memories that clearly belong to the same ongoing situation, project, or goal in the user's life. +- The stronger these signals are, the more likely the items belong together: + - They happen in the same general time window (same day / same few days / same period). + - They occur in the same context (e.g. preparing for a conference trip, rehabbing an injury, onboarding into a new manager role). + - They repeatedly mention the same people or entities (e.g. the user's manager Elena, the user's dog Milo, a real estate agent). + - They reflect the same motivation or aim (e.g. “get ready to present at a conference,” “protect my knee while staying in shape,” “figure out how to lead a new team,” “understand home-buying budget”). + +Hard constraints: +- Do NOT merge memories that clearly come from different life threads, even if they share similar words or emotions. + - Do NOT merge “preparing to present in Berlin at a security conference” with “doing physical therapy after a knee injury.” They are different goals. + - Do NOT merge “learning to manage a new team at work” with “researching mortgage / down payment for a house in Oakland.” These are separate parts of life. +- Each sub-cluster must contain 2–10 items. +- If an item cannot be placed into any multi-item sub-cluster without breaking the rules above, treat it as a singleton. +- A singleton means: this item currently stands alone in its own thread. Do NOT force unrelated items together just to avoid a singleton. +- Each item ID must appear exactly once: either in one sub-cluster or in `singletons`. No duplicates. + +Output requirements: +- You must return strictly valid JSON. +- For each sub-cluster, `key` must be a short, natural title that sounds like how a human would label that period of their life — not corporate jargon. + - Good: "Getting ready to present in Berlin (March 2025)" + - Bad: "Q2 International Presentation Enablement Workstream" +- The language of each `key` should match the dominant language of that sub-cluster. If the sub-cluster is mostly in Chinese, use Chinese. If it's English, use English. + +Return format (must be followed exactly): +{ + "clusters": [ + { + "ids": ["", "", ...], + "key": "" + }, + ... + ], + "singletons": [ + { + "id": "", + "reason": "" + }, + ... + ] +} -Instructions: -- Identify natural topics by analyzing common time, place, people, and event elements. -- Each sub-cluster must reflect a coherent theme that helps retrieval. -- Each sub-cluster should have 2–10 items. Discard singletons. -- Each item ID must appear in exactly one sub-cluster or be discarded. No duplicates are allowed. -- All IDs in the output must be from the provided Memory items. -- Return strictly valid JSON only. +======================== +EXAMPLE +======================== -Example: If you have items about a project across multiple phases, group them by milestone, team, or event. +Example input memory items (illustrative): -Language rules: -- The `key` fields must match the mostly used language of the clustered memories. **如果输入是中文,请输出中文** +- ID: A1 | Value: On March 2, 2025, the user said they were nervous about giving a talk in Berlin next week and asked for help cleaning up their presentation slides. +- ID: A2 | Value: The user said their manager Elena wanted them to highlight the product's security roadmap in that Berlin talk, and the user was worried about sounding too "salesy." +- ID: A3 | Value: The user asked what clothes would look professional but still comfortable under stage lighting at the Berlin conference. +- ID: B1 | Value: The user said they injured their left knee while running stairs on February 28, 2025, and that a doctor told them to avoid high-impact exercise for at least two weeks. +- ID: B2 | Value: The user asked for low-impact leg strengthening exercises that wouldn't aggravate the injured knee and said they were worried about losing training progress. +- ID: C1 | Value: The user said they started casually browsing houses in Oakland and wanted to understand how much down payment they'd need for a $900k place. + +Correct output JSON for this example: -Return valid JSON: { "clusters": [ { - "ids": ["", "", ...], - "key": "" + "ids": ["A1", "A2", "A3"], + "key": "Getting ready to present in Berlin (March 2025)" }, - ... + { + "ids": ["B1", "B2"], + "key": "Recovering from the knee injury" + } + ], + "singletons": [ + { + "id": "C1", + "reason": "House hunting / down payment research currently has no other related items" + } ] } +Explanation: +- A1/A2/A3 all describe the same thread: preparing to give a talk in Berlin. Same event, same time range, same anxiety about performance and tone. +- B1/B2 are about rehabbing a knee injury and staying in shape without making it worse. +- C1 is about browsing houses / down payment planning in Oakland. That is unrelated to conference prep or injury recovery, so it is a singleton. +- We did NOT force C1 into any cluster. +- We did NOT merge the Berlin prep with the knee rehab just because both involve “worry,” since they are different motivations and contexts. + +======================== + Memory items: {joined_scene} """ diff --git a/tests/memories/textual/test_tree_manager.py b/tests/memories/textual/test_tree_manager.py index 1ad730ee5..e3ec89243 100644 --- a/tests/memories/textual/test_tree_manager.py +++ b/tests/memories/textual/test_tree_manager.py @@ -102,37 +102,6 @@ def test_add_to_graph_memory_creates_new_node(memory_manager, mock_graph_store): assert mock_graph_store.add_node.called -def test_inherit_edges(memory_manager, mock_graph_store): - from_id = "from_id" - to_id = "to_id" - mock_graph_store.get_edges.return_value = [ - {"from": from_id, "to": "node_b", "type": "RELATE"}, - {"from": "node_c", "to": from_id, "type": "RELATE"}, - ] - memory_manager._inherit_edges(from_id, to_id) - assert mock_graph_store.add_edge.call_count > 0 - - -def test_ensure_structure_path_creates_new(memory_manager, mock_graph_store): - mock_graph_store.get_by_metadata.return_value = [] - meta = TreeNodeTextualMemoryMetadata( - key="hobby", - embedding=[0.1] * 5, - user_id="user123", - session_id="sess", - ) - node_id = memory_manager._ensure_structure_path("UserMemory", meta) - assert isinstance(node_id, str) - assert mock_graph_store.add_node.called - - -def test_ensure_structure_path_reuses_existing(memory_manager, mock_graph_store): - mock_graph_store.get_by_metadata.return_value = ["existing_node_id"] - meta = TreeNodeTextualMemoryMetadata(key="hobby") - node_id = memory_manager._ensure_structure_path("UserMemory", meta) - assert node_id == "existing_node_id" - - def test_add_returns_written_node_ids(memory_manager): memory = TextualMemoryItem( memory="test memory",