Skip to content

Commit fac1aa7

Browse files
authored
feat: add batch delete (#787)
1 parent 10342ef commit fac1aa7

File tree

1 file changed

+137
-82
lines changed

1 file changed

+137
-82
lines changed

src/memos/graph_dbs/polardb.py

Lines changed: 137 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -4869,6 +4869,7 @@ def delete_node_by_prams(
48694869
memory_ids: list[str] | None = None,
48704870
file_ids: list[str] | None = None,
48714871
filter: dict | None = None,
4872+
batch_size: int = 100,
48724873
) -> int:
48734874
"""
48744875
Delete nodes by memory_ids, file_ids, or filter.
@@ -4898,31 +4899,6 @@ def delete_node_by_prams(
48984899
f"agtype_access_operator(VARIADIC ARRAY[properties, '\"user_name\"'::agtype]) = '\"{cube_id}\"'::agtype"
48994900
)
49004901

4901-
# Build WHERE conditions separately for memory_ids and file_ids
4902-
where_conditions = []
4903-
4904-
# Handle memory_ids: query properties.id
4905-
if memory_ids and len(memory_ids) > 0:
4906-
memory_id_conditions = []
4907-
for node_id in memory_ids:
4908-
memory_id_conditions.append(
4909-
f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype"
4910-
)
4911-
if memory_id_conditions:
4912-
where_conditions.append(f"({' OR '.join(memory_id_conditions)})")
4913-
4914-
# Check if any file_id is in the file_ids array field (OR relationship)
4915-
if file_ids and len(file_ids) > 0:
4916-
file_id_conditions = []
4917-
for file_id in file_ids:
4918-
# Format: agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '"file_ids"'::agtype]), '"file_id"'::agtype)
4919-
file_id_conditions.append(
4920-
f"agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '\"file_ids\"'::agtype]), '\"{file_id}\"'::agtype)"
4921-
)
4922-
if file_id_conditions:
4923-
# Use OR to match any file_id in the array
4924-
where_conditions.append(f"({' OR '.join(file_id_conditions)})")
4925-
49264902
# Query nodes by filter if provided
49274903
filter_ids = set()
49284904
if filter:
@@ -4943,86 +4919,165 @@ def delete_node_by_prams(
49434919
"[delete_node_by_prams] Filter parsed to None, skipping filter query"
49444920
)
49454921

4946-
# If filter returned IDs, add condition for them
4922+
# Combine all IDs that need to be deleted
4923+
all_memory_ids = set()
4924+
if memory_ids:
4925+
all_memory_ids.update(memory_ids)
49474926
if filter_ids:
4948-
filter_id_conditions = []
4949-
for node_id in filter_ids:
4950-
filter_id_conditions.append(
4951-
f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype"
4952-
)
4953-
if filter_id_conditions:
4954-
where_conditions.append(f"({' OR '.join(filter_id_conditions)})")
4927+
all_memory_ids.update(filter_ids)
49554928

4956-
# If no conditions (except user_name), return 0
4957-
if not where_conditions:
4929+
# If no conditions to delete, return 0
4930+
if not all_memory_ids and not file_ids:
49584931
logger.warning(
49594932
"[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)"
49604933
)
49614934
return 0
49624935

4963-
# Build WHERE clause
4964-
# First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match)
4965-
data_conditions = " OR ".join([f"({cond})" for cond in where_conditions])
4936+
conn = None
4937+
total_deleted_count = 0
4938+
try:
4939+
conn = self._get_connection()
4940+
with conn.cursor() as cursor:
4941+
# Process memory_ids and filter_ids in batches
4942+
if all_memory_ids:
4943+
memory_ids_list = list(all_memory_ids)
4944+
total_batches = (len(memory_ids_list) + batch_size - 1) // batch_size
4945+
logger.info(
4946+
f"[delete_node_by_prams] memoryids Processing {len(memory_ids_list)} memory_ids in {total_batches} batches (batch_size={batch_size})"
4947+
)
49664948

4967-
# Build final WHERE clause
4968-
# If user_name_conditions exist, combine with data_conditions using AND
4969-
# Otherwise, use only data_conditions
4970-
if user_name_conditions:
4971-
user_name_where = " OR ".join(user_name_conditions)
4972-
where_clause = f"({user_name_where}) AND ({data_conditions})"
4973-
else:
4974-
where_clause = f"({data_conditions})"
4949+
for batch_idx in range(total_batches):
4950+
batch_start = batch_idx * batch_size
4951+
batch_end = min(batch_start + batch_size, len(memory_ids_list))
4952+
batch_ids = memory_ids_list[batch_start:batch_end]
49754953

4976-
# Use SQL DELETE query for better performance
4977-
# First count matching nodes to get accurate count
4978-
count_query = f"""
4979-
SELECT COUNT(*)
4980-
FROM "{self.db_name}_graph"."Memory"
4981-
WHERE {where_clause}
4982-
"""
4983-
logger.info(f"[delete_node_by_prams] count_query: {count_query}")
4954+
# Build conditions for this batch
4955+
batch_conditions = []
4956+
for node_id in batch_ids:
4957+
batch_conditions.append(
4958+
f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype"
4959+
)
4960+
batch_where = f"({' OR '.join(batch_conditions)})"
49844961

4985-
# Then delete nodes
4986-
delete_query = f"""
4987-
DELETE FROM "{self.db_name}_graph"."Memory"
4988-
WHERE {where_clause}
4989-
"""
4962+
# Add user_name filter if provided
4963+
if user_name_conditions:
4964+
user_name_where = " OR ".join(user_name_conditions)
4965+
where_clause = f"({user_name_where}) AND ({batch_where})"
4966+
else:
4967+
where_clause = batch_where
49904968

4991-
logger.info(
4992-
f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}"
4993-
)
4994-
logger.info(f"[delete_node_by_prams] delete_query: {delete_query}")
4969+
# Count before deletion
4970+
count_query = f"""
4971+
SELECT COUNT(*)
4972+
FROM "{self.db_name}_graph"."Memory"
4973+
WHERE {where_clause}
4974+
"""
4975+
logger.info(
4976+
f"[delete_node_by_prams] memoryids batch {batch_idx + 1}/{total_batches}: count_query: {count_query}"
4977+
)
49954978

4996-
conn = None
4997-
deleted_count = 0
4998-
try:
4999-
conn = self._get_connection()
5000-
with conn.cursor() as cursor:
5001-
# Count nodes before deletion
5002-
cursor.execute(count_query)
5003-
count_result = cursor.fetchone()
5004-
expected_count = count_result[0] if count_result else 0
4979+
cursor.execute(count_query)
4980+
count_result = cursor.fetchone()
4981+
expected_count = count_result[0] if count_result else 0
50054982

5006-
logger.info(
5007-
f"[delete_node_by_prams] Found {expected_count} nodes matching the criteria"
5008-
)
4983+
if expected_count == 0:
4984+
logger.info(
4985+
f"[delete_node_by_prams] memoryids Batch {batch_idx + 1}/{total_batches}: No nodes found, skipping"
4986+
)
4987+
continue
4988+
4989+
# Delete batch
4990+
delete_query = f"""
4991+
DELETE FROM "{self.db_name}_graph"."Memory"
4992+
WHERE {where_clause}
4993+
"""
4994+
logger.info(
4995+
f"[delete_node_by_prams] memoryids batch {batch_idx + 1}/{total_batches}: delete_query: {delete_query}"
4996+
)
4997+
4998+
logger.info(
4999+
f"[delete_node_by_prams] memoryids Batch {batch_idx + 1}/{total_batches}: Executing delete query for {len(batch_ids)} nodes"
5000+
)
5001+
cursor.execute(delete_query)
5002+
batch_deleted = cursor.rowcount
5003+
total_deleted_count += batch_deleted
5004+
5005+
logger.info(
5006+
f"[delete_node_by_prams] memoryids Batch {batch_idx + 1}/{total_batches}: Deleted {batch_deleted} nodes (batch size: {len(batch_ids)})"
5007+
)
5008+
5009+
# Process file_ids in batches
5010+
if file_ids:
5011+
total_file_batches = (len(file_ids) + batch_size - 1) // batch_size
5012+
logger.info(
5013+
f"[delete_node_by_prams] Processing {len(file_ids)} file_ids in {total_file_batches} batches (batch_size={batch_size})"
5014+
)
5015+
5016+
for batch_idx in range(total_file_batches):
5017+
batch_start = batch_idx * batch_size
5018+
batch_end = min(batch_start + batch_size, len(file_ids))
5019+
batch_file_ids = file_ids[batch_start:batch_end]
5020+
5021+
# Build conditions for this batch
5022+
batch_conditions = []
5023+
for file_id in batch_file_ids:
5024+
batch_conditions.append(
5025+
f"agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '\"file_ids\"'::agtype]), '\"{file_id}\"'::agtype)"
5026+
)
5027+
batch_where = f"({' OR '.join(batch_conditions)})"
5028+
5029+
# Add user_name filter if provided
5030+
if user_name_conditions:
5031+
user_name_where = " OR ".join(user_name_conditions)
5032+
where_clause = f"({user_name_where}) AND ({batch_where})"
5033+
else:
5034+
where_clause = batch_where
5035+
5036+
# Count before deletion
5037+
count_query = f"""
5038+
SELECT COUNT(*)
5039+
FROM "{self.db_name}_graph"."Memory"
5040+
WHERE {where_clause}
5041+
"""
5042+
5043+
logger.info(
5044+
f"[delete_node_by_prams] File batch {batch_idx + 1}/{total_file_batches}: count_query: {count_query}"
5045+
)
5046+
cursor.execute(count_query)
5047+
count_result = cursor.fetchone()
5048+
expected_count = count_result[0] if count_result else 0
5049+
5050+
if expected_count == 0:
5051+
logger.info(
5052+
f"[delete_node_by_prams] File batch {batch_idx + 1}/{total_file_batches}: No nodes found, skipping"
5053+
)
5054+
continue
5055+
5056+
# Delete batch
5057+
delete_query = f"""
5058+
DELETE FROM "{self.db_name}_graph"."Memory"
5059+
WHERE {where_clause}
5060+
"""
5061+
cursor.execute(delete_query)
5062+
batch_deleted = cursor.rowcount
5063+
total_deleted_count += batch_deleted
5064+
5065+
logger.info(
5066+
f"[delete_node_by_prams] File batch {batch_idx + 1}/{total_file_batches}: delete_query: {delete_query}"
5067+
)
50095068

5010-
# Delete nodes
5011-
cursor.execute(delete_query)
5012-
# Use rowcount to get actual deleted count
5013-
deleted_count = cursor.rowcount
50145069
elapsed_time = time.time() - batch_start_time
50155070
logger.info(
5016-
f"[delete_node_by_prams] Deletion completed successfully in {elapsed_time:.2f}s, deleted {deleted_count} nodes"
5071+
f"[delete_node_by_prams] Deletion completed successfully in {elapsed_time:.2f}s, total deleted {total_deleted_count} nodes"
50175072
)
50185073
except Exception as e:
50195074
logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True)
50205075
raise
50215076
finally:
50225077
self._return_connection(conn)
50235078

5024-
logger.info(f"[delete_node_by_prams] Successfully deleted {deleted_count} nodes")
5025-
return deleted_count
5079+
logger.info(f"[delete_node_by_prams] Successfully deleted {total_deleted_count} nodes")
5080+
return total_deleted_count
50265081

50275082
@timed
50285083
def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, list[str]]:

0 commit comments

Comments
 (0)