Skip to content

Commit a4f66b1

Browse files
author
yuan.wang
committed
Merge branch 'dev' into feat/fix_palyground_bug
2 parents 7e05fa7 + 35b192f commit a4f66b1

File tree

8 files changed

+470
-3
lines changed

8 files changed

+470
-3
lines changed

src/memos/graph_dbs/base.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,3 +250,16 @@ def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> l
250250
Returns:
251251
list[dict]: Full list of memory items under this scope.
252252
"""
253+
254+
@abstractmethod
255+
def add_nodes_batch(self, nodes: list[dict[str, Any]], user_name: str | None = None) -> None:
256+
"""
257+
Batch add multiple memory nodes to the graph.
258+
259+
Args:
260+
nodes: List of node dictionaries, each containing:
261+
- id: str - Node ID
262+
- memory: str - Memory content
263+
- metadata: dict[str, Any] - Node metadata
264+
user_name: Optional user name (will use config default if not provided)
265+
"""

src/memos/graph_dbs/neo4j.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,110 @@ def add_node(
236236
metadata=metadata,
237237
)
238238

239+
def add_nodes_batch(
240+
self,
241+
nodes: list[dict[str, Any]],
242+
user_name: str | None = None,
243+
) -> None:
244+
"""
245+
Batch add multiple memory nodes to the graph.
246+
247+
Args:
248+
nodes: List of node dictionaries, each containing:
249+
- id: str - Node ID
250+
- memory: str - Memory content
251+
- metadata: dict[str, Any] - Node metadata
252+
user_name: Optional user name (will use config default if not provided)
253+
"""
254+
if not nodes:
255+
logger.warning("[add_nodes_batch] Empty nodes list, skipping")
256+
return
257+
258+
logger.info(f"[add_nodes_batch] Adding {len(nodes)} nodes")
259+
260+
# user_name comes from parameter; fallback to config if missing
261+
effective_user_name = user_name if user_name else self.config.user_name
262+
263+
# Prepare all nodes
264+
prepared_nodes = []
265+
for node_data in nodes:
266+
try:
267+
id = node_data["id"]
268+
memory = node_data["memory"]
269+
metadata = node_data.get("metadata", {})
270+
271+
logger.debug(f"[add_nodes_batch] Processing node id: {id}")
272+
273+
# Set user_name in metadata if needed
274+
if not self.config.use_multi_db and (self.config.user_name or effective_user_name):
275+
metadata["user_name"] = effective_user_name
276+
277+
# Safely process metadata
278+
metadata = _prepare_node_metadata(metadata)
279+
280+
# Flatten info fields to top level (for Neo4j flat structure)
281+
metadata = _flatten_info_fields(metadata)
282+
283+
# Merge node and set metadata
284+
created_at = metadata.pop("created_at")
285+
updated_at = metadata.pop("updated_at")
286+
287+
# Serialization for sources
288+
if metadata.get("sources"):
289+
for idx in range(len(metadata["sources"])):
290+
metadata["sources"][idx] = json.dumps(metadata["sources"][idx])
291+
292+
prepared_nodes.append(
293+
{
294+
"id": id,
295+
"memory": memory,
296+
"created_at": created_at,
297+
"updated_at": updated_at,
298+
"metadata": metadata,
299+
}
300+
)
301+
except Exception as e:
302+
logger.error(
303+
f"[add_nodes_batch] Failed to prepare node {node_data.get('id', 'unknown')}: {e}",
304+
exc_info=True,
305+
)
306+
# Continue with other nodes
307+
continue
308+
309+
if not prepared_nodes:
310+
logger.warning("[add_nodes_batch] No valid nodes to insert after preparation")
311+
return
312+
313+
# Batch insert using Neo4j UNWIND for better performance
314+
query = """
315+
UNWIND $nodes AS node
316+
MERGE (n:Memory {id: node.id})
317+
SET n.memory = node.memory,
318+
n.created_at = datetime(node.created_at),
319+
n.updated_at = datetime(node.updated_at),
320+
n += node.metadata
321+
"""
322+
323+
# Prepare nodes data for UNWIND
324+
nodes_data = [
325+
{
326+
"id": node["id"],
327+
"memory": node["memory"],
328+
"created_at": node["created_at"],
329+
"updated_at": node["updated_at"],
330+
"metadata": node["metadata"],
331+
}
332+
for node in prepared_nodes
333+
]
334+
335+
try:
336+
with self.driver.session(database=self.db_name) as session:
337+
session.run(query, nodes=nodes_data)
338+
logger.info(f"[add_nodes_batch] Successfully inserted {len(prepared_nodes)} nodes")
339+
except Exception as e:
340+
logger.error(f"[add_nodes_batch] Failed to add nodes: {e}", exc_info=True)
341+
raise
342+
239343
def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None:
240344
"""
241345
Update node fields in Neo4j, auto-converting `created_at` and `updated_at` to datetime type if present.

