Skip to content

Commit c30feee

Browse files
authored
Dev zdy 1221 01 user names (#759)
* add get_user_names_by_memory_ids * update delete_node_by_prams by no user_name * update delete_node_by_prams by no user_name
1 parent 15b475b commit c30feee

File tree

1 file changed

+106
-14
lines changed

1 file changed

+106
-14
lines changed

src/memos/graph_dbs/polardb.py

Lines changed: 106 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4763,7 +4763,7 @@ def process_condition(condition):
47634763
@timed
47644764
def delete_node_by_prams(
47654765
self,
4766-
writable_cube_ids: list[str],
4766+
writable_cube_ids: list[str] | None = None,
47674767
memory_ids: list[str] | None = None,
47684768
file_ids: list[str] | None = None,
47694769
filter: dict | None = None,
@@ -4772,7 +4772,8 @@ def delete_node_by_prams(
47724772
Delete nodes by memory_ids, file_ids, or filter.
47734773
47744774
Args:
4775-
writable_cube_ids (list[str]): List of cube IDs (user_name) to filter nodes. Required parameter.
4775+
writable_cube_ids (list[str], optional): List of cube IDs (user_name) to filter nodes.
4776+
If not provided, no user_name filter will be applied.
47764777
memory_ids (list[str], optional): List of memory node IDs to delete.
47774778
file_ids (list[str], optional): List of file node IDs to delete.
47784779
filter (dict, optional): Filter dictionary to query matching nodes for deletion.
@@ -4785,17 +4786,15 @@ def delete_node_by_prams(
47854786
f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}"
47864787
)
47874788

4788-
# Validate writable_cube_ids
4789-
if not writable_cube_ids or len(writable_cube_ids) == 0:
4790-
raise ValueError("writable_cube_ids is required and cannot be empty")
4791-
47924789
# Build user_name condition from writable_cube_ids (OR relationship - match any cube_id)
4790+
# Only add user_name filter if writable_cube_ids is provided
47934791
user_name_conditions = []
4794-
for cube_id in writable_cube_ids:
4795-
# Use agtype_access_operator with VARIADIC ARRAY format for consistency
4796-
user_name_conditions.append(
4797-
f"agtype_access_operator(VARIADIC ARRAY[properties, '\"user_name\"'::agtype]) = '\"{cube_id}\"'::agtype"
4798-
)
4792+
if writable_cube_ids and len(writable_cube_ids) > 0:
4793+
for cube_id in writable_cube_ids:
4794+
# Use agtype_access_operator with VARIADIC ARRAY format for consistency
4795+
user_name_conditions.append(
4796+
f"agtype_access_operator(VARIADIC ARRAY[properties, '\"user_name\"'::agtype]) = '\"{cube_id}\"'::agtype"
4797+
)
47994798

48004799
# Build WHERE conditions separately for memory_ids and file_ids
48014800
where_conditions = []
@@ -4863,9 +4862,14 @@ def delete_node_by_prams(
48634862
# First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match)
48644863
data_conditions = " OR ".join([f"({cond})" for cond in where_conditions])
48654864

4866-
# Then, combine with user_name condition using AND (must match user_name AND one of the data conditions)
4867-
user_name_where = " OR ".join(user_name_conditions)
4868-
where_clause = f"({user_name_where}) AND ({data_conditions})"
4865+
# Build final WHERE clause
4866+
# If user_name_conditions exist, combine with data_conditions using AND
4867+
# Otherwise, use only data_conditions
4868+
if user_name_conditions:
4869+
user_name_where = " OR ".join(user_name_conditions)
4870+
where_clause = f"({user_name_where}) AND ({data_conditions})"
4871+
else:
4872+
where_clause = f"({data_conditions})"
48694873

48704874
# Use SQL DELETE query for better performance
48714875
# First count matching nodes to get accurate count
@@ -4917,3 +4921,91 @@ def delete_node_by_prams(
49174921

49184922
logger.info(f"[delete_node_by_prams] Successfully deleted {deleted_count} nodes")
49194923
return deleted_count
4924+
4925+
@timed
4926+
def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, list[str]]:
4927+
"""Get user names by memory ids.
4928+
4929+
Args:
4930+
memory_ids: List of memory node IDs to query.
4931+
4932+
Returns:
4933+
dict[str, list[str]]: Dictionary with one key:
4934+
- 'no_exist_memory_ids': List of memory_ids that do not exist (if any are missing)
4935+
- 'exist_user_names': List of distinct user names (if all memory_ids exist)
4936+
"""
4937+
if not memory_ids:
4938+
return {"exist_user_names": []}
4939+
4940+
# Build OR conditions for each memory_id
4941+
id_conditions = []
4942+
for mid in memory_ids:
4943+
id_conditions.append(
4944+
f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{mid}\"'::agtype"
4945+
)
4946+
4947+
where_clause = f"({' OR '.join(id_conditions)})"
4948+
4949+
# Query to check which memory_ids exist
4950+
check_query = f"""
4951+
SELECT ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype)::text
4952+
FROM "{self.db_name}_graph"."Memory"
4953+
WHERE {where_clause}
4954+
"""
4955+
4956+
logger.info(f"[get_user_names_by_memory_ids] check_query: {check_query}")
4957+
conn = None
4958+
try:
4959+
conn = self._get_connection()
4960+
with conn.cursor() as cursor:
4961+
# Check which memory_ids exist
4962+
cursor.execute(check_query)
4963+
check_results = cursor.fetchall()
4964+
existing_ids = set()
4965+
for row in check_results:
4966+
node_id = row[0]
4967+
# Remove quotes if present
4968+
if isinstance(node_id, str):
4969+
node_id = node_id.strip('"').strip("'")
4970+
existing_ids.add(node_id)
4971+
4972+
# Check if any memory_ids are missing
4973+
no_exist_list = [mid for mid in memory_ids if mid not in existing_ids]
4974+
4975+
# If any memory_ids are missing, return no_exist_memory_ids
4976+
if no_exist_list:
4977+
logger.info(
4978+
f"[get_user_names_by_memory_ids] Found {len(no_exist_list)} non-existing memory_ids: {no_exist_list}"
4979+
)
4980+
return {"no_exist_memory_ids": no_exist_list}
4981+
4982+
# All memory_ids exist, query user_names
4983+
user_names_query = f"""
4984+
SELECT DISTINCT ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype)::text
4985+
FROM "{self.db_name}_graph"."Memory"
4986+
WHERE {where_clause}
4987+
"""
4988+
logger.info(f"[get_user_names_by_memory_ids] user_names_query: {user_names_query}")
4989+
4990+
cursor.execute(user_names_query)
4991+
results = cursor.fetchall()
4992+
user_names = []
4993+
for row in results:
4994+
user_name = row[0]
4995+
# Remove quotes if present
4996+
if isinstance(user_name, str):
4997+
user_name = user_name.strip('"').strip("'")
4998+
user_names.append(user_name)
4999+
5000+
logger.info(
5001+
f"[get_user_names_by_memory_ids] All memory_ids exist, found {len(user_names)} distinct user_names"
5002+
)
5003+
5004+
return {"exist_user_names": user_names}
5005+
except Exception as e:
5006+
logger.error(
5007+
f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True
5008+
)
5009+
raise
5010+
finally:
5011+
self._return_connection(conn)

0 commit comments

Comments
 (0)