Skip to content

Commit dcd3d50

Browse files
author
yuan.wang
committed
Merge branch 'dev' into feat/fix_palyground_bug
2 parents e638039 + 8b5f796 commit dcd3d50

File tree

25 files changed

+429
-283
lines changed

25 files changed

+429
-283
lines changed

examples/mem_scheduler/memos_w_scheduler.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
from memos.mem_cube.general import GeneralMemCube
1414
from memos.mem_os.main import MOS
1515
from memos.mem_scheduler.general_scheduler import GeneralScheduler
16-
from memos.mem_scheduler.schemas.general_schemas import (
17-
ADD_LABEL,
18-
ANSWER_LABEL,
19-
MEM_ARCHIVE_LABEL,
20-
MEM_ORGANIZE_LABEL,
21-
MEM_UPDATE_LABEL,
22-
QUERY_LABEL,
23-
)
2416
from memos.mem_scheduler.schemas.message_schemas import ScheduleLogForWebItem
17+
from memos.mem_scheduler.schemas.task_schemas import (
18+
ADD_TASK_LABEL,
19+
ANSWER_TASK_LABEL,
20+
MEM_ARCHIVE_TASK_LABEL,
21+
MEM_ORGANIZE_TASK_LABEL,
22+
MEM_UPDATE_TASK_LABEL,
23+
QUERY_TASK_LABEL,
24+
)
2525
from memos.mem_scheduler.utils.filter_utils import transform_name_to_key
2626

2727

@@ -118,24 +118,24 @@ def _first_content() -> str:
118118
return memcube_content[0].get("content", "") or content
119119
return content
120120

121-
if label in ("addMessage", QUERY_LABEL, ANSWER_LABEL):
121+
if label in ("addMessage", QUERY_TASK_LABEL, ANSWER_TASK_LABEL):
122122
target_cube = cube_display.replace("MemCube", "")
123123
title = _format_title(item.timestamp, f"addMessages to {target_cube} MemCube")
124124
return title, _truncate_with_rules(_first_content())
125125

126-
if label in ("addMemory", ADD_LABEL):
126+
if label in ("addMemory", ADD_TASK_LABEL):
127127
title = _format_title(item.timestamp, f"{cube_display} added {memory_len} memories")
128128
return title, _truncate_with_rules(_first_content())
129129

130-
if label in ("updateMemory", MEM_UPDATE_LABEL):
130+
if label in ("updateMemory", MEM_UPDATE_TASK_LABEL):
131131
title = _format_title(item.timestamp, f"{cube_display} updated {memory_len} memories")
132132
return title, _truncate_with_rules(_first_content())
133133

134-
if label in ("archiveMemory", MEM_ARCHIVE_LABEL):
134+
if label in ("archiveMemory", MEM_ARCHIVE_TASK_LABEL):
135135
title = _format_title(item.timestamp, f"{cube_display} archived {memory_len} memories")
136136
return title, _truncate_with_rules(_first_content())
137137

138-
if label in ("mergeMemory", MEM_ORGANIZE_LABEL):
138+
if label in ("mergeMemory", MEM_ORGANIZE_TASK_LABEL):
139139
title = _format_title(item.timestamp, f"{cube_display} merged {memory_len} memories")
140140
merged = [c for c in memcube_content if c.get("type") == "merged"]
141141
post = [c for c in memcube_content if c.get("type") == "postMerge"]

examples/mem_scheduler/redis_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from memos.configs.mem_scheduler import SchedulerConfigFactory
1010
from memos.mem_cube.general import GeneralMemCube
1111
from memos.mem_scheduler.scheduler_factory import SchedulerFactory
12-
from memos.mem_scheduler.schemas.general_schemas import QUERY_LABEL
1312
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
13+
from memos.mem_scheduler.schemas.task_schemas import QUERY_TASK_LABEL
1414

1515

