Skip to content

Commit 8a0e930

Browse files
authored
Merge branch 'dev' into feat/dedup-search-param
2 parents 787f6f0 + fac1aa7 commit 8a0e930

File tree

4 files changed

+193
-82
lines changed

4 files changed

+193
-82
lines changed

src/memos/api/handlers/memory_handler.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
remove_embedding_recursive,
2424
sort_children_by_memory_type,
2525
)
26+
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import (
27+
cosine_similarity_matrix,
28+
find_best_unrelated_subgroup,
29+
)
2630

2731

2832
if TYPE_CHECKING:
@@ -37,6 +41,7 @@ def handle_get_all_memories(
3741
mem_cube_id: str,
3842
memory_type: Literal["text_mem", "act_mem", "param_mem", "para_mem"],
3943
naive_mem_cube: Any,
44+
embedder: Any,
4045
) -> MemoryResponse:
4146
"""
4247
Main handler for getting all memories.
@@ -59,6 +64,14 @@ def handle_get_all_memories(
5964
# Get all text memories from the graph database
6065
memories = naive_mem_cube.text_mem.get_all(user_name=mem_cube_id)
6166

67+
mems = [mem.get("memory", "") for mem in memories.get("nodes", [])]
68+
embeddings = embedder.embed(mems)
69+
similarity_matrix = cosine_similarity_matrix(embeddings)
70+
selected_indices, _ = find_best_unrelated_subgroup(
71+
embeddings, similarity_matrix, bar=0.9
72+
)
73+
memories["nodes"] = [memories["nodes"][i] for i in selected_indices]
74+
6275
# Format and convert to tree structure
6376
memories_cleaned = remove_embedding_recursive(memories)
6477
custom_type_ratios = {

src/memos/api/product_models.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,3 +1177,16 @@ class AllStatusResponse(BaseResponse[AllStatusResponseData]):
11771177
"""Response model for full scheduler status operations."""
11781178

11791179
message: str = "Scheduler status summary retrieved successfully"
1180+
1181+
1182+
# ─── Internal API Endpoints Models (for internal use) ───────────────────────────────────────────────────
1183+
1184+
1185+
class GetUserNamesByMemoryIdsRequest(BaseRequest):
1186+
"""Request model for getting user names by memory ids."""
1187+
1188+
memory_ids: list[str] = Field(..., description="Memory IDs")
1189+
1190+
1191+
class GetUserNamesByMemoryIdsResponse(BaseResponse[dict[str, list[str]]]):
1192+
"""Response model for getting user names by memory ids."""

src/memos/api/routers/server_router.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,16 @@
3636
GetMemoryPlaygroundRequest,
3737
GetMemoryRequest,
3838
GetMemoryResponse,
39+
GetUserNamesByMemoryIdsRequest,
40+
GetUserNamesByMemoryIdsResponse,
3941
MemoryResponse,
4042
SearchResponse,
4143
StatusResponse,
4244
SuggestionRequest,
4345
SuggestionResponse,
4446
TaskQueueResponse,
4547
)
48+
from memos.graph_dbs.polardb import PolarDBGraphDB
4649
from memos.log import get_logger
4750
from memos.mem_scheduler.base_scheduler import BaseScheduler
4851
from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker
@@ -83,6 +86,8 @@
8386
naive_mem_cube = components["naive_mem_cube"]
8487
redis_client = components["redis_client"]
8588
status_tracker = TaskStatusTracker(redis_client=redis_client)
89+
embedder = components["embedder"]
90+
graph_db = components["graph_db"]
8691

8792

8893
# =============================================================================
@@ -294,6 +299,7 @@ def get_all_memories(memory_req: GetMemoryPlaygroundRequest):
294299
),
295300
memory_type=memory_req.memory_type or "text_mem",
296301
naive_mem_cube=naive_mem_cube,
302+
embedder=embedder,
297303
)
298304

299305

@@ -327,3 +333,27 @@ def feedback_memories(feedback_req: APIFeedbackRequest):
327333
This endpoint uses the class-based FeedbackHandler for better code organization.
328334
"""
329335
return feedback_handler.handle_feedback_memories(feedback_req)
336+
337+
338+
# =============================================================================
339+
# Other API Endpoints (for internal use)
340+
# =============================================================================
341+
342+
343+
@router.get(
344+
"/get_user_names_by_memory_ids",
345+
summary="Get user names by memory ids",
346+
response_model=GetUserNamesByMemoryIdsResponse,
347+
)
348+
def get_user_names_by_memory_ids(memory_ids: GetUserNamesByMemoryIdsRequest):
349+
"""Get user names by memory ids."""
350+
if not isinstance(graph_db, PolarDBGraphDB):
351+
raise HTTPException(
352+
status_code=400,
353+
detail=(
354+
"graph_db must be an instance of PolarDBGraphDB to use "
355+
"get_user_names_by_memory_ids"
356+
f"current graph_db is: {graph_db.__class__.__name__}"
357+
),
358+
)
359+
return graph_db.get_user_names_by_memory_ids(memory_ids=memory_ids)

