Skip to content

Commit 275ddc8

Browse files
wustzdyfridayLCaralHsi
authored
add nodes batch (#658)
* add_nodes_batch for polardb.py * add_nodes_batch for neo4j.py --------- Co-authored-by: chunyu li <[email protected]> Co-authored-by: CaralHsi <[email protected]>
1 parent 3d5a6e5 commit 275ddc8

File tree

3 files changed

+301
-0
lines changed

3 files changed

+301
-0
lines changed

src/memos/graph_dbs/base.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,3 +250,16 @@ def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> l
250250
Returns:
251251
list[dict]: Full list of memory items under this scope.
252252
"""
253+
254+
@abstractmethod
255+
def add_nodes_batch(self, nodes: list[dict[str, Any]], user_name: str | None = None) -> None:
256+
"""
257+
Batch add multiple memory nodes to the graph.
258+
259+
Args:
260+
nodes: List of node dictionaries, each containing:
261+
- id: str - Node ID
262+
- memory: str - Memory content
263+
- metadata: dict[str, Any] - Node metadata
264+
user_name: Optional user name (will use config default if not provided)
265+
"""

src/memos/graph_dbs/neo4j.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,110 @@ def add_node(
236236
metadata=metadata,
237237
)
238238

239+
def add_nodes_batch(
240+
self,
241+
nodes: list[dict[str, Any]],
242+
user_name: str | None = None,
243+
) -> None:
244+
"""
245+
Batch add multiple memory nodes to the graph.
246+
247+
Args:
248+
nodes: List of node dictionaries, each containing:
249+
- id: str - Node ID
250+
- memory: str - Memory content
251+
- metadata: dict[str, Any] - Node metadata
252+
user_name: Optional user name (will use config default if not provided)
253+
"""
254+
if not nodes:
255+
logger.warning("[add_nodes_batch] Empty nodes list, skipping")
256+
return
257+
258+
logger.info(f"[add_nodes_batch] Adding {len(nodes)} nodes")
259+
260+
# user_name comes from parameter; fallback to config if missing
261+
effective_user_name = user_name if user_name else self.config.user_name
262+
263+
# Prepare all nodes
264+
prepared_nodes = []
265+
for node_data in nodes:
266+
try:
267+
id = node_data["id"]
268+
memory = node_data["memory"]
269+
metadata = node_data.get("metadata", {})
270+
271+
logger.debug(f"[add_nodes_batch] Processing node id: {id}")
272+
273+
# Set user_name in metadata if needed
274+
if not self.config.use_multi_db and (self.config.user_name or effective_user_name):
275+
metadata["user_name"] = effective_user_name
276+
277+
# Safely process metadata
278+
metadata = _prepare_node_metadata(metadata)
279+
280+
# Flatten info fields to top level (for Neo4j flat structure)
281+
metadata = _flatten_info_fields(metadata)
282+
283+
# Merge node and set metadata
284+
created_at = metadata.pop("created_at")
285+
updated_at = metadata.pop("updated_at")
286+
287+
# Serialization for sources
288+
if metadata.get("sources"):
289+
for idx in range(len(metadata["sources"])):
290+
metadata["sources"][idx] = json.dumps(metadata["sources"][idx])
291+
292+
prepared_nodes.append(
293+
{
294+
"id": id,
295+
"memory": memory,
296+
"created_at": created_at,
297+
"updated_at": updated_at,
298+
"metadata": metadata,
299+
}
300+
)
301+
except Exception as e:
302+
logger.error(
303+
f"[add_nodes_batch] Failed to prepare node {node_data.get('id', 'unknown')}: {e}",
304+
exc_info=True,
305+
)
306+
# Continue with other nodes
307+
continue
308+
309+
if not prepared_nodes:
310+
logger.warning("[add_nodes_batch] No valid nodes to insert after preparation")
311+
return
312+
313+
# Batch insert using Neo4j UNWIND for better performance
314+
query = """
315+
UNWIND $nodes AS node
316+
MERGE (n:Memory {id: node.id})
317+
SET n.memory = node.memory,
318+
n.created_at = datetime(node.created_at),
319+
n.updated_at = datetime(node.updated_at),
320+
n += node.metadata
321+
"""
322+
323+
# Prepare nodes data for UNWIND
324+
nodes_data = [
325+
{
326+
"id": node["id"],
327+
"memory": node["memory"],
328+
"created_at": node["created_at"],
329+
"updated_at": node["updated_at"],
330+
"metadata": node["metadata"],
331+
}
332+
for node in prepared_nodes
333+
]
334+
335+
try:
336+
with self.driver.session(database=self.db_name) as session:
337+
session.run(query, nodes=nodes_data)
338+
logger.info(f"[add_nodes_batch] Successfully inserted {len(prepared_nodes)} nodes")
339+
except Exception as e:
340+
logger.error(f"[add_nodes_batch] Failed to add nodes: {e}", exc_info=True)
341+
raise
342+
239343
def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None:
240344
"""
241345
Update node fields in Neo4j, auto-converting `created_at` and `updated_at` to datetime type if present.

src/memos/graph_dbs/polardb.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3226,6 +3226,190 @@ def add_node(
32263226
logger.info(f"In add node polardb: id-{id} memory-{memory} query-{insert_query}")
32273227
self._return_connection(conn)
32283228

3229+
@timed
3230+
def add_nodes_batch(
3231+
self,
3232+
nodes: list[dict[str, Any]],
3233+
user_name: str | None = None,
3234+
) -> None:
3235+
"""
3236+
Batch add multiple memory nodes to the graph.
3237+
3238+
Args:
3239+
nodes: List of node dictionaries, each containing:
3240+
- id: str - Node ID
3241+
- memory: str - Memory content
3242+
- metadata: dict[str, Any] - Node metadata
3243+
user_name: Optional user name (will use config default if not provided)
3244+
"""
3245+
if not nodes:
3246+
logger.warning("[add_nodes_batch] Empty nodes list, skipping")
3247+
return
3248+
3249+
logger.info(f"[add_nodes_batch] Adding {len(nodes)} nodes")
3250+
3251+
# user_name comes from parameter; fallback to config if missing
3252+
effective_user_name = user_name if user_name else self.config.user_name
3253+
3254+
# Prepare all nodes
3255+
prepared_nodes = []
3256+
for node_data in nodes:
3257+
try:
3258+
id = node_data["id"]
3259+
memory = node_data["memory"]
3260+
metadata = node_data.get("metadata", {})
3261+
3262+
logger.debug(f"[add_nodes_batch] Processing node id: {id}")
3263+
3264+
# Set user_name in metadata
3265+
metadata["user_name"] = effective_user_name
3266+
3267+
metadata = _prepare_node_metadata(metadata)
3268+
3269+
# Merge node and set metadata
3270+
created_at = metadata.pop("created_at", datetime.utcnow().isoformat())
3271+
updated_at = metadata.pop("updated_at", datetime.utcnow().isoformat())
3272+
3273+
# Prepare properties
3274+
properties = {
3275+
"id": id,
3276+
"memory": memory,
3277+
"created_at": created_at,
3278+
"updated_at": updated_at,
3279+
**metadata,
3280+
}
3281+
3282+
# Generate embedding if not provided
3283+
if "embedding" not in properties or not properties["embedding"]:
3284+
properties["embedding"] = generate_vector(
3285+
self._get_config_value("embedding_dimension", 1024)
3286+
)
3287+
3288+
# Serialization - JSON-serialize sources and usage fields
3289+
for field_name in ["sources", "usage"]:
3290+
if properties.get(field_name):
3291+
if isinstance(properties[field_name], list):
3292+
for idx in range(len(properties[field_name])):
3293+
# Serialize only when element is not a string
3294+
if not isinstance(properties[field_name][idx], str):
3295+
properties[field_name][idx] = json.dumps(
3296+
properties[field_name][idx]
3297+
)
3298+
elif isinstance(properties[field_name], str):
3299+
# If already a string, leave as-is
3300+
pass
3301+
3302+
# Extract embedding for separate column
3303+
embedding_vector = properties.pop("embedding", [])
3304+
if not isinstance(embedding_vector, list):
3305+
embedding_vector = []
3306+
3307+
# Select column name based on embedding dimension
3308+
embedding_column = "embedding" # default column
3309+
if len(embedding_vector) == 3072:
3310+
embedding_column = "embedding_3072"
3311+
elif len(embedding_vector) == 1024:
3312+
embedding_column = "embedding"
3313+
elif len(embedding_vector) == 768:
3314+
embedding_column = "embedding_768"
3315+
3316+
prepared_nodes.append(
3317+
{
3318+
"id": id,
3319+
"memory": memory,
3320+
"properties": properties,
3321+
"embedding_vector": embedding_vector,
3322+
"embedding_column": embedding_column,
3323+
}
3324+
)
3325+
except Exception as e:
3326+
logger.error(
3327+
f"[add_nodes_batch] Failed to prepare node {node_data.get('id', 'unknown')}: {e}",
3328+
exc_info=True,
3329+
)
3330+
# Continue with other nodes
3331+
continue
3332+
3333+
if not prepared_nodes:
3334+
logger.warning("[add_nodes_batch] No valid nodes to insert after preparation")
3335+
return
3336+
3337+
# Group nodes by embedding column to optimize batch inserts
3338+
nodes_by_embedding_column = {}
3339+
for node in prepared_nodes:
3340+
col = node["embedding_column"]
3341+
if col not in nodes_by_embedding_column:
3342+
nodes_by_embedding_column[col] = []
3343+
nodes_by_embedding_column[col].append(node)
3344+
3345+
conn = None
3346+
try:
3347+
conn = self._get_connection()
3348+
with conn.cursor() as cursor:
3349+
# Process each group separately
3350+
for embedding_column, nodes_group in nodes_by_embedding_column.items():
3351+
# Delete existing records first (batch delete)
3352+
for node in nodes_group:
3353+
delete_query = f"""
3354+
DELETE FROM {self.db_name}_graph."Memory"
3355+
WHERE id = ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring)
3356+
"""
3357+
cursor.execute(delete_query, (node["id"],))
3358+
3359+
# Insert nodes (batch insert using executemany for better performance)
3360+
for node in nodes_group:
3361+
# Get graph_id for this node
3362+
get_graph_id_query = f"""
3363+
SELECT ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring)
3364+
"""
3365+
cursor.execute(get_graph_id_query, (node["id"],))
3366+
graph_id = cursor.fetchone()[0]
3367+
node["properties"]["graph_id"] = str(graph_id)
3368+
3369+
# Insert node
3370+
if node["embedding_vector"]:
3371+
insert_query = f"""
3372+
INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column})
3373+
VALUES (
3374+
ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring),
3375+
%s,
3376+
%s
3377+
)
3378+
"""
3379+
logger.info(
3380+
f"[add_nodes_batch] Inserting node insert_query={insert_query}"
3381+
)
3382+
cursor.execute(
3383+
insert_query,
3384+
(
3385+
node["id"],
3386+
json.dumps(node["properties"]),
3387+
json.dumps(node["embedding_vector"]),
3388+
),
3389+
)
3390+
else:
3391+
insert_query = f"""
3392+
INSERT INTO {self.db_name}_graph."Memory"(id, properties)
3393+
VALUES (
3394+
ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring),
3395+
%s
3396+
)
3397+
"""
3398+
cursor.execute(
3399+
insert_query,
3400+
(node["id"], json.dumps(node["properties"])),
3401+
)
3402+
3403+
logger.info(
3404+
f"[add_nodes_batch] Inserted {len(nodes_group)} nodes with embedding_column={embedding_column}"
3405+
)
3406+
3407+
except Exception as e:
3408+
logger.error(f"[add_nodes_batch] Failed to add nodes: {e}", exc_info=True)
3409+
raise
3410+
finally:
3411+
self._return_connection(conn)
3412+
32293413
def _build_node_from_agtype(self, node_agtype, embedding=None):
32303414
"""
32313415
Parse the cypher-returned column `n` (agtype or JSON string)

0 commit comments

Comments
 (0)