Skip to content

Commit ee6fa28

Browse files
committed
Merge branch 'feat/complete_multi_modal' of github.com:CaralHsi/MemOSRealPublic into feat/complete_multi_modal
2 parents d22f329 + c155e63 commit ee6fa28

File tree

13 files changed

+335
-85
lines changed

13 files changed

+335
-85
lines changed

src/memos/api/product_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ class APIADDRequest(BaseRequest):
469469
),
470470
)
471471

472-
info: dict[str, str] | None = Field(
472+
info: dict[str, Any] | None = Field(
473473
None,
474474
description=(
475475
"Additional metadata for the add request. "

src/memos/graph_dbs/neo4j.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1441,17 +1441,24 @@ def build_filter_condition(condition_dict: dict, param_counter: list) -> tuple[s
14411441
f"{node_alias}.{key} {cypher_op} ${param_name}"
14421442
)
14431443
elif op == "contains":
1444-
# Handle contains operator (for array fields like tags, sources)
1445-
param_name = f"filter_{key}_{op}_{param_counter[0]}"
1446-
param_counter[0] += 1
1447-
params[param_name] = op_value
1448-
1449-
# For array fields, check if element is in array
1450-
if key in ("tags", "sources"):
1451-
condition_parts.append(f"${param_name} IN {node_alias}.{key}")
1452-
else:
1453-
# For non-array fields, contains might not be applicable, but we'll treat it as IN for consistency
1454-
condition_parts.append(f"${param_name} IN {node_alias}.{key}")
1444+
# Handle contains operator (for array fields)
1445+
# Only supports array format: {"field": {"contains": ["value1", "value2"]}}
1446+
# Single string values are not supported, use array format instead: {"field": {"contains": ["value"]}}
1447+
if not isinstance(op_value, list):
1448+
raise ValueError(
1449+
f"contains operator only supports array format. "
1450+
f"Use {{'{key}': {{'contains': ['{op_value}']}}}} instead of {{'{key}': {{'contains': '{op_value}'}}}}"
1451+
)
1452+
# Handle array of values: generate AND conditions for each value (all must be present)
1453+
and_conditions = []
1454+
for item in op_value:
1455+
param_name = f"filter_{key}_{op}_{param_counter[0]}"
1456+
param_counter[0] += 1
1457+
params[param_name] = item
1458+
# For array fields, check if element is in array
1459+
and_conditions.append(f"${param_name} IN {node_alias}.{key}")
1460+
if and_conditions:
1461+
condition_parts.append(f"({' AND '.join(and_conditions)})")
14551462
elif op == "like":
14561463
# Handle like operator (for fuzzy matching, similar to SQL LIKE '%value%')
14571464
# Neo4j uses CONTAINS for string matching

src/memos/graph_dbs/polardb.py

Lines changed: 64 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3443,23 +3443,40 @@ def build_cypher_filter_condition(condition_dict: dict) -> str:
34433443
condition_parts.append(f"n.{key} = {op_value}")
34443444
elif op == "contains":
34453445
# Handle contains operator (for array fields)
3446+
# Only supports array format: {"field": {"contains": ["value1", "value2"]}}
3447+
# Single string values are not supported, use array format instead: {"field": {"contains": ["value"]}}
3448+
if not isinstance(op_value, list):
3449+
raise ValueError(
3450+
f"contains operator only supports array format. "
3451+
f"Use {{'{key}': {{'contains': ['{op_value}']}}}} instead of {{'{key}': {{'contains': '{op_value}'}}}}"
3452+
)
34463453
# Check if key starts with "info." prefix
34473454
if key.startswith("info."):
34483455
info_field = key[5:] # Remove "info." prefix
3449-
if isinstance(op_value, str):
3450-
escaped_value = escape_cypher_string(op_value)
3451-
condition_parts.append(
3452-
f"'{escaped_value}' IN n.info.{info_field}"
3453-
)
3454-
else:
3455-
condition_parts.append(f"{op_value} IN n.info.{info_field}")
3456+
# Handle array of values: generate AND conditions for each value (all must be present)
3457+
and_conditions = []
3458+
for item in op_value:
3459+
if isinstance(item, str):
3460+
escaped_value = escape_cypher_string(item)
3461+
and_conditions.append(
3462+
f"'{escaped_value}' IN n.info.{info_field}"
3463+
)
3464+
else:
3465+
and_conditions.append(f"{item} IN n.info.{info_field}")
3466+
if and_conditions:
3467+
condition_parts.append(f"({' AND '.join(and_conditions)})")
34563468
else:
34573469
# Direct property access
3458-
if isinstance(op_value, str):
3459-
escaped_value = escape_cypher_string(op_value)
3460-
condition_parts.append(f"'{escaped_value}' IN n.{key}")
3461-
else:
3462-
condition_parts.append(f"{op_value} IN n.{key}")
3470+
# Handle array of values: generate AND conditions for each value (all must be present)
3471+
and_conditions = []
3472+
for item in op_value:
3473+
if isinstance(item, str):
3474+
escaped_value = escape_cypher_string(item)
3475+
and_conditions.append(f"'{escaped_value}' IN n.{key}")
3476+
else:
3477+
and_conditions.append(f"{item} IN n.{key}")
3478+
if and_conditions:
3479+
condition_parts.append(f"({' AND '.join(and_conditions)})")
34633480
elif op == "like":
34643481
# Handle like operator (for fuzzy matching, similar to SQL LIKE '%value%')
34653482
# Check if key starts with "info." prefix
@@ -3668,29 +3685,46 @@ def build_filter_condition(condition_dict: dict) -> str:
36683685
)
36693686
elif op == "contains":
36703687
# Handle contains operator (for array fields) - use @> operator
3688+
# Only supports array format: {"field": {"contains": ["value1", "value2"]}}
3689+
# Single string values are not supported, use array format instead: {"field": {"contains": ["value"]}}
3690+
if not isinstance(op_value, list):
3691+
raise ValueError(
3692+
f"contains operator only supports array format. "
3693+
f"Use {{'{key}': {{'contains': ['{op_value}']}}}} instead of {{'{key}': {{'contains': '{op_value}'}}}}"
3694+
)
36713695
# Check if key starts with "info." prefix
36723696
if key.startswith("info."):
36733697
info_field = key[5:] # Remove "info." prefix
3674-
if isinstance(op_value, str):
3675-
escaped_value = escape_sql_string(op_value)
3676-
condition_parts.append(
3677-
f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) @> '\"{escaped_value}\"'::agtype"
3678-
)
3679-
else:
3680-
condition_parts.append(
3681-
f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) @> {op_value}::agtype"
3682-
)
3698+
# Handle array of values: generate AND conditions for each value (all must be present)
3699+
and_conditions = []
3700+
for item in op_value:
3701+
if isinstance(item, str):
3702+
escaped_value = escape_sql_string(item)
3703+
and_conditions.append(
3704+
f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) @> '\"{escaped_value}\"'::agtype"
3705+
)
3706+
else:
3707+
and_conditions.append(
3708+
f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) @> {item}::agtype"
3709+
)
3710+
if and_conditions:
3711+
condition_parts.append(f"({' AND '.join(and_conditions)})")
36833712
else:
36843713
# Direct property access
3685-
if isinstance(op_value, str):
3686-
escaped_value = escape_sql_string(op_value)
3687-
condition_parts.append(
3688-
f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) @> '\"{escaped_value}\"'::agtype"
3689-
)
3690-
else:
3691-
condition_parts.append(
3692-
f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) @> {op_value}::agtype"
3693-
)
3714+
# Handle array of values: generate AND conditions for each value (all must be present)
3715+
and_conditions = []
3716+
for item in op_value:
3717+
if isinstance(item, str):
3718+
escaped_value = escape_sql_string(item)
3719+
and_conditions.append(
3720+
f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) @> '\"{escaped_value}\"'::agtype"
3721+
)
3722+
else:
3723+
and_conditions.append(
3724+
f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) @> {item}::agtype"
3725+
)
3726+
if and_conditions:
3727+
condition_parts.append(f"({' AND '.join(and_conditions)})")
36943728
elif op == "like":
36953729
# Handle like operator (for fuzzy matching, similar to SQL LIKE '%value%')
36963730
# Check if key starts with "info." prefix

