Skip to content

Commit 58fd663

Browse files
committed
Merge remote-tracking branch 'upstream/dev' into dev
2 parents 7c4db5c + 666698d commit 58fd663

File tree

12 files changed

+385
-256
lines changed

12 files changed

+385
-256
lines changed

src/memos/api/handlers/component_init.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def init_server() -> dict[str, Any]:
247247
config_factory=pref_retriever_config,
248248
llm_provider=llm,
249249
embedder=embedder,
250-
reranker=reranker,
250+
reranker=feedback_reranker,
251251
vector_db=vector_db,
252252
)
253253
if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true"
@@ -262,7 +262,7 @@ def init_server() -> dict[str, Any]:
262262
extractor_llm=llm,
263263
vector_db=vector_db,
264264
embedder=embedder,
265-
reranker=reranker,
265+
reranker=feedback_reranker,
266266
extractor=pref_extractor,
267267
adder=pref_adder,
268268
retriever=pref_retriever,

src/memos/graph_dbs/polardb.py

Lines changed: 38 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4792,35 +4792,35 @@ def delete_node_by_prams(
47924792
# Build user_name condition from writable_cube_ids (OR relationship - match any cube_id)
47934793
user_name_conditions = []
47944794
for cube_id in writable_cube_ids:
4795-
# Escape single quotes in cube IDs
4796-
escaped_cube_id = str(cube_id).replace("'", "\\'")
4797-
user_name_conditions.append(f"n.user_name = '{escaped_cube_id}'")
4795+
# Use agtype_access_operator with VARIADIC ARRAY format for consistency
4796+
user_name_conditions.append(
4797+
f"agtype_access_operator(VARIADIC ARRAY[properties, '\"user_name\"'::agtype]) = '\"{cube_id}\"'::agtype"
4798+
)
47984799

47994800
# Build WHERE conditions separately for memory_ids and file_ids
48004801
where_conditions = []
48014802

4802-
# Handle memory_ids: query n.id
4803+
# Handle memory_ids: query properties.id
48034804
if memory_ids and len(memory_ids) > 0:
48044805
memory_id_conditions = []
48054806
for node_id in memory_ids:
4806-
# Escape single quotes in node IDs
4807-
escaped_id = str(node_id).replace("'", "\\'")
4808-
memory_id_conditions.append(f"'{escaped_id}'")
4807+
memory_id_conditions.append(
4808+
f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype"
4809+
)
48094810
if memory_id_conditions:
4810-
where_conditions.append(f"n.id IN [{', '.join(memory_id_conditions)}]")
4811+
where_conditions.append(f"({' OR '.join(memory_id_conditions)})")
48114812

4812-
# Handle file_ids: query n.file_ids field
4813-
# All file_ids must be present in the array field (AND relationship)
4813+
# Check if any file_id is in the file_ids array field (OR relationship)
48144814
if file_ids and len(file_ids) > 0:
4815-
file_id_and_conditions = []
4815+
file_id_conditions = []
48164816
for file_id in file_ids:
4817-
# Escape single quotes in file IDs
4818-
escaped_id = str(file_id).replace("'", "\\'")
4819-
# Check if this file_id is in the file_ids array field
4820-
file_id_and_conditions.append(f"'{escaped_id}' IN n.file_ids")
4821-
if file_id_and_conditions:
4822-
# Use AND to require all file_ids to be present
4823-
where_conditions.append(f"({' OR '.join(file_id_and_conditions)})")
4817+
# Format: agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '"file_ids"'::agtype]), '"file_id"'::agtype)
4818+
file_id_conditions.append(
4819+
f"agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '\"file_ids\"'::agtype]), '\"{file_id}\"'::agtype)"
4820+
)
4821+
if file_id_conditions:
4822+
# Use OR to match any file_id in the array
4823+
where_conditions.append(f"({' OR '.join(file_id_conditions)})")
48244824

48254825
# Query nodes by filter if provided
48264826
filter_ids = set()
@@ -4846,11 +4846,11 @@ def delete_node_by_prams(
48464846
if filter_ids:
48474847
filter_id_conditions = []
48484848
for node_id in filter_ids:
4849-
# Escape single quotes in node IDs
4850-
escaped_id = str(node_id).replace("'", "\\'")
4851-
filter_id_conditions.append(f"'{escaped_id}'")
4849+
filter_id_conditions.append(
4850+
f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype"
4851+
)
48524852
if filter_id_conditions:
4853-
where_conditions.append(f"n.id IN [{', '.join(filter_id_conditions)}]")
4853+
where_conditions.append(f"({' OR '.join(filter_id_conditions)})")
48544854

48554855
# If no conditions (except user_name), return 0
48564856
if not where_conditions:
@@ -4865,26 +4865,21 @@ def delete_node_by_prams(
48654865

48664866
# Then, combine with user_name condition using AND (must match user_name AND one of the data conditions)
48674867
user_name_where = " OR ".join(user_name_conditions)
4868-
ids_where = f"{user_name_where} AND ({data_conditions})"
4868+
where_clause = f"({user_name_where}) AND ({data_conditions})"
48694869

4870-
# Use Cypher DELETE query
4870+
# Use SQL DELETE query for better performance
48714871
# First count matching nodes to get accurate count
48724872
count_query = f"""
4873-
SELECT * FROM cypher('{self.db_name}_graph', $$
4874-
MATCH (n:Memory)
4875-
WHERE {ids_where}
4876-
RETURN count(n) AS node_count
4877-
$$) AS (node_count agtype)
4873+
SELECT COUNT(*)
4874+
FROM "{self.db_name}_graph"."Memory"
4875+
WHERE {where_clause}
48784876
"""
48794877
logger.info(f"[delete_node_by_prams] count_query: {count_query}")
48804878

48814879
# Then delete nodes
48824880
delete_query = f"""
4883-
SELECT * FROM cypher('{self.db_name}_graph', $$
4884-
MATCH (n:Memory)
4885-
WHERE {ids_where}
4886-
DETACH DELETE n
4887-
$$) AS (result agtype)
4881+
DELETE FROM "{self.db_name}_graph"."Memory"
4882+
WHERE {where_clause}
48884883
"""
48894884

48904885
logger.info(
@@ -4899,20 +4894,20 @@ def delete_node_by_prams(
48994894
with conn.cursor() as cursor:
49004895
# Count nodes before deletion
49014896
cursor.execute(count_query)
4902-
count_results = cursor.fetchall()
4903-
expected_count = 0
4904-
if count_results and len(count_results) > 0:
4905-
count_str = str(count_results[0][0])
4906-
count_str = count_str.strip('"').strip("'")
4907-
expected_count = int(count_str) if count_str.isdigit() else 0
4897+
count_result = cursor.fetchone()
4898+
expected_count = count_result[0] if count_result else 0
4899+
4900+
logger.info(
4901+
f"[delete_node_by_prams] Found {expected_count} nodes matching the criteria"
4902+
)
49084903

49094904
# Delete nodes
49104905
cursor.execute(delete_query)
4911-
# Use the count from before deletion as the actual deleted count
4912-
deleted_count = expected_count
4906+
# Use rowcount to get actual deleted count
4907+
deleted_count = cursor.rowcount
49134908
elapsed_time = time.time() - batch_start_time
49144909
logger.info(
4915-
f"[delete_node_by_prams] execute_values completed successfully in {elapsed_time:.2f}s"
4910+
f"[delete_node_by_prams] Deletion completed successfully in {elapsed_time:.2f}s, deleted {deleted_count} nodes"
49164911
)
49174912
except Exception as e:
49184913
logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True)

src/memos/mem_reader/multi_modal_struct.py

Lines changed: 78 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
from memos.configs.mem_reader import MultiModalStructMemReaderConfig
99
from memos.context.context import ContextThreadPoolExecutor
1010
from memos.mem_reader.read_multi_modal import MultiModalParser, detect_lang
11+
from memos.mem_reader.read_multi_modal.base import _derive_key
1112
from memos.mem_reader.simple_struct import PROMPT_DICT, SimpleStructMemReader
12-
from memos.memories.textual.item import TextualMemoryItem
13+
from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata
1314
from memos.templates.tool_mem_prompts import TOOL_TRAJECTORY_PROMPT_EN, TOOL_TRAJECTORY_PROMPT_ZH
1415
from memos.types import MessagesType
1516
from memos.utils import timed
@@ -184,6 +185,33 @@ def _concat_multi_modal_memories(
184185
if window:
185186
windows.append(window)
186187

188+
# Batch compute embeddings for all windows
189+
if windows:
190+
# Collect all valid windows that need embedding
191+
valid_windows = [w for w in windows if w and w.memory]
192+
193+
if valid_windows:
194+
# Collect all texts that need embedding
195+
texts_to_embed = [w.memory for w in valid_windows]
196+
197+
# Batch compute all embeddings at once
198+
try:
199+
embeddings = self.embedder.embed(texts_to_embed)
200+
# Fill embeddings back into memory items
201+
for window, embedding in zip(valid_windows, embeddings, strict=True):
202+
window.metadata.embedding = embedding
203+
except Exception as e:
204+
logger.error(f"[MultiModalStruct] Error batch computing embeddings: {e}")
205+
# Fallback: compute embeddings individually
206+
for window in valid_windows:
207+
if window.memory:
208+
try:
209+
window.metadata.embedding = self.embedder.embed([window.memory])[0]
210+
except Exception as e2:
211+
logger.error(
212+
f"[MultiModalStruct] Error computing embedding for item: {e2}"
213+
)
214+
187215
return windows
188216

189217
def _build_window_from_items(
@@ -247,17 +275,35 @@ def _build_window_from_items(
247275
# If no text content, return None
248276
return None
249277

250-
# Create aggregated memory item (similar to _build_fast_node in simple_struct)
278+
# Create aggregated memory item without embedding (will be computed in batch later)
251279
extra_kwargs: dict[str, Any] = {}
252280
if aggregated_file_ids:
253281
extra_kwargs["file_ids"] = aggregated_file_ids
254-
aggregated_item = self._make_memory_item(
255-
value=merged_text,
256-
info=info,
257-
memory_type=memory_type,
258-
tags=["mode:fast"],
259-
sources=all_sources,
260-
**extra_kwargs,
282+
283+
# Extract info fields
284+
info_ = info.copy()
285+
user_id = info_.pop("user_id", "")
286+
session_id = info_.pop("session_id", "")
287+
288+
# Create memory item without embedding (set to None, will be filled in batch)
289+
aggregated_item = TextualMemoryItem(
290+
memory=merged_text,
291+
metadata=TreeNodeTextualMemoryMetadata(
292+
user_id=user_id,
293+
session_id=session_id,
294+
memory_type=memory_type,
295+
status="activated",
296+
tags=["mode:fast"],
297+
key=_derive_key(merged_text),
298+
embedding=None, # Will be computed in batch
299+
usage=[],
300+
sources=all_sources,
301+
background="",
302+
confidence=0.99,
303+
type="fact",
304+
info=info_,
305+
**extra_kwargs,
306+
),
261307
)
262308

263309
return aggregated_item
@@ -282,22 +328,23 @@ def _get_llm_response(
282328
Returns:
283329
LLM response dictionary
284330
"""
285-
# Try to extract actual text content from sources for better language detection
286-
text_for_lang_detection = mem_str
331+
# Determine language: prioritize lang from sources (set in fast mode),
332+
# fallback to detecting from mem_str if sources don't have lang
333+
lang = None
334+
335+
# First, try to get lang from sources (fast mode already set this)
287336
if sources:
288-
source_texts = []
289337
for source in sources:
290-
if hasattr(source, "content") and source.content:
291-
source_texts.append(source.content)
292-
elif isinstance(source, dict) and source.get("content"):
293-
source_texts.append(source.get("content"))
294-
295-
# If we have text content from sources, use it for language detection
296-
if source_texts:
297-
text_for_lang_detection = " ".join(source_texts)
298-
299-
# Use the extracted text for language detection
300-
lang = detect_lang(text_for_lang_detection)
338+
if hasattr(source, "lang") and source.lang:
339+
lang = source.lang
340+
break
341+
elif isinstance(source, dict) and source.get("lang"):
342+
lang = source.get("lang")
343+
break
344+
345+
# Fallback: detect language from mem_str if no lang from sources
346+
if lang is None:
347+
lang = detect_lang(mem_str)
301348

302349
# Select prompt template based on prompt_type
303350
if prompt_type == "doc":
@@ -574,8 +621,13 @@ def _process_multi_modal_data(
574621
for fast_item in fast_memory_items:
575622
sources = fast_item.metadata.sources
576623
for source in sources:
624+
lang = getattr(source, "lang", "en")
577625
items = self.multi_modal_parser.process_transfer(
578-
source, context_items=[fast_item], custom_tags=custom_tags, info=info
626+
source,
627+
context_items=[fast_item],
628+
custom_tags=custom_tags,
629+
info=info,
630+
lang=lang,
579631
)
580632
fine_memory_items.extend(items)
581633
return fine_memory_items
@@ -616,8 +668,9 @@ def _process_transfer_multi_modal_data(
616668

617669
# Part B: get fine multimodal items
618670
for source in sources:
671+
lang = getattr(source, "lang", "en")
619672
items = self.multi_modal_parser.process_transfer(
620-
source, context_items=[raw_node], info=info, custom_tags=custom_tags
673+
source, context_items=[raw_node], info=info, custom_tags=custom_tags, lang=lang
621674
)
622675
fine_memory_items.extend(items)
623676
return fine_memory_items

0 commit comments

Comments
 (0)