Skip to content

Commit 8b5f796

Browse files
authored
fix file_ids (#615)
* add delete_node_by_prams for neo4j_community.py * fix
1 parent 3f99afd commit 8b5f796

File tree

3 files changed

+133
-2
lines changed

3 files changed

+133
-2
lines changed

src/memos/graph_dbs/neo4j.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1588,7 +1588,7 @@ def delete_node_by_prams(
15881588
file_id_and_conditions.append(f"${param_name} IN n.file_ids")
15891589
if file_id_and_conditions:
15901590
# Use AND to require all file_ids to be present
1591-
where_clauses.append(f"({' AND '.join(file_id_and_conditions)})")
1591+
where_clauses.append(f"({' OR '.join(file_id_and_conditions)})")
15921592

15931593
# Query nodes by filter if provided
15941594
filter_ids = []

src/memos/graph_dbs/neo4j_community.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,136 @@ def build_filter_condition(
706706
result = session.run(query, params)
707707
return [record["id"] for record in result]
708708

709+
def delete_node_by_prams(
710+
self,
711+
writable_cube_ids: list[str],
712+
memory_ids: list[str] | None = None,
713+
file_ids: list[str] | None = None,
714+
filter: dict | None = None,
715+
) -> int:
716+
"""
717+
Delete nodes by memory_ids, file_ids, or filter.
718+
719+
Args:
720+
writable_cube_ids (list[str]): List of cube IDs (user_name) to filter nodes. Required parameter.
721+
memory_ids (list[str], optional): List of memory node IDs to delete.
722+
file_ids (list[str], optional): List of file node IDs to delete.
723+
filter (dict, optional): Filter dictionary to query matching nodes for deletion.
724+
725+
Returns:
726+
int: Number of nodes deleted.
727+
"""
728+
logger.info(
729+
f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}"
730+
)
731+
print(
732+
f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}"
733+
)
734+
735+
# Validate writable_cube_ids
736+
if not writable_cube_ids or len(writable_cube_ids) == 0:
737+
raise ValueError("writable_cube_ids is required and cannot be empty")
738+
739+
# Build WHERE conditions separately for memory_ids and file_ids
740+
where_clauses = []
741+
params = {}
742+
743+
# Build user_name condition from writable_cube_ids (OR relationship - match any cube_id)
744+
user_name_conditions = []
745+
for idx, cube_id in enumerate(writable_cube_ids):
746+
param_name = f"cube_id_{idx}"
747+
user_name_conditions.append(f"n.user_name = ${param_name}")
748+
params[param_name] = cube_id
749+
750+
# Handle memory_ids: query n.id
751+
if memory_ids and len(memory_ids) > 0:
752+
where_clauses.append("n.id IN $memory_ids")
753+
params["memory_ids"] = memory_ids
754+
755+
# Handle file_ids: query n.file_ids field
756+
# All file_ids must be present in the array field (AND relationship)
757+
if file_ids and len(file_ids) > 0:
758+
file_id_and_conditions = []
759+
for idx, file_id in enumerate(file_ids):
760+
param_name = f"file_id_{idx}"
761+
params[param_name] = file_id
762+
# Check if this file_id is in the file_ids array field
763+
file_id_and_conditions.append(f"${param_name} IN n.file_ids")
764+
if file_id_and_conditions:
765+
# Use AND to require all file_ids to be present
766+
where_clauses.append(f"({' AND '.join(file_id_and_conditions)})")
767+
768+
# Query nodes by filter if provided
769+
filter_ids = []
770+
if filter:
771+
# Use get_by_metadata with empty filters list and filter
772+
filter_ids = self.get_by_metadata(
773+
filters=[],
774+
user_name=None,
775+
filter=filter,
776+
knowledgebase_ids=writable_cube_ids,
777+
)
778+
779+
# If filter returned IDs, add condition for them
780+
if filter_ids:
781+
where_clauses.append("n.id IN $filter_ids")
782+
params["filter_ids"] = filter_ids
783+
784+
# If no conditions (except user_name), return 0
785+
if not where_clauses:
786+
logger.warning(
787+
"[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)"
788+
)
789+
return 0
790+
791+
# Build WHERE clause
792+
# First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match)
793+
data_conditions = " OR ".join([f"({clause})" for clause in where_clauses])
794+
795+
# Then, combine with user_name condition using AND (must match user_name AND one of the data conditions)
796+
user_name_where = " OR ".join(user_name_conditions)
797+
ids_where = f"({user_name_where}) AND ({data_conditions})"
798+
799+
logger.info(
800+
f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}"
801+
)
802+
print(
803+
f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}"
804+
)
805+
806+
# First count matching nodes to get accurate count
807+
count_query = f"MATCH (n:Memory) WHERE {ids_where} RETURN count(n) AS node_count"
808+
logger.info(f"[delete_node_by_prams] count_query: {count_query}")
809+
print(f"[delete_node_by_prams] count_query: {count_query}")
810+
811+
# Then delete nodes
812+
delete_query = f"MATCH (n:Memory) WHERE {ids_where} DETACH DELETE n"
813+
logger.info(f"[delete_node_by_prams] delete_query: {delete_query}")
814+
print(f"[delete_node_by_prams] delete_query: {delete_query}")
815+
print(f"[delete_node_by_prams] params: {params}")
816+
817+
deleted_count = 0
818+
try:
819+
with self.driver.session(database=self.db_name) as session:
820+
# Count nodes before deletion
821+
count_result = session.run(count_query, **params)
822+
count_record = count_result.single()
823+
expected_count = 0
824+
if count_record:
825+
expected_count = count_record["node_count"] or 0
826+
827+
# Delete nodes
828+
session.run(delete_query, **params)
829+
# Use the count from before deletion as the actual deleted count
830+
deleted_count = expected_count
831+
832+
except Exception as e:
833+
logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True)
834+
raise
835+
836+
logger.info(f"[delete_node_by_prams] Successfully deleted {deleted_count} nodes")
837+
return deleted_count
838+
709839
def clear(self, user_name: str | None = None) -> None:
710840
"""
711841
Clear the entire graph if the target database exists.

src/memos/graph_dbs/polardb.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4113,6 +4113,7 @@ def parse_filter(
41134113
"memory_type",
41144114
"node_type",
41154115
"info",
4116+
"source",
41164117
}
41174118

41184119
def process_condition(condition):
@@ -4216,7 +4217,7 @@ def delete_node_by_prams(
42164217
file_id_and_conditions.append(f"'{escaped_id}' IN n.file_ids")
42174218
if file_id_and_conditions:
42184219
# Use AND to require all file_ids to be present
4219-
where_conditions.append(f"({' AND '.join(file_id_and_conditions)})")
4220+
where_conditions.append(f"({' OR '.join(file_id_and_conditions)})")
42204221

42214222
# Query nodes by filter if provided
42224223
filter_ids = set()

0 commit comments

Comments
 (0)