1616
if TYPE_CHECKING:
@@ -55,7 +55,7 @@ def service_run():
5555
message_item = ScheduleMessageItem(
5656
user_id=user_id,
5757
mem_cube_id="mem_cube_2",
58-
label=QUERY_LABEL,
58+
label=QUERY_TASK_LABEL,
5959
mem_cube=mem_cube,
6060
content=query,
6161
timestamp=datetime.now(),

examples/mem_scheduler/try_schedule_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from memos.mem_cube.general import GeneralMemCube
1515
from memos.mem_scheduler.analyzer.mos_for_test_scheduler import MOSForTestScheduler
1616
from memos.mem_scheduler.general_scheduler import GeneralScheduler
17-
from memos.mem_scheduler.schemas.general_schemas import (
17+
from memos.mem_scheduler.schemas.task_schemas import (
1818
NOT_APPLICABLE_TYPE,
1919
)
2020

src/memos/api/handlers/chat_handler.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@
3030
prepare_reference_data,
3131
process_streaming_references_complete,
3232
)
33-
from memos.mem_scheduler.schemas.general_schemas import (
34-
ANSWER_LABEL,
35-
QUERY_LABEL,
36-
)
3733
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
34+
from memos.mem_scheduler.schemas.task_schemas import (
35+
ANSWER_TASK_LABEL,
36+
QUERY_TASK_LABEL,
37+
)
3838
from memos.templates.mos_prompts import (
3939
FURTHER_SUGGESTION_PROMPT,
4040
get_memos_prompt,
@@ -244,7 +244,7 @@ def generate_chat_response() -> Generator[str, None, None]:
244244
user_id=chat_req.user_id,
245245
mem_cube_id=scheduler_cube_id,
246246
query=chat_req.query,
247-
label=QUERY_LABEL,
247+
label=QUERY_TASK_LABEL,
248248
)
249249
# Extract memories from search results
250250
memories_list = []
@@ -406,7 +406,7 @@ def generate_chat_response() -> Generator[str, None, None]:
406406
user_id=chat_req.user_id,
407407
mem_cube_id=scheduler_cube_id,
408408
query=chat_req.query,
409-
label=QUERY_LABEL,
409+
label=QUERY_TASK_LABEL,
410410
)
411411