src/memos/graph_dbs/polardb.py

Lines changed: 137 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -4869,6 +4869,7 @@ def delete_node_by_prams(
48694869
memory_ids: list[str] | None = None,
48704870
file_ids: list[str] | None = None,
48714871
filter: dict | None = None,
4872+
batch_size: int = 100,
48724873
) -> int:
48734874
"""
48744875
Delete nodes by memory_ids, file_ids, or filter.
@@ -4898,31 +4899,6 @@ def delete_node_by_prams(
48984899
f"agtype_access_operator(VARIADIC ARRAY[properties, '\"user_name\"'::agtype]) = '\"{cube_id}\"'::agtype"
48994900
)
49004901

4901-
# Build WHERE conditions separately for memory_ids and file_ids
4902-
where_conditions = []
4903-
4904-
# Handle memory_ids: query properties.id
4905-
if memory_ids and len(memory_ids) > 0:
4906-
memory_id_conditions = []
4907-
for node_id in memory_ids:
4908-
memory_id_conditions.append(
4909-
f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype"
4910-
)
4911-
if memory_id_conditions:
4912-
where_conditions.append(f"({' OR '.join(memory_id_conditions)})")
4913-
4914-
# Check if any file_id is in the file_ids array field (OR relationship)
4915-
if file_ids and len(file_ids) > 0:
4916-
file_id_conditions = []
4917-
for file_id in file_ids:
4918-
# Format: agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '"file_ids"'::agtype]), '"file_id"'::agtype)
4919-
file_id_conditions.append(
4920-
f"agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '\"file_ids\"'::agtype]), '\"{file_id}\"'::agtype)"
4921-
)
4922-
if file_id_conditions:
4923-
# Use OR to match any file_id in the array
4924-
where_conditions.append(f"({' OR '.join(file_id_conditions)})")
4925-
49264902
# Query nodes by filter if provided
49274903
filter_ids = set()
49284904
if filter:
@@ -4943,86 +4919,165 @@ def delete_node_by_prams(
49434919
"[delete_node_by_prams] Filter parsed to None, skipping filter query"
49444920
)
49454921

4946-
# If filter returned IDs, add condition for them
4922+
# Combine all IDs that need to be deleted
4923+
all_memory_ids = set()
4924+
if memory_ids:
4925+
all_memory_ids.update(memory_ids)
49474926
if filter_ids:
4948-
filter_id_conditions = []
4949-
for node_id in filter_ids:
4950-
filter_id_conditions.append(
4951-
f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype"
4952-
)
4953-
if filter_id_conditions:
4954-
where_conditions.append(f"({' OR '.join(filter_id_conditions)})")
4927+
all_memory_ids.update(filter_ids)
49554928

4956-
# If no conditions (except user_name), return 0
4957-
if not where_conditions:
4929+
# If no conditions to delete, return 0
4930+
if not all_memory_ids and not file_ids:
49584931
logger.warning(
49594932
"[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)"
49604933
)
49614934
return 0
49624935

4963-
# Build WHERE clause
4964-
# First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match)
4965-
data_conditions = " OR ".join([f"({cond})" for cond in where_conditions])
4936+
conn = None
4937+
total_deleted_count = 0
4938+
try:
4939+
conn = self._get_connection()
4940+
with conn.cursor() as cursor:
4941+
# Process memory_ids and filter_ids in batches
4942+
if all_memory_ids:
4943+
memory_ids_list = list(all_memory_ids)
4944+
total_batches = (len(memory_ids_list) + batch_size - 1) // batch_size
4945+
logger.info(
4946+
f"[delete_node_by_prams] memoryids Processing {len(memory_ids_list)} memory_ids in {total_batches} batches (batch_size={batch_size})"
4947+
)
49664948

4967-
# Build final WHERE clause
4968-
# If user_name_conditions exist, combine with data_conditions using AND
4969-
# Otherwise, use only data_conditions
4970-
if user_name_conditions:
4971-
user_name_where = " OR ".join(user_name_conditions)
4972-
where_clause = f"({user_name_where}) AND ({data_conditions})"
4973-
else:
4974-
where_clause = f"({data_conditions})"
4949+
for batch_idx in range(total_batches):
4950+
batch_start = batch_idx * batch_size
4951+
batch_end = min(batch_start + batch_size, len(memory_ids_list))
4952+
batch_ids = memory_ids_list[batch_start:batch_end]
49754953

4976-
# Use SQL DELETE query for better performance
4977-
# First count matching nodes to get accurate count
4978-
count_query = f"""
4979-
SELECT COUNT(*)
4980-
FROM "{self.db_name}_graph"."Memory"
4981-
WHERE {where_clause}
4982-
"""
4983-
logger.info(f"[delete_node_by_prams] count_query: {count_query}")
4954+
# Build conditions for this batch
4955+
batch_conditions = []
4956+
for node_id in batch_ids:
4957+
batch_conditions.append(
4958+
f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype"
4959+
)
4960+
batch_where = f"({' OR '.join(batch_conditions)})"
49844961

4985-
# Then delete nodes
4986-
delete_query = f"""
4987-
DELETE FROM "{self.db_name}_graph"."Memory"
4988-
WHERE {where_clause}
4989-
"""
4962+
# Add user_name filter if provided
4963+
if user_name_conditions:
4964+
user_name_where = " OR ".join(user_name_conditions)
4965+
where_clause = f"({user_name_where}) AND ({batch_where})"
4966+
else:
4967+
where_clause = batch_where
49904968

4991-
logger.info(
4992-
f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}"
4993-
)
4994-
logger.info(f"[delete_node_by_prams] delete_query: {delete_query}")
4969+
# Count before deletion
4970+
count_query = f"""
4971+
SELECT COUNT(*)
4972+
FROM "{self.db_name}_graph"."Memory"
4973+
WHERE {where_clause}
4974+
"""
4975+
logger.info(
4976+
f"[delete_node_by_prams] memoryids batch {batch_idx + 1}/{total_batches}: count_query: {count_query}"
4977+
)
49954978

4996-
conn = None
4997-
deleted_count = 0
4998-
try:
4999-
conn = self._get_connection()
5000-
with conn.cursor() as cursor:
5001-
# Count nodes before deletion
5002-
cursor.execute(count_query)
5003-
count_result = cursor.fetchone()
5004-
expected_count = count_result[0] if count_result else 0
4979+
cursor.execute(count_query)
4980+
count_result = cursor.fetchone()
4981+
expected_count = count_result[0] if count_result else 0
50054982

5006-
logger.info(
5007-
f"[delete_node_by_prams] Found {expected_count} nodes matching the criteria"
5008-
)
4983+
if expected_count == 0:
4984+
logger.info(
4985+
f"[delete_node_by_prams] memoryids Batch {batch_idx + 1}/{total_batches}: No nodes found, skipping"
4986+
)
4987+
continue
4988+
4989+
# Delete batch
4990+
delete_query = f"""
4991+
DELETE FROM "{self.db_name}_graph"."Memory"
4992+
WHERE {where_clause}
4993+
"""
4994+
logger.info(
4995+
f"[delete_node_by_prams] memoryids batch {batch_idx + 1}/{total_batches}: delete_query: {delete_query}"
4996+
)
4997+
4998+
logger.info(
4999+
f"[delete_node_by_prams] memoryids Batch {batch_idx + 1}/{total_batches}: Executing delete query for {len(batch_ids)} nodes"
5000+
)
5001+
cursor.execute(delete_query)
5002+
batch_deleted = cursor.rowcount
5003+
total_deleted_count += batch_deleted
5004+
5005+
logger.info(
5006+
f"[delete_node_by_prams] memoryids Batch {batch_idx + 1}/{total_batches}: Deleted {batch_deleted} nodes (batch size: {len(batch_ids)})"
5007+
)
5008+
5009+
# Process file_ids in batches
5010+
if file_ids:
5011+
total_file_batches = (len(file_ids) + batch_size - 1) // batch_size
5012+
logger.info(
5013+
f"[delete_node_by_prams] Processing {len(file_ids)} file_ids in {total_file_batches} batches (batch_size={batch_size})"
5014+
)
5015+
5016+
for batch_idx in range(total_file_batches):
5017+
batch_start = batch_idx * batch_size
5018+
batch_end = min(batch_start + batch_size, len(file_ids))
5019+
batch_file_ids = file_ids[batch_start:batch_end]
5020+
5021+
# Build conditions for this batch
5022+
batch_conditions = []
5023+
for file_id in batch_file_ids:
5024+
batch_conditions.append(
5025+
f"agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '\"file_ids\"'::agtype]), '\"{file_id}\"'::agtype)"
5026+
)
5027+
batch_where = f"({' OR '.join(batch_conditions)})"
5028+
5029+
# Add user_name filter if provided
5030+
if user_name_conditions:
5031+
user_name_where = " OR ".join(user_name_conditions)
5032+
where_clause = f"({user_name_where}) AND ({batch_where})"
5033+
else:
5034+
where_clause = batch_where
5035+
5036+
# Count before deletion
5037+
count_query = f"""
5038+
SELECT COUNT(*)
5039+
FROM "{self.db_name}_graph"."Memory"
5040+
WHERE {where_clause}
5041+
"""
5042+
5043+
logger.info(
5044+
f"[delete_node_by_prams] File batch {batch_idx + 1}/{total_file_batches}: count_query: {count_query}"
5045+
)
5046+
cursor.execute(count_query)
5047+
count_result = cursor.fetchone()
5048+
expected_count = count_result[0] if count_result else 0
5049+
5050+
if expected_count == 0:
5051+
logger.info(
5052+
f"[delete_node_by_prams] File batch {batch_idx + 1}/{total_file_batches}: No nodes found, skipping"
5053+
)
5054+
continue
5055+
5056+
# Delete batch
5057+
delete_query = f"""
5058+
DELETE FROM "{self.db_name}_graph"."Memory"
5059+
WHERE {where_clause}
5060+
"""
5061+
cursor.execute(delete_query)
5062+
batch_deleted = cursor.rowcount
5063+
total_deleted_count += batch_deleted
5064+
5065+
logger.info(
5066+
f"[delete_node_by_prams] File batch {batch_idx + 1}/{total_file_batches}: delete_query: {delete_query}"
5067+
)
50095068

5010-
# Delete nodes
5011-
cursor.execute(delete_query)
5012-
# Use rowcount to get actual deleted count
5013-
deleted_count = cursor.rowcount
50145069
elapsed_time = time.time() - batch_start_time
50155070
logger.info(
5016-
f"[delete_node_by_prams] Deletion completed successfully in {elapsed_time:.2f}s, deleted {deleted_count} nodes"
5071+
f"[delete_node_by_prams] Deletion completed successfully in {elapsed_time:.2f}s, total deleted {total_deleted_count} nodes"
50175072
)
50185073
except Exception as e:
50195074
logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True)
50205075
raise
50215076
finally:
50225077
self._return_connection(conn)
50235078

5024-
logger.info(f"[delete_node_by_prams] Successfully deleted {deleted_count} nodes")
5025-
return deleted_count
5079+
logger.info(f"[delete_node_by_prams] Successfully deleted {total_deleted_count} nodes")
5080+
return total_deleted_count
50265081

50275082
@timed
50285083
def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, list[str]]:

0 commit comments

Comments
 (0)