src/memos/graph_dbs/polardb.py

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3226,6 +3226,252 @@ def add_node(
32263226
logger.info(f"In add node polardb: id-{id} memory-{memory} query-{insert_query}")
32273227
self._return_connection(conn)
32283228

3229+
@timed
3230+
def add_nodes_batch(
3231+
self,
3232+
nodes: list[dict[str, Any]],
3233+
user_name: str | None = None,
3234+
) -> None:
3235+
"""
3236+
Batch add multiple memory nodes to the graph.
3237+
3238+
Args:
3239+
nodes: List of node dictionaries, each containing:
3240+
- id: str - Node ID
3241+
- memory: str - Memory content
3242+
- metadata: dict[str, Any] - Node metadata
3243+
user_name: Optional user name (will use config default if not provided)
3244+
"""
3245+
if not nodes:
3246+
logger.warning("[add_nodes_batch] Empty nodes list, skipping")
3247+
return
3248+
3249+
logger.info(f"[add_nodes_batch] Adding {len(nodes)} nodes")
3250+
3251+
# user_name comes from parameter; fallback to config if missing
3252+
effective_user_name = user_name if user_name else self.config.user_name
3253+
3254+
# Prepare all nodes
3255+
prepared_nodes = []
3256+
for node_data in nodes:
3257+
try:
3258+
id = node_data["id"]
3259+
memory = node_data["memory"]
3260+
metadata = node_data.get("metadata", {})
3261+
3262+
logger.debug(f"[add_nodes_batch] Processing node id: {id}")
3263+
3264+
# Set user_name in metadata
3265+
metadata["user_name"] = effective_user_name
3266+
3267+
metadata = _prepare_node_metadata(metadata)
3268+
3269+
# Merge node and set metadata
3270+
created_at = metadata.pop("created_at", datetime.utcnow().isoformat())
3271+
updated_at = metadata.pop("updated_at", datetime.utcnow().isoformat())
3272+
3273+
# Prepare properties
3274+
properties = {
3275+
"id": id,
3276+
"memory": memory,
3277+
"created_at": created_at,
3278+
"updated_at": updated_at,
3279+
**metadata,
3280+
}
3281+
3282+
# Generate embedding if not provided
3283+
if "embedding" not in properties or not properties["embedding"]:
3284+
properties["embedding"] = generate_vector(
3285+
self._get_config_value("embedding_dimension", 1024)
3286+
)
3287+
3288+
# Serialization - JSON-serialize sources and usage fields
3289+
for field_name in ["sources", "usage"]:
3290+
if properties.get(field_name):
3291+
if isinstance(properties[field_name], list):
3292+
for idx in range(len(properties[field_name])):
3293+
# Serialize only when element is not a string
3294+
if not isinstance(properties[field_name][idx], str):
3295+
properties[field_name][idx] = json.dumps(
3296+
properties[field_name][idx]
3297+
)
3298+
elif isinstance(properties[field_name], str):
3299+
# If already a string, leave as-is
3300+
pass
3301+
3302+
# Extract embedding for separate column
3303+
embedding_vector = properties.pop("embedding", [])
3304+
if not isinstance(embedding_vector, list):
3305+
embedding_vector = []
3306+
3307+
# Select column name based on embedding dimension
3308+
embedding_column = "embedding" # default column
3309+
if len(embedding_vector) == 3072:
3310+
embedding_column = "embedding_3072"
3311+
elif len(embedding_vector) == 1024:
3312+
embedding_column = "embedding"
3313+
elif len(embedding_vector) == 768:
3314+
embedding_column = "embedding_768"
3315+
3316+
prepared_nodes.append(
3317+
{
3318+
"id": id,
3319+
"memory": memory,
3320+
"properties": properties,
3321+
"embedding_vector": embedding_vector,
3322+
"embedding_column": embedding_column,
3323+
}
3324+
)
3325+
except Exception as e:
3326+
logger.error(
3327+
f"[add_nodes_batch] Failed to prepare node {node_data.get('id', 'unknown')}: {e}",
3328+
exc_info=True,
3329+
)
3330+
# Continue with other nodes
3331+
continue
3332+
3333+
if not prepared_nodes:
3334+
logger.warning("[add_nodes_batch] No valid nodes to insert after preparation")
3335+
return
3336+
3337+
# Group nodes by embedding column to optimize batch inserts
3338+
nodes_by_embedding_column = {}
3339+
for node in prepared_nodes:
3340+
col = node["embedding_column"]
3341+
if col not in nodes_by_embedding_column:
3342+
nodes_by_embedding_column[col] = []
3343+
nodes_by_embedding_column[col].append(node)
3344+
3345+
conn = None
3346+
try:
3347+
conn = self._get_connection()
3348+
with conn.cursor() as cursor:
3349+
# Process each group separately
3350+
for embedding_column, nodes_group in nodes_by_embedding_column.items():
3351+
# Batch delete existing records using IN clause
3352+
ids_to_delete = [node["id"] for node in nodes_group]
3353+
if ids_to_delete:
3354+
delete_query = f"""
3355+
DELETE FROM {self.db_name}_graph."Memory"
3356+
WHERE id IN (
3357+
SELECT ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, unnest(%s::text[])::cstring)
3358+
)
3359+
"""
3360+
cursor.execute(delete_query, (ids_to_delete,))
3361+
3362+
# Batch get graph_ids for all nodes
3363+
get_graph_ids_query = f"""
3364+
SELECT
3365+
id_val,
3366+
ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, id_val::text::cstring) as graph_id
3367+
FROM unnest(%s::text[]) as id_val
3368+
"""
3369+
cursor.execute(get_graph_ids_query, (ids_to_delete,))
3370+
graph_id_map = {row[0]: row[1] for row in cursor.fetchall()}
3371+
3372+
# Add graph_id to properties
3373+
for node in nodes_group:
3374+
graph_id = graph_id_map.get(node["id"])
3375+
if graph_id:
3376+
node["properties"]["graph_id"] = str(graph_id)
3377+
3378+
# Batch insert using VALUES with multiple rows
3379+
# Use psycopg2.extras.execute_values for efficient batch insert
3380+
from psycopg2.extras import execute_values
3381+
3382+
if embedding_column and any(node["embedding_vector"] for node in nodes_group):
3383+
# Prepare data tuples for batch insert with embedding
3384+
data_tuples = []
3385+
for node in nodes_group:
3386+
# Each tuple: (id, properties_json, embedding_json)
3387+
data_tuples.append(
3388+
(
3389+
node["id"],
3390+
json.dumps(node["properties"]),
3391+
json.dumps(node["embedding_vector"])
3392+
if node["embedding_vector"]
3393+
else None,
3394+
)
3395+
)
3396+
3397+
# Build the INSERT query template
3398+
insert_query = f"""
3399+
INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column})
3400+
VALUES %s
3401+
"""
3402+
3403+
# Build the VALUES template for execute_values
3404+
# Each row: (graph_id_function, agtype, vector)
3405+
# Note: properties column is agtype, not jsonb
3406+
template = f"""
3407+
(
3408+
ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring),
3409+
%s::text::agtype,
3410+
%s::vector
3411+
)
3412+
"""
3413+
logger.info(
3414+
f"[add_nodes_batch] embedding_column Inserting insert_query:{insert_query}"
3415+
)
3416+
logger.info(
3417+
f"[add_nodes_batch] embedding_column Inserting data_tuples:{data_tuples}"
3418+
)
3419+
3420+
# Execute batch insert
3421+
execute_values(
3422+
cursor,
3423+
insert_query,
3424+
data_tuples,
3425+
template=template,
3426+
page_size=100, # Insert in batches of 100
3427+
)
3428+
else:
3429+
# Prepare data tuples for batch insert without embedding
3430+
data_tuples = []
3431+
for node in nodes_group:
3432+
# Each tuple: (id, properties_json)
3433+
data_tuples.append(
3434+
(
3435+
node["id"],
3436+
json.dumps(node["properties"]),
3437+
)
3438+
)
3439+
3440+
# Build the INSERT query template
3441+
insert_query = f"""
3442+
INSERT INTO {self.db_name}_graph."Memory"(id, properties)
3443+
VALUES %s
3444+
"""
3445+
3446+
# Build the VALUES template for execute_values
3447+
# Note: properties column is agtype, not jsonb
3448+
template = f"""
3449+
(
3450+
ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring),
3451+
%s::text::agtype
3452+
)
3453+
"""
3454+
logger.info(f"[add_nodes_batch] Inserting insert_query:{insert_query}")
3455+
logger.info(f"[add_nodes_batch] Inserting data_tuples:{data_tuples}")
3456+
# Execute batch insert
3457+
execute_values(
3458+
cursor,
3459+
insert_query,
3460+
data_tuples,
3461+
template=template,
3462+
page_size=100, # Insert in batches of 100
3463+
)
3464+
3465+
logger.info(
3466+
f"[add_nodes_batch] Inserted {len(nodes_group)} nodes with embedding_column={embedding_column}"
3467+
)
3468+
3469+
except Exception as e:
3470+
logger.error(f"[add_nodes_batch] Failed to add nodes: {e}", exc_info=True)
3471+
raise
3472+
finally:
3473+
self._return_connection(conn)
3474+
32293475
def _build_node_from_agtype(self, node_agtype, embedding=None):
32303476
"""
32313477
Parse the cypher-returned column `n` (agtype or JSON string)

src/memos/mem_scheduler/base_scheduler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,8 @@ def initialize_modules(
224224
if self.dispatcher:
225225
self.dispatcher.status_tracker = self.status_tracker
226226
if self.memos_message_queue:
227-
self.memos_message_queue.status_tracker = self.status_tracker
227+
# Use the setter to propagate to the inner queue (e.g. SchedulerRedisQueue)
228+
self.memos_message_queue.set_status_tracker(self.status_tracker)
228229
# initialize submodules
229230
self.chat_llm = chat_llm
230231
self.process_llm = process_llm

0 commit comments

Comments
 (0)