412412
# ====== first search without parse goal ======
@@ -1091,7 +1091,7 @@ async def _post_chat_processing(
10911091

10921092
# Send answer to scheduler
10931093
self._send_message_to_scheduler(
1094-
user_id=user_id, mem_cube_id=cube_id, query=clean_response, label=ANSWER_LABEL
1094+
user_id=user_id, mem_cube_id=cube_id, query=clean_response, label=ANSWER_TASK_LABEL
10951095
)
10961096

10971097
self.logger.info(f"Post-chat processing completed for user {user_id}")

src/memos/configs/mem_reader.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from datetime import datetime
22
from typing import Any, ClassVar
33

4-
from pydantic import Field, field_validator, model_validator
4+
from pydantic import ConfigDict, Field, field_validator, model_validator
55

66
from memos.configs.base import BaseConfig
77
from memos.configs.chunker import ChunkerConfigFactory
@@ -44,6 +44,9 @@ def parse_datetime(cls, value):
4444
class SimpleStructMemReaderConfig(BaseMemReaderConfig):
4545
"""SimpleStruct MemReader configuration class."""
4646

47+
# Allow passing additional fields without raising validation errors
48+
model_config = ConfigDict(extra="allow", strict=True)
49+
4750

4851
class MultiModalStructMemReaderConfig(BaseMemReaderConfig):
4952
"""MultiModalStruct MemReader configuration class."""

src/memos/graph_dbs/neo4j.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1588,7 +1588,7 @@ def delete_node_by_prams(
15881588
file_id_and_conditions.append(f"${param_name} IN n.file_ids")
15891589
if file_id_and_conditions:
15901590
# Use AND to require all file_ids to be present
1591-
where_clauses.append(f"({' AND '.join(file_id_and_conditions)})")
1591+
where_clauses.append(f"({' OR '.join(file_id_and_conditions)})")
15921592

15931593
# Query nodes by filter if provided
15941594
filter_ids = []

src/memos/graph_dbs/neo4j_community.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,136 @@ def build_filter_condition(
706706
result = session.run(query, params)
707707
return [record["id"] for record in result]
708708

709+
def delete_node_by_prams(
710+
self,
711+
writable_cube_ids: list[str],
712+
memory_ids: list[str] | None = None,
713+
file_ids: list[str] | None = None,
714+
filter: dict | None = None,
715+
) -> int:
716+
"""
717+
Delete nodes by memory_ids, file_ids, or filter.
718+
719+
Args:
720+
writable_cube_ids (list[str]): List of cube IDs (user_name) to filter nodes. Required parameter.
721+
memory_ids (list[str], optional): List of memory node IDs to delete.
722+
file_ids (list[str], optional): List of file node IDs to delete.
723+
filter (dict, optional): Filter dictionary to query matching nodes for deletion.
724+
725+
Returns:
726+
int: Number of nodes deleted.
727+
"""
728+
logger.info(
729+
f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}"
730+
)
731+
print(
732+
f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}"
733+
)
734+
735+
# Validate writable_cube_ids
736+
if not writable_cube_ids or len(writable_cube_ids) == 0:
737+
raise ValueError("writable_cube_ids is required and cannot be empty")
738+
739+
# Build WHERE conditions separately for memory_ids and file_ids
740+
where_clauses = []
741+
params = {}
742+
743+
# Build user_name condition from writable_cube_ids (OR relationship - match any cube_id)
744+
user_name_conditions = []
745+
for idx, cube_id in enumerate(writable_cube_ids):
746+
param_name = f"cube_id_{idx}"
747+
user_name_conditions.append(f"n.user_name = ${param_name}")
748+
params[param_name] = cube_id
749+
750+
# Handle memory_ids: query n.id
751+
if memory_ids and len(memory_ids) > 0:
752+
where_clauses.append("n.id IN $memory_ids")
753+
params["memory_ids"] = memory_ids
754+
755+
# Handle file_ids: query n.file_ids field
756+
# All file_ids must be present in the array field (AND relationship)
757+
if file_ids and len(file_ids) > 0:
758+
file_id_and_conditions = []
759+
for idx, file_id in enumerate(file_ids):
760+
param_name = f"file_id_{idx}"
761+
params[param_name] = file_id
762+
# Check if this file_id is in the file_ids array field
763+
file_id_and_conditions.append(f"${param_name} IN n.file_ids")
764+
if file_id_and_conditions:
765+
# Use AND to require all file_ids to be present
766+
where_clauses.append(f"({' AND '.join(file_id_and_conditions)})")
767+
768+
# Query nodes by filter if provided
769+
filter_ids = []
770+
if filter:
771+
# Use get_by_metadata with empty filters list and filter
772+
filter_ids = self.get_by_metadata(
773+
filters=[],
774+
user_name=None,
775+
filter=filter,
776+
knowledgebase_ids=writable_cube_ids,
777+
)
778+
779+
# If filter returned IDs, add condition for them
780+
if filter_ids:
781+
where_clauses.append("n.id IN $filter_ids")
782+
params["filter_ids"] = filter_ids
783+
784+
# If no conditions (except user_name), return 0
785+
if not where_clauses:
786+
logger.warning(
787+
"[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)"
788+
)
789+
return 0
790+
791+
# Build WHERE clause
792+
# First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match)
793+
data_conditions = " OR ".join([f"({clause})" for clause in where_clauses])
794+
795+
# Then, combine with user_name condition using AND (must match user_name AND one of the data conditions)
796+
user_name_where = " OR ".join(user_name_conditions)
797+
ids_where = f"({user_name_where}) AND ({data_conditions})"
798+
799+
logger.info(
800+
f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}"
801+
)
802+
print(
803+
f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}"
804+
)
805+
806+
# First count matching nodes to get accurate count
807+
count_query = f"MATCH (n:Memory) WHERE {ids_where} RETURN count(n) AS node_count"
808+
logger.info(f"[delete_node_by_prams] count_query: {count_query}")
809+
print(f"[delete_node_by_prams] count_query: {count_query}")
810+
811+
# Then delete nodes
812+
delete_query = f"MATCH (n:Memory) WHERE {ids_where} DETACH DELETE n"
813+
logger.info(f"[delete_node_by_prams] delete_query: {delete_query}")
814+
print(f"[delete_node_by_prams] delete_query: {delete_query}")
815+
print(f"[delete_node_by_prams] params: {params}")
816+
817+
deleted_count = 0
818+
try:
819+
with self.driver.session(database=self.db_name) as session:
820+
# Count nodes before deletion
821+
count_result = session.run(count_query, **params)
822+
count_record = count_result.single()
823+
expected_count = 0
824+
if count_record:
825+
expected_count = count_record["node_count"] or 0
826+
827+
# Delete nodes
828+
session.run(delete_query, **params)
829+
# Use the count from before deletion as the actual deleted count
830+
deleted_count = expected_count
831+
832+
except Exception as e:
833+
logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True)
834+
raise
835+
836+
logger.info(f"[delete_node_by_prams] Successfully deleted {deleted_count} nodes")
837+
return deleted_count
838+
709839
def clear(self, user_name: str | None = None) -> None:
710840
"""
711841
Clear the entire graph if the target database exists.

src/memos/graph_dbs/polardb.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4113,6 +4113,7 @@ def parse_filter(
41134113
"memory_type",
41144114
"node_type",
41154115
"info",
4116+
"source",
41164117
}
41174118

41184119
def process_condition(condition):
@@ -4216,7 +4217,7 @@ def delete_node_by_prams(
42164217
file_id_and_conditions.append(f"'{escaped_id}' IN n.file_ids")
42174218
if file_id_and_conditions:
42184219
# Use AND to require all file_ids to be present
4219-
where_conditions.append(f"({' AND '.join(file_id_and_conditions)})")
4220+
where_conditions.append(f"({' OR '.join(file_id_and_conditions)})")
42204221

42214222
# Query nodes by filter if provided
42224223
filter_ids = set()

0 commit comments

Comments
 (0)