Skip to content

Commit 748ef3d

Browse files
authored
Dev zdy 1226 page (#796)
* feat: add export_graph page for polardb.py * feat: add export_graph page for neo4j.py * feat: add get_user_names_by_memory_ids * feat: add delete_node_by_prams
1 parent 336a2be commit 748ef3d

File tree

2 files changed

+151
-30
lines changed

2 files changed

+151
-30
lines changed

src/memos/graph_dbs/neo4j.py

Lines changed: 122 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,17 +1132,40 @@ def clear(self, user_name: str | None = None) -> None:
11321132
logger.error(f"[ERROR] Failed to clear database '{self.db_name}': {e}")
11331133
raise
11341134

1135-
def export_graph(self, **kwargs) -> dict[str, Any]:
1135+
def export_graph(
1136+
self,
1137+
page: int | None = None,
1138+
page_size: int | None = None,
1139+
**kwargs,
1140+
) -> dict[str, Any]:
11361141
"""
11371142
Export all graph nodes and edges in a structured form.
11381143
1144+
Args:
1145+
page (int, optional): Page number (starts from 1). If None, exports all data without pagination.
1146+
page_size (int, optional): Number of items per page. If None, exports all data without pagination.
1147+
**kwargs: Additional keyword arguments, including:
1148+
- user_name (str, optional): User name for filtering in non-multi-db mode
1149+
11391150
Returns:
11401151
{
11411152
"nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ],
11421153
"edges": [ { "source": ..., "target": ..., "type": ... }, ... ]
11431154
}
11441155
"""
11451156
user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name
1157+
1158+
# Determine if pagination is needed
1159+
use_pagination = page is not None and page_size is not None
1160+
1161+
# Validate pagination parameters if pagination is enabled
1162+
if use_pagination:
1163+
if page < 1:
1164+
page = 1
1165+
if page_size < 1:
1166+
page_size = 10
1167+
skip = (page - 1) * page_size
1168+
11461169
with self.driver.session(database=self.db_name) as session:
11471170
# Export nodes
11481171
node_query = "MATCH (n:Memory)"
@@ -1154,13 +1177,23 @@ def export_graph(self, **kwargs) -> dict[str, Any]:
11541177
edge_query += " WHERE a.user_name = $user_name AND b.user_name = $user_name"
11551178
params["user_name"] = user_name
11561179

1157-
node_result = session.run(f"{node_query} RETURN n", params)
1180+
# Add ORDER BY and pagination for nodes
1181+
node_query += " RETURN n ORDER BY n.id"
1182+
if use_pagination:
1183+
node_query += f" SKIP {skip} LIMIT {page_size}"
1184+
1185+
node_result = session.run(node_query, params)
11581186
nodes = [self._parse_node(dict(record["n"])) for record in node_result]
11591187

11601188
# Export edges
1161-
edge_result = session.run(
1162-
f"{edge_query} RETURN a.id AS source, b.id AS target, type(r) AS type", params
1189+
# Add ORDER BY and pagination for edges
1190+
edge_query += (
1191+
" RETURN a.id AS source, b.id AS target, type(r) AS type ORDER BY a.id, b.id"
11631192
)
1193+
if use_pagination:
1194+
edge_query += f" SKIP {skip} LIMIT {page_size}"
1195+
1196+
edge_result = session.run(edge_query, params)
11641197
edges = [
11651198
{"source": record["source"], "target": record["target"], "type": record["type"]}
11661199
for record in edge_result
@@ -1646,7 +1679,7 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]:
16461679

16471680
def delete_node_by_prams(
16481681
self,
1649-
writable_cube_ids: list[str],
1682+
writable_cube_ids: list[str] | None = None,
16501683
memory_ids: list[str] | None = None,
16511684
file_ids: list[str] | None = None,
16521685
filter: dict | None = None,
@@ -1655,7 +1688,8 @@ def delete_node_by_prams(
16551688
Delete nodes by memory_ids, file_ids, or filter.
16561689
16571690
Args:
1658-
writable_cube_ids (list[str]): List of cube IDs (user_name) to filter nodes. Required parameter.
1691+
writable_cube_ids (list[str], optional): List of cube IDs (user_name) to filter nodes.
1692+
If not provided, no user_name filter will be applied.
16591693
memory_ids (list[str], optional): List of memory node IDs to delete.
16601694
file_ids (list[str], optional): List of file node IDs to delete.
16611695
filter (dict, optional): Filter dictionary to query matching nodes for deletion.
@@ -1670,20 +1704,18 @@ def delete_node_by_prams(
16701704
f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}"
16711705
)
16721706

1673-
# Validate writable_cube_ids
1674-
if not writable_cube_ids or len(writable_cube_ids) == 0:
1675-
raise ValueError("writable_cube_ids is required and cannot be empty")
1676-
16771707
# Build WHERE conditions separately for memory_ids and file_ids
16781708
where_clauses = []
16791709
params = {}
16801710

16811711
# Build user_name condition from writable_cube_ids (OR relationship - match any cube_id)
1712+
# Only add user_name filter if writable_cube_ids is provided
16821713
user_name_conditions = []
1683-
for idx, cube_id in enumerate(writable_cube_ids):
1684-
param_name = f"cube_id_{idx}"
1685-
user_name_conditions.append(f"n.user_name = ${param_name}")
1686-
params[param_name] = cube_id
1714+
if writable_cube_ids and len(writable_cube_ids) > 0:
1715+
for idx, cube_id in enumerate(writable_cube_ids):
1716+
param_name = f"cube_id_{idx}"
1717+
user_name_conditions.append(f"n.user_name = ${param_name}")
1718+
params[param_name] = cube_id
16871719

16881720
# Handle memory_ids: query n.id
16891721
if memory_ids and len(memory_ids) > 0:
@@ -1711,7 +1743,7 @@ def delete_node_by_prams(
17111743
filters=[],
17121744
user_name=None,
17131745
filter=filter,
1714-
knowledgebase_ids=writable_cube_ids,
1746+
knowledgebase_ids=writable_cube_ids if writable_cube_ids else None,
17151747
)
17161748

17171749
# If filter returned IDs, add condition for them
@@ -1730,9 +1762,14 @@ def delete_node_by_prams(
17301762
# First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match)
17311763
data_conditions = " OR ".join([f"({clause})" for clause in where_clauses])
17321764

1733-
# Then, combine with user_name condition using AND (must match user_name AND one of the data conditions)
1734-
user_name_where = " OR ".join(user_name_conditions)
1735-
ids_where = f"({user_name_where}) AND ({data_conditions})"
1765+
# Build final WHERE clause
1766+
# If user_name_conditions exist, combine with data_conditions using AND
1767+
# Otherwise, use only data_conditions
1768+
if user_name_conditions:
1769+
user_name_where = " OR ".join(user_name_conditions)
1770+
ids_where = f"({user_name_where}) AND ({data_conditions})"
1771+
else:
1772+
ids_where = f"({data_conditions})"
17361773

17371774
logger.info(
17381775
f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}"
@@ -1773,3 +1810,70 @@ def delete_node_by_prams(
17731810

17741811
logger.info(f"[delete_node_by_prams] Successfully deleted {deleted_count} nodes")
17751812
return deleted_count
1813+
1814+
def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, list[str]]:
1815+
"""Get user names by memory ids.
1816+
1817+
Args:
1818+
memory_ids: List of memory node IDs to query.
1819+
1820+
Returns:
1821+
dict[str, list[str]]: Dictionary with one key:
1822+
- 'no_exist_memory_ids': List of memory_ids that do not exist (if any are missing)
1823+
- 'exist_user_names': List of distinct user names (if all memory_ids exist)
1824+
"""
1825+
if not memory_ids:
1826+
return {"exist_user_names": []}
1827+
1828+
logger.info(f"[get_user_names_by_memory_ids] Checking {len(memory_ids)} memory_ids")
1829+
1830+
try:
1831+
with self.driver.session(database=self.db_name) as session:
1832+
# Query to check which memory_ids exist
1833+
check_query = """
1834+
MATCH (n:Memory)
1835+
WHERE n.id IN $memory_ids
1836+
RETURN n.id AS id
1837+
"""
1838+
1839+
check_result = session.run(check_query, memory_ids=memory_ids)
1840+
existing_ids = set()
1841+
for record in check_result:
1842+
node_id = record["id"]
1843+
existing_ids.add(node_id)
1844+
1845+
# Check if any memory_ids are missing
1846+
no_exist_list = [mid for mid in memory_ids if mid not in existing_ids]
1847+
1848+
# If any memory_ids are missing, return no_exist_memory_ids
1849+
if no_exist_list:
1850+
logger.info(
1851+
f"[get_user_names_by_memory_ids] Found {len(no_exist_list)} non-existing memory_ids: {no_exist_list}"
1852+
)
1853+
return {"no_exist_memory_ids": no_exist_list}
1854+
1855+
# All memory_ids exist, query user_names
1856+
user_names_query = """
1857+
MATCH (n:Memory)
1858+
WHERE n.id IN $memory_ids
1859+
RETURN DISTINCT n.user_name AS user_name
1860+
"""
1861+
logger.info(f"[get_user_names_by_memory_ids] user_names_query: {user_names_query}")
1862+
1863+
user_names_result = session.run(user_names_query, memory_ids=memory_ids)
1864+
user_names = []
1865+
for record in user_names_result:
1866+
user_name = record["user_name"]
1867+
if user_name:
1868+
user_names.append(user_name)
1869+
1870+
logger.info(
1871+
f"[get_user_names_by_memory_ids] All memory_ids exist, found {len(user_names)} distinct user_names"
1872+
)
1873+
1874+
return {"exist_user_names": user_names}
1875+
except Exception as e:
1876+
logger.error(
1877+
f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True
1878+
)
1879+
raise

src/memos/graph_dbs/polardb.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2505,16 +2505,16 @@ def export_graph(
25052505
self,
25062506
include_embedding: bool = False,
25072507
user_name: str | None = None,
2508-
page: int = 1,
2509-
page_size: int = 10,
2508+
page: int | None = None,
2509+
page_size: int | None = None,
25102510
) -> dict[str, Any]:
25112511
"""
25122512
Export all graph nodes and edges in a structured form.
25132513
Args:
25142514
include_embedding (bool): Whether to include the large embedding field.
25152515
user_name (str, optional): User name for filtering in non-multi-db mode
2516-
page (int): Page number (starts from 1). Default is 1.
2517-
page_size (int): Number of items per page. Default is 1000.
2516+
page (int, optional): Page number (starts from 1). If None, exports all data without pagination.
2517+
page_size (int, optional): Number of items per page. If None, exports all data without pagination.
25182518
25192519
Returns:
25202520
{
@@ -2527,31 +2527,43 @@ def export_graph(
25272527
)
25282528
user_name = user_name if user_name else self._get_config_value("user_name")
25292529

2530-
# Validate pagination parameters
2531-
if page < 1:
2532-
page = 1
2533-
if page_size < 1:
2534-
page_size = 10
2530+
# Determine if pagination is needed
2531+
use_pagination = page is not None and page_size is not None
2532+
2533+
# Validate pagination parameters if pagination is enabled
2534+
if use_pagination:
2535+
if page < 1:
2536+
page = 1
2537+
if page_size < 1:
2538+
page_size = 10
2539+
offset = (page - 1) * page_size
2540+
else:
2541+
offset = None
25352542

25362543
conn = None
25372544
try:
25382545
conn = self._get_connection()
25392546
# Export nodes
2547+
# Build pagination clause if needed
2548+
pagination_clause = ""
2549+
if use_pagination:
2550+
pagination_clause = f"LIMIT {page_size} OFFSET {offset}"
2551+
25402552
if include_embedding:
25412553
node_query = f"""
25422554
SELECT id, properties, embedding
25432555
FROM "{self.db_name}_graph"."Memory"
25442556
WHERE ag_catalog.agtype_access_operator(properties, '"user_name"'::agtype) = '\"{user_name}\"'::agtype
25452557
ORDER BY id
2546-
LIMIT {page_size} OFFSET {(page - 1) * page_size}
2558+
{pagination_clause}
25472559
"""
25482560
else:
25492561
node_query = f"""
25502562
SELECT id, properties
25512563
FROM "{self.db_name}_graph"."Memory"
25522564
WHERE ag_catalog.agtype_access_operator(properties, '"user_name"'::agtype) = '\"{user_name}\"'::agtype
25532565
ORDER BY id
2554-
LIMIT {page_size} OFFSET {(page - 1) * page_size}
2566+
{pagination_clause}
25552567
"""
25562568
logger.info(f"[export_graph nodes] Query: {node_query}")
25572569
with conn.cursor() as cursor:
@@ -2601,6 +2613,11 @@ def export_graph(
26012613
conn = self._get_connection()
26022614
# Export edges using cypher query
26032615
# Note: Apache AGE Cypher may not support SKIP, so we use SQL LIMIT/OFFSET on the subquery
2616+
# Build pagination clause if needed
2617+
edge_pagination_clause = ""
2618+
if use_pagination:
2619+
edge_pagination_clause = f"LIMIT {page_size} OFFSET {offset}"
2620+
26042621
edge_query = f"""
26052622
SELECT source, target, edge FROM (
26062623
SELECT * FROM cypher('{self.db_name}_graph', $$
@@ -2610,7 +2627,7 @@ def export_graph(
26102627
ORDER BY a.id, b.id
26112628
$$) AS (source agtype, target agtype, edge agtype)
26122629
) AS edges
2613-
LIMIT {page_size} OFFSET {(page - 1) * page_size}
2630+
{edge_pagination_clause}
26142631
"""
26152632
logger.info(f"[export_graph edges] Query: {edge_query}")
26162633
with conn.cursor() as cursor:

0 commit comments

Comments
 (0)