src/memos/mem_scheduler/optimized_scheduler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ def mix_search_memories(
138138
target_session_id = search_req.session_id
139139
if not target_session_id:
140140
target_session_id = "default_session"
141-
search_filter = {"session_id": search_req.session_id} if search_req.session_id else None
141+
search_priority = {"session_id": search_req.session_id} if search_req.session_id else None
142+
search_filter = search_req.filter
142143

143144
# Rerank Memories - reranker expects TextualMemoryItem objects
144145

@@ -155,6 +156,7 @@ def mix_search_memories(
155156
mode=SearchMode.FAST,
156157
manual_close_internet=not search_req.internet_search,
157158
search_filter=search_filter,
159+
search_priority=search_priority,
158160
info=info,
159161
)
160162

@@ -178,7 +180,7 @@ def mix_search_memories(
178180
query=search_req.query, # Use search_req.query instead of undefined query
179181
graph_results=history_memories, # Pass TextualMemoryItem objects directly
180182
top_k=search_req.top_k, # Use search_req.top_k instead of undefined top_k
181-
search_filter=search_filter,
183+
search_priority=search_priority,
182184
)
183185
logger.info(f"Reranked {len(sorted_history_memories)} history memories.")
184186
processed_hist_mem = self.searcher.post_retrieve(
@@ -234,6 +236,7 @@ def mix_search_memories(
234236
mode=SearchMode.FAST,
235237
memory_type="All",
236238
search_filter=search_filter,
239+
search_priority=search_priority,
237240
info=info,
238241
)
239242
else:

src/memos/memories/textual/prefer_text_memory/retrievers.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@ def __init__(self, llm_provider=None, embedder=None, reranker=None, vector_db=No
1717

1818
@abstractmethod
1919
def retrieve(
20-
self, query: str, top_k: int, info: dict[str, Any] | None = None
20+
self,
21+
query: str,
22+
top_k: int,
23+
info: dict[str, Any] | None = None,
24+
search_filter: dict[str, Any] | None = None,
2125
) -> list[TextualMemoryItem]:
2226
"""Retrieve memories from the retriever."""
2327

@@ -76,14 +80,19 @@ def _original_text_reranker(
7680
return prefs_mem
7781

7882
def retrieve(
79-
self, query: str, top_k: int, info: dict[str, Any] | None = None
83+
self,
84+
query: str,
85+
top_k: int,
86+
info: dict[str, Any] | None = None,
87+
search_filter: dict[str, Any] | None = None,
8088
) -> list[TextualMemoryItem]:
8189
"""Retrieve memories from the naive retriever."""
8290
# TODO: un-support rewrite query and session filter now
8391
if info:
8492
info = info.copy() # Create a copy to avoid modifying the original
8593
info.pop("chat_history", None)
8694
info.pop("session_id", None)
95+
search_filter = {"and": [info, search_filter]}
8796
query_embeddings = self.embedder.embed([query]) # Pass as list to get list of embeddings
8897
query_embedding = query_embeddings[0] # Get the first (and only) embedding
8998

@@ -96,15 +105,15 @@ def retrieve(
96105
query,
97106
"explicit_preference",
98107
top_k * 2,
99-
info,
108+
search_filter,
100109
)
101110
future_implicit = executor.submit(
102111
self.vector_db.search,
103112
query_embedding,
104113
query,
105114
"implicit_preference",
106115
top_k * 2,
107-
info,
116+
search_filter,
108117
)
109118

110119
# Wait for all results

src/memos/memories/textual/preference.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ def get_memory(
7676
"""
7777
return self.extractor.extract(messages, type, info)
7878

79-
def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]:
79+
def search(
80+
self, query: str, top_k: int, info=None, search_filter=None, **kwargs
81+
) -> list[TextualMemoryItem]:
8082
"""Search for memories based on a query.
8183
Args:
8284
query (str): The query to search for.
@@ -85,7 +87,8 @@ def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMem
8587
Returns:
8688
list[TextualMemoryItem]: List of matching memories.
8789
"""
88-
return self.retriever.retrieve(query, top_k, info)
90+
logger.info(f"search_filter for preference memory: {search_filter}")
91+
return self.retriever.retrieve(query, top_k, info, search_filter)
8992

9093
def load(self, dir: str) -> None:
9194
"""Load memories from the specified directory.

src/memos/memories/textual/simple_preference.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ def get_memory(
5050
"""
5151
return self.extractor.extract(messages, type, info)
5252

53-
def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]:
53+
def search(
54+
self, query: str, top_k: int, info=None, search_filter=None, **kwargs
55+
) -> list[TextualMemoryItem]:
5456
"""Search for memories based on a query.
5557
Args:
5658
query (str): The query to search for.
@@ -59,7 +61,7 @@ def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMem
5961
Returns:
6062
list[TextualMemoryItem]: List of matching memories.
6163
"""
62-
return self.retriever.retrieve(query, top_k, info)
64+
return self.retriever.retrieve(query, top_k, info, search_filter)
6365

6466
def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]:
6567
"""Add memories.

src/memos/memories/textual/tree.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def search(
162162
mode: str = "fast",
163163
memory_type: str = "All",
164164
manual_close_internet: bool = True,
165+
search_priority: dict | None = None,
165166
search_filter: dict | None = None,
166167
user_name: str | None = None,
167168
) -> list[TextualMemoryItem]:
@@ -209,7 +210,14 @@ def search(
209210
manual_close_internet=manual_close_internet,
210211
)
211212
return searcher.search(
212-
query, top_k, info, mode, memory_type, search_filter, user_name=user_name
213+
query,
214+
top_k,
215+
info,
216+
mode,
217+
memory_type,
218+
search_filter,
219+
search_priority,
220+
user_name=user_name,
213221
)
214222

215223
def get_relevant_subgraph(

0 commit comments

Comments
 (0)