Skip to content

Commit f37b15b

Browse files
author
黑布林
committed
Merge branch 'dev' into dev_test
2 parents 73106ed + b7ffa5a commit f37b15b

File tree

5 files changed

+153
-57
lines changed

5 files changed

+153
-57
lines changed

src/memos/graph_dbs/polardb.py

Lines changed: 106 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4763,7 +4763,7 @@ def process_condition(condition):
47634763
@timed
47644764
def delete_node_by_prams(
47654765
self,
4766-
writable_cube_ids: list[str],
4766+
writable_cube_ids: list[str] | None = None,
47674767
memory_ids: list[str] | None = None,
47684768
file_ids: list[str] | None = None,
47694769
filter: dict | None = None,
@@ -4772,7 +4772,8 @@ def delete_node_by_prams(
47724772
Delete nodes by memory_ids, file_ids, or filter.
47734773
47744774
Args:
4775-
writable_cube_ids (list[str]): List of cube IDs (user_name) to filter nodes. Required parameter.
4775+
writable_cube_ids (list[str], optional): List of cube IDs (user_name) to filter nodes.
4776+
If not provided, no user_name filter will be applied.
47764777
memory_ids (list[str], optional): List of memory node IDs to delete.
47774778
file_ids (list[str], optional): List of file node IDs to delete.
47784779
filter (dict, optional): Filter dictionary to query matching nodes for deletion.
@@ -4785,17 +4786,15 @@ def delete_node_by_prams(
47854786
f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}"
47864787
)
47874788

4788-
# Validate writable_cube_ids
4789-
if not writable_cube_ids or len(writable_cube_ids) == 0:
4790-
raise ValueError("writable_cube_ids is required and cannot be empty")
4791-
47924789
# Build user_name condition from writable_cube_ids (OR relationship - match any cube_id)
4790+
# Only add user_name filter if writable_cube_ids is provided
47934791
user_name_conditions = []
4794-
for cube_id in writable_cube_ids:
4795-
# Use agtype_access_operator with VARIADIC ARRAY format for consistency
4796-
user_name_conditions.append(
4797-
f"agtype_access_operator(VARIADIC ARRAY[properties, '\"user_name\"'::agtype]) = '\"{cube_id}\"'::agtype"
4798-
)
4792+
if writable_cube_ids and len(writable_cube_ids) > 0:
4793+
for cube_id in writable_cube_ids:
4794+
# Use agtype_access_operator with VARIADIC ARRAY format for consistency
4795+
user_name_conditions.append(
4796+
f"agtype_access_operator(VARIADIC ARRAY[properties, '\"user_name\"'::agtype]) = '\"{cube_id}\"'::agtype"
4797+
)
47994798

48004799
# Build WHERE conditions separately for memory_ids and file_ids
48014800
where_conditions = []
@@ -4863,9 +4862,14 @@ def delete_node_by_prams(
48634862
# First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match)
48644863
data_conditions = " OR ".join([f"({cond})" for cond in where_conditions])
48654864

4866-
# Then, combine with user_name condition using AND (must match user_name AND one of the data conditions)
4867-
user_name_where = " OR ".join(user_name_conditions)
4868-
where_clause = f"({user_name_where}) AND ({data_conditions})"
4865+
# Build final WHERE clause
4866+
# If user_name_conditions exist, combine with data_conditions using AND
4867+
# Otherwise, use only data_conditions
4868+
if user_name_conditions:
4869+
user_name_where = " OR ".join(user_name_conditions)
4870+
where_clause = f"({user_name_where}) AND ({data_conditions})"
4871+
else:
4872+
where_clause = f"({data_conditions})"
48694873

48704874
# Use SQL DELETE query for better performance
48714875
# First count matching nodes to get accurate count
@@ -4917,3 +4921,91 @@ def delete_node_by_prams(
49174921

49184922
logger.info(f"[delete_node_by_prams] Successfully deleted {deleted_count} nodes")
49194923
return deleted_count
4924+
4925+
@timed
4926+
def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, list[str]]:
4927+
"""Get user names by memory ids.
4928+
4929+
Args:
4930+
memory_ids: List of memory node IDs to query.
4931+
4932+
Returns:
4933+
dict[str, list[str]]: Dictionary with one key:
4934+
- 'no_exist_memory_ids': List of memory_ids that do not exist (if any are missing)
4935+
- 'exist_user_names': List of distinct user names (if all memory_ids exist)
4936+
"""
4937+
if not memory_ids:
4938+
return {"exist_user_names": []}
4939+
4940+
# Build OR conditions for each memory_id
4941+
id_conditions = []
4942+
for mid in memory_ids:
4943+
id_conditions.append(
4944+
f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{mid}\"'::agtype"
4945+
)
4946+
4947+
where_clause = f"({' OR '.join(id_conditions)})"
4948+
4949+
# Query to check which memory_ids exist
4950+
check_query = f"""
4951+
SELECT ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype)::text
4952+
FROM "{self.db_name}_graph"."Memory"
4953+
WHERE {where_clause}
4954+
"""
4955+
4956+
logger.info(f"[get_user_names_by_memory_ids] check_query: {check_query}")
4957+
conn = None
4958+
try:
4959+
conn = self._get_connection()
4960+
with conn.cursor() as cursor:
4961+
# Check which memory_ids exist
4962+
cursor.execute(check_query)
4963+
check_results = cursor.fetchall()
4964+
existing_ids = set()
4965+
for row in check_results:
4966+
node_id = row[0]
4967+
# Remove quotes if present
4968+
if isinstance(node_id, str):
4969+
node_id = node_id.strip('"').strip("'")
4970+
existing_ids.add(node_id)
4971+
4972+
# Check if any memory_ids are missing
4973+
no_exist_list = [mid for mid in memory_ids if mid not in existing_ids]
4974+
4975+
# If any memory_ids are missing, return no_exist_memory_ids
4976+
if no_exist_list:
4977+
logger.info(
4978+
f"[get_user_names_by_memory_ids] Found {len(no_exist_list)} non-existing memory_ids: {no_exist_list}"
4979+
)
4980+
return {"no_exist_memory_ids": no_exist_list}
4981+
4982+
# All memory_ids exist, query user_names
4983+
user_names_query = f"""
4984+
SELECT DISTINCT ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype)::text
4985+
FROM "{self.db_name}_graph"."Memory"
4986+
WHERE {where_clause}
4987+
"""
4988+
logger.info(f"[get_user_names_by_memory_ids] user_names_query: {user_names_query}")
4989+
4990+
cursor.execute(user_names_query)
4991+
results = cursor.fetchall()
4992+
user_names = []
4993+
for row in results:
4994+
user_name = row[0]
4995+
# Remove quotes if present
4996+
if isinstance(user_name, str):
4997+
user_name = user_name.strip('"').strip("'")
4998+
user_names.append(user_name)
4999+
5000+
logger.info(
5001+
f"[get_user_names_by_memory_ids] All memory_ids exist, found {len(user_names)} distinct user_names"
5002+
)
5003+
5004+
return {"exist_user_names": user_names}
5005+
except Exception as e:
5006+
logger.error(
5007+
f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True
5008+
)
5009+
raise
5010+
finally:
5011+
self._return_connection(conn)

src/memos/mem_reader/read_multi_modal/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import re
66

7-
from datetime import datetime, timezone
7+
from datetime import datetime
88
from typing import Any, TypeAlias
99
from urllib.parse import urlparse
1010

@@ -245,8 +245,8 @@ def coerce_scene_data(scene_data: SceneDataInput, scene_type: str) -> list[Messa
245245

246246
# Default timestamp
247247
if chat_time_value is None:
248-
session_date = datetime.now(timezone.utc)
249-
date_format = "%I:%M %p on %d %B, %Y UTC"
248+
session_date = datetime.now()
249+
date_format = "%I:%M %p on %d %B, %Y"
250250
chat_time_value = session_date.strftime(date_format)
251251

252252
# Inject chat_time

src/memos/mem_scheduler/task_schedule_modules/redis_queue.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -699,27 +699,23 @@ def _batch_claim_pending_messages(
699699
results = []
700700
try:
701701
results = pipe.execute()
702-
except Exception as e:
703-
err_msg = str(e).lower()
704-
if "nogroup" in err_msg or "no such key" in err_msg:
705-
# Fallback: attempt sequential xautoclaim for robustness
706-
for stream_key, need_count, label in claims_spec:
707-
try:
708-
self._ensure_consumer_group(stream_key=stream_key)
709-
res = self._redis_conn.xautoclaim(
710-
name=stream_key,
711-
groupname=self.consumer_group,
712-
consumername=self.consumer_name,
713-
min_idle_time=self.orchestrator.get_task_idle_min(task_label=label),
714-
start_id="0-0",
715-
count=need_count,
716-
justid=False,
717-
)
718-
results.append(res)
719-
except Exception:
720-
continue
721-
else:
722-
logger.error(f"Pipeline xautoclaim failed: {e}")
702+
except Exception:
703+
# Fallback: attempt sequential xautoclaim for robustness
704+
for stream_key, need_count, label in claims_spec:
705+
try:
706+
self._ensure_consumer_group(stream_key=stream_key)
707+
res = self._redis_conn.xautoclaim(
708+
name=stream_key,
709+
groupname=self.consumer_group,
710+
consumername=self.consumer_name,
711+
min_idle_time=self.orchestrator.get_task_idle_min(task_label=label),
712+
start_id="0-0",
713+
count=need_count,
714+
justid=False,
715+
)
716+
results.append(res)
717+
except Exception:
718+
continue
723719

724720
claimed_pairs: list[tuple[str, list[tuple[str, dict]]]] = []
725721
for (stream_key, _need_count, _label), claimed_result in zip(

src/memos/multi_mem_cube/composite_cube.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from concurrent.futures import ThreadPoolExecutor, as_completed
34
from dataclasses import dataclass
45
from typing import TYPE_CHECKING, Any
56

@@ -46,21 +47,30 @@ def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]:
4647
"tool_mem": [],
4748
}
4849

49-
for view in self.cube_views:
50+
def _search_single_cube(view: SingleCubeView) -> dict[str, Any]:
5051
self.logger.info(f"[CompositeCubeView] fan-out search to cube={view.cube_id}")
51-
cube_result = view.search_memories(search_req)
52-
merged_results["text_mem"].extend(cube_result.get("text_mem", []))
53-
merged_results["act_mem"].extend(cube_result.get("act_mem", []))
54-
merged_results["para_mem"].extend(cube_result.get("para_mem", []))
55-
merged_results["pref_mem"].extend(cube_result.get("pref_mem", []))
56-
merged_results["tool_mem"].extend(cube_result.get("tool_mem", []))
57-
58-
note = cube_result.get("pref_note")
59-
if note:
60-
if merged_results["pref_note"]:
61-
merged_results["pref_note"] += " | " + note
62-
else:
63-
merged_results["pref_note"] = note
52+
return view.search_memories(search_req)
53+
54+
# parallel search for each cube
55+
with ThreadPoolExecutor(max_workers=2) as executor:
56+
future_to_view = {
57+
executor.submit(_search_single_cube, view): view for view in self.cube_views
58+
}
59+
60+
for future in as_completed(future_to_view):
61+
cube_result = future.result()
62+
merged_results["text_mem"].extend(cube_result.get("text_mem", []))
63+
merged_results["act_mem"].extend(cube_result.get("act_mem", []))
64+
merged_results["para_mem"].extend(cube_result.get("para_mem", []))
65+
merged_results["pref_mem"].extend(cube_result.get("pref_mem", []))
66+
merged_results["tool_mem"].extend(cube_result.get("tool_mem", []))
67+
68+
note = cube_result.get("pref_note")
69+
if note:
70+
if merged_results["pref_note"]:
71+
merged_results["pref_note"] += " | " + note
72+
else:
73+
merged_results["pref_note"] = note
6474

6575
return merged_results
6676

src/memos/utils.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@
66

77
logger = get_logger(__name__)
88

9-
# Global threshold (seconds) for timing logs
10-
DEFAULT_TIME_BAR = 10.0
11-
129

1310
def timed_with_status(
1411
func=None,
@@ -97,7 +94,7 @@ def wrapper(*args, **kwargs):
9794
return decorator(func)
9895

9996

100-
def timed(func=None, *, log=False, log_prefix=""):
97+
def timed(func=None, *, log=True, log_prefix=""):
10198
def decorator(fn):
10299
def wrapper(*args, **kwargs):
103100
start = time.perf_counter()
@@ -107,7 +104,8 @@ def wrapper(*args, **kwargs):
107104
if log is not True:
108105
return result
109106

110-
if elapsed_ms >= (DEFAULT_TIME_BAR * 1000.0):
107+
# 100ms threshold
108+
if elapsed_ms >= 100.0:
111109
logger.info(f"[TIMER] {log_prefix or fn.__name__} took {elapsed_ms:.0f} ms")
112110

113111
return result

0 commit comments

Comments
 (0)