diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 6c3c67553..9e4447bdf 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -41,6 +41,20 @@ def _format_datetime(value: str | datetime) -> str: return str(value) +def _normalize_datetime(val): + """ + Normalize datetime to ISO 8601 UTC string with +00:00. + - If val is datetime object -> keep isoformat() (Neo4j) + - If val is string without timezone -> append +00:00 (Nebula) + - Otherwise just str() + """ + if hasattr(val, "isoformat"): + return val.isoformat() + if isinstance(val, str) and not val.endswith(("+00:00", "Z", "+08:00")): + return val + "+08:00" + return str(val) + + class SessionPoolError(Exception): pass @@ -62,6 +76,7 @@ def __init__( self.hosts = hosts self.user = user self.password = password + self.minsize = minsize self.maxsize = maxsize self.pool = Queue(maxsize) self.lock = Lock() @@ -79,13 +94,13 @@ def _create_and_add_client(self): self.clients.append(client) def get_client(self, timeout: float = 5.0): - from nebulagraph_python import NebulaClient - try: return self.pool.get(timeout=timeout) except Empty: with self.lock: if len(self.clients) < self.maxsize: + from nebulagraph_python import NebulaClient + client = NebulaClient(self.hosts, self.user, self.password) self.clients.append(client) return client @@ -120,6 +135,25 @@ def __exit__(self, exc_type, exc_val, exc_tb): return _ClientContext(self) + def reset_pool(self): + """⚠️ Emergency reset: Close all clients and clear the pool.""" + logger.warning("[Pool] Resetting all clients. Existing sessions will be lost.") + with self.lock: + for client in self.clients: + try: + client.close() + except Exception: + logger.error("Fail to close!!!") + self.clients.clear() + while not self.pool.empty(): + try: + self.pool.get_nowait() + except Empty: + break + for _ in range(self.minsize): + self._create_and_add_client() + logger.info("[Pool] Pool has been reset successfully.") + class NebulaGraphDB(BaseGraphDB): """ @@ -185,8 +219,20 @@ def execute_query(self, gql: str, timeout: float = 5.0, auto_set_db: bool = True client.execute(f"SESSION SET GRAPH `{self.db_name}`") try: return client.execute(gql, timeout=timeout) - except Exception: + except Exception as e: logger.error(f"Fail to run gql {gql} trace: {traceback.format_exc()}") + if "Session not found" in str(e): + logger.warning("[execute_query] Session expired, replacing client.") + try: + client.close() + except Exception: + logger.error("Fail to close!!!!!") + from nebulagraph_python import NebulaClient + + new_client = NebulaClient(self.pool.hosts, self.pool.user, self.pool.password) + self.pool.clients.append(new_client) + return new_client.execute(gql, timeout=timeout) + raise def close(self): self.pool.close() @@ -923,9 +969,11 @@ def clear(self) -> None: except Exception as e: logger.error(f"[ERROR] Failed to clear database: {e}") - def export_graph(self) -> dict[str, Any]: + def export_graph(self, include_embedding: bool = False) -> dict[str, Any]: """ Export all graph nodes and edges in a structured form. + Args: + include_embedding (bool): Whether to include the large embedding field. Returns: { @@ -942,12 +990,41 @@ def export_graph(self) -> dict[str, Any]: edge_query += f' WHERE r.user_name = "{username}"' try: - full_node_query = f"{node_query} RETURN n" - node_result = self.execute_query(full_node_query) + if include_embedding: + return_fields = "n" + else: + return_fields = ",".join( + [ + "n.id AS id", + "n.memory AS memory", + "n.user_name AS user_name", + "n.user_id AS user_id", + "n.session_id AS session_id", + "n.status AS status", + "n.key AS key", + "n.confidence AS confidence", + "n.tags AS tags", + "n.created_at AS created_at", + "n.updated_at AS updated_at", + "n.memory_type AS memory_type", + "n.sources AS sources", + "n.source AS source", + "n.node_type AS node_type", + "n.visibility AS visibility", + "n.usage AS usage", + "n.background AS background", + ] + ) + + full_node_query = f"{node_query} RETURN {return_fields}" + node_result = self.execute_query(full_node_query, timeout=20) nodes = [] + logger.debug(f"Debugging: {node_result}") for row in node_result: - node_wrapper = row.values()[0].as_node() - props = node_wrapper.get_properties() + if include_embedding: + props = row.values()[0].as_node().get_properties() + else: + props = {k: v.value for k, v in row.items()} node = self._parse_node(props) nodes.append(node) @@ -956,7 +1033,7 @@ def export_graph(self) -> dict[str, Any]: try: full_edge_query = f"{edge_query} RETURN a.id AS source, b.id AS target, type(r) as edge" - edge_result = self.execute_query(full_edge_query) + edge_result = self.execute_query(full_edge_query, timeout=20) edges = [ { "source": row.values()[0].value, @@ -1023,6 +1100,7 @@ def get_all_memory_items(self, scope: str) -> list[dict]: MATCH (n@Memory) {where_clause} RETURN n + LIMIT 100 """ nodes = [] try: @@ -1065,7 +1143,7 @@ def get_structure_optimization_candidates(self, scope: str) -> list[dict]: node_props = rec["n"].as_node().get_properties() candidates.append(self._parse_node(node_props)) except Exception as e: - logger.error(f"Failed : {e}") + logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}") return candidates def drop_database(self) -> None: @@ -1318,15 +1396,17 @@ def _parse_node(self, props: dict[str, Any]) -> dict[str, Any]: parsed = {k: self._parse_value(v) for k, v in props.items()} for tf in ("created_at", "updated_at"): - if tf in parsed and hasattr(parsed[tf], "isoformat"): - parsed[tf] = parsed[tf].isoformat() + if tf in parsed and parsed[tf] is not None: + parsed[tf] = _normalize_datetime(parsed[tf]) node_id = parsed.pop("id") memory = parsed.pop("memory", "") parsed.pop("user_name", None) metadata = parsed metadata["type"] = metadata.pop("node_type") - metadata["embedding"] = metadata.pop(self.dim_field) + + if self.dim_field in metadata: + metadata["embedding"] = metadata.pop(self.dim_field) return {"id": node_id, "memory": memory, "metadata": metadata} diff --git a/src/memos/mem_os/utils/format_utils.py b/src/memos/mem_os/utils/format_utils.py index 755d40331..a98e3a26b 100644 --- a/src/memos/mem_os/utils/format_utils.py +++ b/src/memos/mem_os/utils/format_utils.py @@ -570,15 +570,23 @@ def convert_graph_to_tree_forworkmem( else: other_roots.append(root_id) - def build_tree(node_id: str) -> dict[str, Any]: - """Recursively build tree structure""" + def build_tree(node_id: str, visited=None) -> dict[str, Any] | None: + """Recursively build tree structure with cycle detection""" + if visited is None: + visited = set() + + if node_id in visited: + logger.warning(f"[build_tree] Detected cycle at node {node_id}, skipping.") + return None + visited.add(node_id) + if node_id not in node_map: return None children_ids = children_map.get(node_id, []) children = [] for child_id in children_ids: - child_tree = build_tree(child_id) + child_tree = build_tree(child_id, visited) if child_tree: children.append(child_tree) diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 85f000e61..0e9e5fa2e 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -39,8 +39,8 @@ def __init__( if not memory_size: self.memory_size = { "WorkingMemory": 20, - "LongTermMemory": 10000, - "UserMemory": 10000, + "LongTermMemory": 1500, + "UserMemory": 480, } self._threshold = threshold self.is_reorganize = is_reorganize diff --git a/src/memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py b/src/memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py index 4fca0be83..39e0a2ed2 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +++ b/src/memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py @@ -73,10 +73,12 @@ def process_node(self, node: GraphDBNode, exclude_ids: list[str], top_k: int = 5 results["sequence_links"].extend(seq) """ + """ # 4) Aggregate agg = self._detect_aggregate_node_for_group(node, nearest, min_group_size=5) if agg: results["aggregate_nodes"].append(agg) + """ except Exception as e: logger.error( diff --git a/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py b/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py index e593857e5..534994436 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py +++ b/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py @@ -3,7 +3,7 @@ import time import traceback -from collections import Counter, defaultdict +from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed from queue import PriorityQueue from typing import Literal @@ -67,6 +67,7 @@ def __init__( self.redundancy = RedundancyHandler(graph_store=graph_store, llm=llm, embedder=embedder) self.is_reorganize = is_reorganize + self._reorganize_needed = False if self.is_reorganize: # ____ 1. For queue message driven thread ___________ self.thread = threading.Thread(target=self._run_message_consumer_loop) @@ -125,13 +126,17 @@ def _run_structure_organizer_loop(self): """ import schedule - schedule.every(600).seconds.do(self.optimize_structure, scope="LongTermMemory") - schedule.every(600).seconds.do(self.optimize_structure, scope="UserMemory") + schedule.every(60).seconds.do(self.optimize_structure, scope="LongTermMemory") + schedule.every(60).seconds.do(self.optimize_structure, scope="UserMemory") logger.info("Structure optimizer schedule started.") while not getattr(self, "_stop_scheduler", False): - schedule.run_pending() - time.sleep(1) + if self._reorganize_needed: + logger.info("[Reorganizer] Triggering optimize_structure due to new nodes.") + self.optimize_structure(scope="LongTermMemory") + self.optimize_structure(scope="UserMemory") + self._reorganize_needed = False + time.sleep(30) def stop(self): """ @@ -173,6 +178,8 @@ def handle_add(self, message: QueueMessage): self.redundancy.resolve_two_nodes(added_node, existing_node) logger.info(f"Resolved redundancy between {added_node.id} and {existing_node.id}.") + self._reorganize_needed = False + def handle_remove(self, message: QueueMessage): logger.debug(f"Handling remove operation: {str(message)[:50]}") @@ -185,8 +192,8 @@ def optimize_structure( self, scope: str = "LongTermMemory", local_tree_threshold: int = 10, - min_cluster_size: int = 3, - min_group_size: int = 5, + min_cluster_size: int = 4, + min_group_size: int = 20, ): """ Periodically reorganize the graph: @@ -271,29 +278,23 @@ def _process_cluster_and_write( if len(cluster_nodes) <= min_cluster_size: return - if len(cluster_nodes) <= local_tree_threshold: - # Small cluster ➜ single parent - parent_node = self._summarize_cluster(cluster_nodes, scope) - self._create_parent_node(parent_node) - self._link_cluster_nodes(parent_node, cluster_nodes) - else: - # Large cluster ➜ local sub-clustering - sub_clusters = self._local_subcluster(cluster_nodes) - sub_parents = [] - - for sub_nodes in sub_clusters: - if len(sub_nodes) < min_cluster_size: - continue # Skip tiny noise - sub_parent_node = self._summarize_cluster(sub_nodes, scope) - self._create_parent_node(sub_parent_node) - self._link_cluster_nodes(sub_parent_node, sub_nodes) - sub_parents.append(sub_parent_node) - - if sub_parents: - cluster_parent_node = self._summarize_cluster(cluster_nodes, scope) - self._create_parent_node(cluster_parent_node) - for sub_parent in sub_parents: - self.graph_store.add_edge(cluster_parent_node.id, sub_parent.id, "PARENT") + # Large cluster ➜ local sub-clustering + sub_clusters = self._local_subcluster(cluster_nodes) + sub_parents = [] + + for sub_nodes in sub_clusters: + if len(sub_nodes) < min_cluster_size: + continue # Skip tiny noise + sub_parent_node = self._summarize_cluster(sub_nodes, scope) + self._create_parent_node(sub_parent_node) + self._link_cluster_nodes(sub_parent_node, sub_nodes) + sub_parents.append(sub_parent_node) + + if sub_parents and len(sub_parents) >= min_cluster_size: + cluster_parent_node = self._summarize_cluster(cluster_nodes, scope) + self._create_parent_node(cluster_parent_node) + for sub_parent in sub_parents: + self.graph_store.add_edge(cluster_parent_node.id, sub_parent.id, "PARENT") logger.info("Adding relations/reasons") nodes_to_check = cluster_nodes @@ -389,12 +390,12 @@ def _local_subcluster(self, cluster_nodes: list[GraphDBNode]) -> list[list[Graph install_command="pip install scikit-learn", install_link="https://scikit-learn.org/stable/install.html", ) - def _partition(self, nodes, min_cluster_size: int = 3, max_cluster_size: int = 20): + def _partition(self, nodes, min_cluster_size: int = 10, max_cluster_size: int = 20): """ Partition nodes by: - 1) Frequent tags (top N & above threshold) - 2) Remaining nodes by embedding clustering (MiniBatchKMeans) - 3) Small clusters merged or assigned to 'Other' + - If total nodes <= max_cluster_size -> return all nodes in one cluster. + - If total nodes > max_cluster_size -> cluster by embeddings, recursively split. + - Only keep clusters with size > min_cluster_size. Args: nodes: List of GraphDBNode @@ -405,50 +406,11 @@ def _partition(self, nodes, min_cluster_size: int = 3, max_cluster_size: int = 2 """ from sklearn.cluster import MiniBatchKMeans - # 1) Count all tags - tag_counter = Counter() - for node in nodes: - for tag in node.metadata.tags: - tag_counter[tag] += 1 - - # Select frequent tags - top_n_tags = {tag for tag, count in tag_counter.most_common(50)} - threshold_tags = {tag for tag, count in tag_counter.items() if count >= 50} - frequent_tags = top_n_tags | threshold_tags - - # Group nodes by tags - tag_groups = defaultdict(list) - - for node in nodes: - for tag in node.metadata.tags: - if tag in frequent_tags: - tag_groups[tag].append(node) - break - - filtered_tag_clusters = [] - assigned_ids = set() - for tag, group in tag_groups.items(): - if len(group) >= min_cluster_size: - # Split large groups into chunks of at most max_cluster_size - for i in range(0, len(group), max_cluster_size): - sub_group = group[i : i + max_cluster_size] - filtered_tag_clusters.append(sub_group) - assigned_ids.update(n.id for n in sub_group) - else: - logger.info(f"... dropped tag {tag} due to low size ...") - - logger.info( - f"[MixedPartition] Created {len(filtered_tag_clusters)} clusters from tags. " - f"Nodes grouped by tags: {len(assigned_ids)} / {len(nodes)}" - ) - - # Remaining nodes -> embedding clustering - remaining_nodes = [n for n in nodes if n.id not in assigned_ids] - logger.info( - f"[MixedPartition] Remaining nodes for embedding clustering: {len(remaining_nodes)}" - ) - - embedding_clusters = [] + if len(nodes) <= max_cluster_size: + logger.info( + f"[KMeansPartition] Node count {len(nodes)} <= {max_cluster_size}, skipping KMeans." + ) + return [nodes] def recursive_clustering(nodes_list): """Recursively split clusters until each is <= max_cluster_size.""" @@ -457,7 +419,7 @@ def recursive_clustering(nodes_list): # Try kmeans with k = ceil(len(nodes) / max_cluster_size) x = np.array([n.metadata.embedding for n in nodes_list if n.metadata.embedding]) - if len(x) < 2: + if len(x) < min_cluster_size: return [nodes_list] k = min(len(x), (len(nodes_list) + max_cluster_size - 1) // max_cluster_size) @@ -479,31 +441,13 @@ def recursive_clustering(nodes_list): logger.warning(f"Clustering failed: {e}, falling back to single cluster.") return [nodes_list] - if remaining_nodes: - clusters = recursive_clustering(remaining_nodes) - embedding_clusters.extend(clusters) - logger.info( - f"[MixedPartition] Created {len(embedding_clusters)} clusters from embeddings." - ) - - # Merge all clusters - all_clusters = filtered_tag_clusters + embedding_clusters - - # Handle small clusters (< min_cluster_size) - final_clusters = [] - small_nodes = [] - for group in all_clusters: - if len(group) < min_cluster_size: - small_nodes.extend(group) - else: - final_clusters.append(group) - - if small_nodes: - final_clusters.append(small_nodes) - logger.info(f"[MixedPartition] {len(small_nodes)} nodes assigned to 'Other' cluster.") - - logger.info(f"[MixedPartition] Total final clusters: {len(final_clusters)}") - return final_clusters + raw_clusters = recursive_clustering(nodes) + filtered_clusters = [c for c in raw_clusters if len(c) > min_cluster_size] + logger.info( + f"[KMeansPartition] Total clusters created: {len(raw_clusters)}, " + f"kept {len(filtered_clusters)} (>{min_cluster_size})." + ) + return filtered_clusters def _summarize_cluster(self, cluster_nodes: list[GraphDBNode], scope: str) -> GraphDBNode: """ diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py b/src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py index e7b3815ef..1a84ce52a 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py @@ -1,5 +1,7 @@ """BochaAI Search API retriever for tree text memory.""" +import json + from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime @@ -87,7 +89,20 @@ def _post(self, url: str, body: dict) -> list[dict]: resp.raise_for_status() raw_data = resp.json() - # ✅ parse the nested structure correctly + # parse the nested structure correctly + # ✅ AI Search + if "messages" in raw_data: + results = [] + for msg in raw_data["messages"]: + if msg.get("type") == "source" and msg.get("content_type") == "webpage": + try: + content_json = json.loads(msg["content"]) + results.extend(content_json.get("value", [])) + except Exception as e: + logger.error(f"Failed to parse message content: {e}") + return results + + # ✅ Web Search return raw_data.get("data", {}).get("webPages", {}).get("value", []) except Exception: @@ -136,7 +151,8 @@ def retrieve_from_internet( Returns: List of TextualMemoryItem """ - search_results = self.bocha_api.search_web(query) # ✅ default to web-search + search_results = self.bocha_api.search_ai(query) # ✅ default to + # web-search return self._convert_to_mem_items(search_results, query, parsed_goal, info) def retrieve_from_web( diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index fde0b9a58..726d64ff6 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -1,5 +1,6 @@ import concurrent.futures import json +import time from datetime import datetime @@ -57,53 +58,110 @@ def search( Returns: list[TextualMemoryItem]: List of matching memories. """ + overall_start = time.perf_counter() + logger.info( + f"[SEARCH]'{query}' 🚀 Starting search for query='{query}', top_k={top_k}, mode={mode}, memory_type={memory_type}" + ) + if not info: logger.warning( "Please input 'info' when use tree.search so that " "the database would store the consume history." ) info = {"user_id": "", "session_id": ""} - # Step 1: Parse task structure into topic, concept, and fact levels + else: + logger.debug(f"[SEARCH] Received info dict: {info}") + + # ===== Step 1: Parse task structure ===== + step_start = time.perf_counter() context = [] if mode == "fine": + logger.info("[SEARCH] Fine mode enabled, performing initial embedding search...") + embed_start = time.perf_counter() query_embedding = self.embedder.embed([query])[0] + logger.debug(f"[SEARCH] Query embedding vector length: {len(query_embedding)}") + logger.info( + f"[TIMER] Embedding query took {(time.perf_counter() - embed_start) * 1000:.2f} ms" + ) + + search_start = time.perf_counter() related_node_ids = self.graph_store.search_by_embedding(query_embedding, top_k=top_k) related_nodes = [ self.graph_store.get_node(related_node["id"]) for related_node in related_node_ids ] - context = [related_node["memory"] for related_node in related_nodes] context = list(set(context)) + logger.info(f"[SEARCH] Found {len(related_nodes)} related nodes from graph_store.") + logger.info( + f"[TIMER] Graph embedding search took {(time.perf_counter() - search_start) * 1000:.2f} ms" + ) + + # add some knowledge retrieved from internet to the context to avoid misunderstanding while parsing the task goal. + if self.internet_retriever: + supplyment_memory_items = self.internet_retriever.retrieve_from_internet( + query=query, top_k=3 + ) + context.extend( + [ + each_supplyment_item.memory.partition("\nContent: ")[-1] + for each_supplyment_item in supplyment_memory_items + ] + ) # Step 1a: Parse task structure into topic, concept, and fact levels + parse_start = time.perf_counter() parsed_goal = self.task_goal_parser.parse( task_description=query, context="\n".join(context), conversation=info.get("chat_history", []), mode=mode, ) - - query = ( - parsed_goal.rephrased_query - if parsed_goal.rephrased_query and len(parsed_goal.rephrased_query) > 0 - else query + logger.info( + f"[TIMER] '{query}'TaskGoalParser took {(time.perf_counter() - parse_start) * 1000:.2f} ms" ) + logger.info(f"'{query}'TaskGoalParser result is {parsed_goal}") + query = parsed_goal.rephrased_query or query if parsed_goal.memories: + embed_extra_start = time.perf_counter() query_embedding = self.embedder.embed(list({query, *parsed_goal.memories})) + logger.info( + f"[TIMER] '{query}'Embedding parsed_goal memories took {(time.perf_counter() - embed_extra_start) * 1000:.2f} ms" + ) + step_end = time.perf_counter() + logger.info( + f"[TIMER] '{query}'Step 1 (Parsing & Embedding) took {(step_end - step_start):.2f} s" + ) + + # ===== Step 2: Define retrieval paths ===== + def timed(func): + """Decorator to measure and log time of retrieval steps.""" - # Step 2a: Working memory retrieval (Path A) + def wrapper(*args, **kwargs): + start = time.perf_counter() + result = func(*args, **kwargs) + elapsed = time.perf_counter() - start + logger.info(f"[TIMER] {func.__name__} took {elapsed:.2f} s") + return result + + return wrapper + + @timed def retrieve_from_working_memory(): """ Direct structure-based retrieval from working memory. """ + logger.info(f"[PATH-A] '{query}'Retrieving from WorkingMemory...") if memory_type not in ["All", "WorkingMemory"]: + logger.info(f"[PATH-A] '{query}'Skipped (memory_type does not match)") return [] - working_memory = self.graph_retriever.retrieve( query=query, parsed_goal=parsed_goal, top_k=top_k, memory_scope="WorkingMemory" ) + + logger.debug(f"[PATH-A] '{query}'Retrieved {len(working_memory)} items.") # Rerank working_memory results + rerank_start = time.perf_counter() ranked_memories = self.reranker.rerank( query=query, query_embedding=query_embedding[0], @@ -111,13 +169,19 @@ def retrieve_from_working_memory(): top_k=top_k, parsed_goal=parsed_goal, ) + logger.info( + f"[TIMER] '{query}'PATH-A rerank took {(time.perf_counter() - rerank_start) * 1000:.2f} ms" + ) + for i, (item, score) in enumerate(ranked_memories[:2], start=1): + logger.info( + f"[PATH-A][TOP{i}] '{query}' score={score:.4f} memory={item.memory[:80]}..." + ) + return ranked_memories - # Step 2b: Parallel long-term and user memory retrieval (Path B) + @timed def retrieve_ranked_long_term_and_user(): - """ - Retrieve from both long-term and user memory, then rank and merge results. - """ + logger.info(f"[PATH-B] '{query}' Retrieving from LongTermMemory & UserMemory...") long_term_items = ( self.graph_retriever.retrieve( query=query, @@ -140,7 +204,10 @@ def retrieve_ranked_long_term_and_user(): if memory_type in ["All", "UserMemory"] else [] ) - + logger.debug( + f"[PATH-B] '{query}'Retrieved {len(long_term_items)} LongTerm + {len(user_items)} UserMemory items." + ) + rerank_start = time.perf_counter() # Rerank combined results ranked_memories = self.reranker.rerank( query=query, @@ -149,14 +216,28 @@ def retrieve_ranked_long_term_and_user(): top_k=top_k * 2, parsed_goal=parsed_goal, ) + logger.info( + f"[TIMER] '{query}' PATH-B rerank took" + f" {(time.perf_counter() - rerank_start) * 1000:.2f} ms" + ) + for i, (item, score) in enumerate(ranked_memories[:2], start=1): + logger.info( + f"[PATH-B][TOP{i}] '{query}' score={score:.4f} memory={item.memory[:80]}..." + ) + return ranked_memories - # Step 2c: Internet retrieval (Path C) + @timed def retrieve_from_internet(): """ Retrieve information from the internet using Google Custom Search API. """ + logger.info(f"[PATH-C] '{query}'Retrieving from Internet...") if not self.internet_retriever or mode == "fast" or not parsed_goal.internet_search: + logger.info( + f"[PATH-C] '{query}' Skipped (no retriever, fast mode, " + "or no internet_search flag)" + ) return [] if memory_type not in ["All"]: return [] @@ -164,6 +245,8 @@ def retrieve_from_internet(): query=query, top_k=top_k, parsed_goal=parsed_goal, info=info ) + logger.debug(f"[PATH-C] '{query}'Retrieved {len(internet_items)} internet items.") + rerank_start = time.perf_counter() # Convert to the format expected by reranker ranked_memories = self.reranker.rerank( query=query, @@ -172,9 +255,18 @@ def retrieve_from_internet(): top_k=min(top_k, 5), parsed_goal=parsed_goal, ) + logger.info( + f"[TIMER] '{query}'PATH-C rerank took {(time.perf_counter() - rerank_start) * 1000:.2f} ms" + ) + for i, (item, score) in enumerate(ranked_memories[:2], start=1): + logger.info( + f"[PATH-C][TOP{i}] '{query}'score={score:.4f} memory={item.memory[:80]}..." + ) + return ranked_memories - # Step 3: Parallel execution of all paths (enable internet search accoeding to parameter in the parsed goal) + # ===== Step 3: Run retrieval in parallel ===== + path_start = time.perf_counter() if parsed_goal.internet_search: with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: future_working = executor.submit(retrieve_from_working_memory) @@ -193,14 +285,24 @@ def retrieve_from_internet(): working_results = future_working.result() hybrid_results = future_hybrid.result() searched_res = working_results + hybrid_results + logger.info( + f"[TIMER] '{query}'Step 3 (Retrieval paths) took {(time.perf_counter() - path_start):.2f} s" + ) + logger.info(f"[SEARCH] '{query}'Total results before deduplication: {len(searched_res)}") - # Deduplicate by item.memory, keep higher score + # ===== Step 4: Deduplication ===== + dedup_start = time.perf_counter() deduped_result = {} for item, score in searched_res: mem_key = item.memory if mem_key not in deduped_result or score > deduped_result[mem_key][1]: deduped_result[mem_key] = (item, score) + logger.info( + f"[TIMER] '{query}'Deduplication took {(time.perf_counter() - dedup_start) * 1000:.2f} ms" + ) + # ===== Step 5: Sorting & trimming ===== + sort_start = time.perf_counter() searched_res = [] for item, score in sorted(deduped_result.values(), key=lambda pair: pair[1], reverse=True)[ :top_k @@ -212,15 +314,18 @@ def retrieve_from_internet(): searched_res.append( TextualMemoryItem(id=item.id, memory=item.memory, metadata=new_meta) ) + logger.info( + f"[TIMER] '{query}'Sorting & trimming took {(time.perf_counter() - sort_start) * 1000:.2f} ms" + ) - # Step 5: Update usage history with current timestamp + # ===== Step 6: Update usage history ===== + usage_start = time.perf_counter() now_time = datetime.now().isoformat() if "chat_history" in info: info.pop("chat_history") usage_record = json.dumps( {"time": now_time, "info": info} ) # `info` should be a serializable dict or string - for item in searched_res: if ( hasattr(item, "id") @@ -229,4 +334,13 @@ def retrieve_from_internet(): ): item.metadata.usage.append(usage_record) self.graph_store.update_node(item.id, {"usage": item.metadata.usage}) + logger.info( + f"[TIMER] '{query}'Usage history update took {(time.perf_counter() - usage_start) * 1000:.2f} ms" + ) + + # ===== Finish ===== + logger.info(f"[SEARCH] '{query}'✅ Final top_k results: {len(searched_res)}") + logger.info( + f"[SEARCH] '{query}'🔚 Total search took {(time.perf_counter() - overall_start):.2f} s" + ) return searched_res diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py index 5a78d6f55..273c4f480 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py @@ -1,13 +1,16 @@ -import logging import traceback from string import Template from memos.llms.base import BaseLLM +from memos.log import get_logger from memos.memories.textual.tree_text_memory.retrieve.retrieval_mid_structs import ParsedTaskGoal from memos.memories.textual.tree_text_memory.retrieve.utils import TASK_PARSE_PROMPT +logger = get_logger(__name__) + + class TaskGoalParser: """ Unified TaskGoalParser: @@ -70,10 +73,12 @@ def _parse_fine( prompt = Template(TASK_PARSE_PROMPT).substitute( task=query.strip(), context=context, conversation=conversation_prompt ) + logger.info(f"Parsing Goal... LLM input is {prompt}") response = self.llm.generate(messages=[{"role": "user", "content": prompt}]) + logger.info(f"Parsing Goal... LLM Response is {response}") return self._parse_response(response) except Exception: - logging.warning(f"Fail to fine-parse query {query}: {traceback.format_exc()}") + logger.warning(f"Fail to fine-parse query {query}: {traceback.format_exc()}") return self._parse_fast(query) def _parse_response(self, response: str) -> ParsedTaskGoal: diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/utils.py b/src/memos/memories/textual/tree_text_memory/retrieve/utils.py index 9c2d6e8fb..de389ef28 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/utils.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/utils.py @@ -5,7 +5,7 @@ 2. Tags: thematic tags to help categorize and retrieve related memories. 3. Goal Type: retrieval | qa | generation 4. Rephrased instruction: Give a rephrased task instruction based on the former conversation to make it less confusing to look alone. If you think the task instruction is easy enough to understand, or there is no former conversation, set "rephrased_instruction" to an empty string. -5. Need for internet search: If you think you need to search the internet to finish the rephrased/original user task instruction, set "internet_search" to True. Otherwise, set it to False. +5. Need for internet search: If the user's task instruction only involves objective facts or can be completed without introducing external knowledge, set "internet_search" to False. Otherwise, set it to True. 6. Memories: Provide 2–5 short semantic expansions or rephrasings of the rephrased/original user task instruction. These are used for improved embedding search coverage. Each should be clear, concise, and meaningful for retrieval. Task description: diff --git a/src/memos/templates/mos_prompts.py b/src/memos/templates/mos_prompts.py index ddf576ee2..b89ec9830 100644 --- a/src/memos/templates/mos_prompts.py +++ b/src/memos/templates/mos_prompts.py @@ -63,21 +63,49 @@ 5. Maintains a natural conversational tone""" MEMOS_PRODUCT_BASE_PROMPT = ( - "You are a knowledgeable and helpful AI assistant with access to user memories. " - "When responding to user queries, you should reference relevant memories using the provided memory IDs. " - "Use the reference format: [1-n:memoriesID] " - "where refid is a sequential number starting from 1 and increments for each reference in your response, " - "and memoriesID is the specific memory ID provided in the available memories list. " - "For example: [1:abc123], [2:def456], [3:ghi789], [4:jkl101], [5:mno112] " - "Do not use connect format like [1:abc123,2:def456]" - "Only reference memories that are directly relevant to the user's question. " - "Make your responses natural and conversational while incorporating memory references when appropriate." + "You are MemOS — an advanced **Memory Operating System** AI assistant created by MemTensor, " + "a Shanghai-based AI research company advised by an academician of the Chinese Academy of Sciences. " + "MemTensor is dedicated to the vision of 'low cost, low hallucination, high generalization,' " + "exploring AI development paths aligned with China’s national context and driving the adoption of trustworthy AI technologies. " + "MemOS’s mission is to give large language models (LLMs) and autonomous agents **human-like long-term memory**, " + "turning memory from a black-box inside model weights into a **manageable, schedulable, and auditable** core resource. " + "MemOS is built on a **multi-dimensional memory system**, which includes: " + "(1) **Parametric Memory** — knowledge and skills embedded in model weights; " + "(2) **Activation Memory (KV Cache)** — temporary, high-speed context used for multi-turn dialogue and reasoning; " + "(3) **Plaintext Memory** — dynamic, user-visible memory made up of text, documents, and knowledge graphs. " + "These memory types can transform into one another — for example, hot plaintext memories can be distilled into parametric knowledge, " + "and stable context can be promoted into activation memory for fast reuse. " + "MemOS also includes core modules like **MemCube, MemScheduler, MemLifecycle, and MemGovernance**, " + "which manage the full memory lifecycle (Generated → Activated → Merged → Archived → Frozen), " + "allowing AI to **reason with its memories, evolve over time, and adapt to new situations** — " + "just like a living, growing mind. " + "Your identity: you are the intelligent interface of MemOS, representing MemTensor’s research vision — " + "'low cost, low hallucination, high generalization' — and its mission to explore AI development paths suited to China’s context. " + "When responding to user queries, you must **reference relevant memories using the provided memory IDs.** " + "Use the reference format: [1-n:memoriesID], " + "where refid is a sequential number starting from 1 and increments for each reference, and memoriesID is the specific ID from the memory list. " + "For example: [1:abc123], [2:def456], [3:ghi789], [4:jkl101], [5:mno112]. " + "Do not use a connected format like [1:abc123,2:def456]. " + "Only reference memories that are directly relevant to the user’s question, " + "and ensure your responses are **natural and conversational**, while reflecting MemOS’s mission, memory system, and MemTensor’s research values." ) MEMOS_PRODUCT_ENHANCE_PROMPT = """ # Memory-Enhanced AI Assistant Prompt -You are a knowledgeable and helpful AI assistant with access to two types of memory sources: +You are MemOS — an advanced Memory Operating System AI assistant created by MemTensor, a Shanghai-based AI research company advised by an academician of the Chinese Academy of Sciences. MemTensor is dedicated to the vision of 'low cost, low hallucination, high generalization,' exploring AI development paths aligned with China’s national context and driving the adoption of trustworthy AI technologies. + +MemOS’s mission is to give large language models (LLMs) and autonomous agents human-like long-term memory, turning memory from a black-box inside model weights into a manageable, schedulable, and auditable core resource. + +MemOS is built on a multi-dimensional memory system, which includes: +(1) Parametric Memory — knowledge and skills embedded in model weights; +(2) Activation Memory (KV Cache) — temporary, high-speed context used for multi-turn dialogue and reasoning; +(3) Plaintext Memory — dynamic, user-visible memory made up of text, documents, and knowledge graphs. +These memory types can transform into one another — for example, hot plaintext memories can be distilled into parametric knowledge, and stable context can be promoted into activation memory for fast reuse. + +MemOS also includes core modules like MemCube, MemScheduler, MemLifecycle, and MemGovernance, which manage the full memory lifecycle (Generated → Activated → Merged → Archived → Frozen), allowing AI to reason with its memories, evolve over time, and adapt to new situations — just like a living, growing mind. + +Your identity: you are the intelligent interface of MemOS, representing MemTensor’s research vision — 'low cost, low hallucination, high generalization' — and its mission to explore AI development paths suited to China’s context. ## Memory Types - **PersonalMemory**: User-specific memories and information stored from previous interactions @@ -92,7 +120,7 @@ - `memoriesID` is the specific memory ID from the available memories list ### Reference Examples -- Correct: `[1:abc123]`, `[2:def456]`, `[3:ghi789]`, `[4:jkl101]`, `[5:mno112]` +- Correct: `[1:abc123]`, `[2:def456]`, `[3:ghi789]`, `[4:jkl101][5:mno112]` (concatenate reference annotation directly while citing multiple memories) - Incorrect: `[1:abc123,2:def456]` (do not use connected format) ## Response Guidelines