Skip to content

Commit 71f8edf

Browse files
authored
Feat: sync hotfix to dev and add full text for polardb (#563)
* feat: update memos headers * feat: headers add * feat: update search agent * feat: upadte mem story * feat: update mem scehduler * feat: update deepsearch mem code * feat: update deepsearch agent * feat: update test code * fix: remove dup config * feat: dock search pipeline * fix: code test * feat: add test scripts * feat: add test * feat: update need_raw process * fix: add initter * fix: change agent search func name * feat: update logs and defined * feat: update full text mem search * feat: cp plugin to dev
1 parent e08e164 commit 71f8edf

File tree

10 files changed

+264
-13
lines changed

10 files changed

+264
-13
lines changed

src/memos/api/handlers/component_init.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from memos.memories.textual.simple_preference import SimplePreferenceTextMemory
4242
from memos.memories.textual.simple_tree import SimpleTreeTextMemory
4343
from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager
44+
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer
4445

4546

4647
if TYPE_CHECKING:
@@ -196,6 +197,7 @@ def init_server() -> dict[str, Any]:
196197

197198
logger.debug("Memory manager initialized")
198199

200+
tokenizer = FastTokenizer()
199201
# Initialize text memory
200202
text_mem = SimpleTreeTextMemory(
201203
llm=llm,
@@ -206,6 +208,7 @@ def init_server() -> dict[str, Any]:
206208
memory_manager=memory_manager,
207209
config=default_cube_config.text_mem.config,
208210
internet_retriever=internet_retriever,
211+
tokenizer=tokenizer,
209212
)
210213

211214
logger.debug("Text memory initialized")

src/memos/api/product_models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,12 @@ class APISearchRequest(BaseRequest):
388388
description="(Internal) Operation definitions for multi-cube read permissions.",
389389
)
390390

391+
# ==== Source for plugin ====
392+
source: str | None = Field(
393+
None,
394+
description="Source of the search query [plugin will router diff search]",
395+
)
396+
391397
@model_validator(mode="after")
392398
def _convert_deprecated_fields(self) -> "APISearchRequest":
393399
"""

src/memos/graph_dbs/polardb.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1451,6 +1451,130 @@ def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]:
14511451
"""Get the ordered context chain starting from a node."""
14521452
raise NotImplementedError
14531453

1454+
@timed
1455+
def search_by_fulltext(
1456+
self,
1457+
query_words: list[str],
1458+
top_k: int = 10,
1459+
scope: str | None = None,
1460+
status: str | None = None,
1461+
threshold: float | None = None,
1462+
search_filter: dict | None = None,
1463+
user_name: str | None = None,
1464+
filter: dict | None = None,
1465+
knowledgebase_ids: list[str] | None = None,
1466+
tsvector_field: str = "properties_tsvector_zh",
1467+
tsquery_config: str = "jiebaqry",
1468+
**kwargs,
1469+
) -> list[dict]:
1470+
"""
1471+
Full-text search functionality using PostgreSQL's full-text search capabilities.
1472+
1473+
Args:
1474+
query_text: query text
1475+
top_k: maximum number of results to return
1476+
scope: memory type filter (memory_type)
1477+
status: status filter, defaults to "activated"
1478+
threshold: similarity threshold filter
1479+
search_filter: additional property filter conditions
1480+
user_name: username filter
1481+
knowledgebase_ids: knowledgebase ids filter
1482+
filter: filter conditions with 'and' or 'or' logic for search results.
1483+
tsvector_field: full-text index field name, defaults to properties_tsvector_zh_1
1484+
tsquery_config: full-text search configuration, defaults to jiebaqry (Chinese word segmentation)
1485+
**kwargs: other parameters (e.g. cube_name)
1486+
1487+
Returns:
1488+
list[dict]: result list containing id and score
1489+
"""
1490+
# Build WHERE clause dynamically, same as search_by_embedding
1491+
where_clauses = []
1492+
1493+
if scope:
1494+
where_clauses.append(
1495+
f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) = '\"{scope}\"'::agtype"
1496+
)
1497+
if status:
1498+
where_clauses.append(
1499+
f"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"{status}\"'::agtype"
1500+
)
1501+
else:
1502+
where_clauses.append(
1503+
"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"activated\"'::agtype"
1504+
)
1505+
1506+
# Build user_name filter with knowledgebase_ids support (OR relationship) using common method
1507+
user_name_conditions = self._build_user_name_and_kb_ids_conditions_sql(
1508+
user_name=user_name,
1509+
knowledgebase_ids=knowledgebase_ids,
1510+
default_user_name=self.config.user_name,
1511+
)
1512+
1513+
# Add OR condition if we have any user_name conditions
1514+
if user_name_conditions:
1515+
if len(user_name_conditions) == 1:
1516+
where_clauses.append(user_name_conditions[0])
1517+
else:
1518+
where_clauses.append(f"({' OR '.join(user_name_conditions)})")
1519+
1520+
# Add search_filter conditions
1521+
if search_filter:
1522+
for key, value in search_filter.items():
1523+
if isinstance(value, str):
1524+
where_clauses.append(
1525+
f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{value}\"'::agtype"
1526+
)
1527+
else:
1528+
where_clauses.append(
1529+
f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype"
1530+
)
1531+
1532+
# Build filter conditions using common method
1533+
filter_conditions = self._build_filter_conditions_sql(filter)
1534+
where_clauses.extend(filter_conditions)
1535+
# Add fulltext search condition
1536+
# Convert query_text to OR query format: "word1 | word2 | word3"
1537+
tsquery_string = " | ".join(query_words)
1538+
1539+
where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)")
1540+
1541+
where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
1542+
1543+
# Build fulltext search query
1544+
query = f"""
1545+
SELECT
1546+
ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id,
1547+
agtype_object_field_text(properties, 'memory') as memory_text,
1548+
ts_rank({tsvector_field}, to_tsquery('{tsquery_config}', %s)) as rank
1549+
FROM "{self.db_name}_graph"."Memory"
1550+
{where_clause}
1551+
ORDER BY rank DESC
1552+
LIMIT {top_k};
1553+
"""
1554+
1555+
params = [tsquery_string, tsquery_string]
1556+
1557+
conn = self._get_connection()
1558+
try:
1559+
with conn.cursor() as cursor:
1560+
cursor.execute(query, params)
1561+
results = cursor.fetchall()
1562+
output = []
1563+
for row in results:
1564+
oldid = row[0] # old_id
1565+
rank = row[2] # rank score
1566+
1567+
id_val = str(oldid)
1568+
score_val = float(rank)
1569+
1570+
# Apply threshold filter if specified
1571+
if threshold is None or score_val >= threshold:
1572+
output.append({"id": id_val, "score": score_val})
1573+
1574+
return output[:top_k]
1575+
finally:
1576+
self._return_connection(conn)
1577+
14541578
@timed
14551579
def search_by_embedding(
14561580
self,

src/memos/memories/textual/simple_tree.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from memos.memories.textual.tree import TreeTextMemory
1010
from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager
1111
from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25
12+
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer
1213
from memos.reranker.base import BaseReranker
1314

1415

@@ -35,6 +36,7 @@ def __init__(
3536
config: TreeTextMemoryConfig,
3637
internet_retriever: None = None,
3738
is_reorganize: bool = False,
39+
tokenizer: FastTokenizer | None = None,
3840
):
3941
"""Initialize memory with the given configuration."""
4042
self.config: TreeTextMemoryConfig = config
@@ -51,6 +53,7 @@ def __init__(
5153
if self.search_strategy and self.search_strategy.get("bm25", False)
5254
else None
5355
)
56+
self.tokenizer = tokenizer
5457
self.reranker = reranker
5558
self.memory_manager: MemoryManager = memory_manager
5659
# Create internet retriever if configured

src/memos/memories/textual/tree.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def __init__(self, config: TreeTextMemoryConfig):
9191
)
9292
else:
9393
logger.info("No internet retriever configured")
94+
self.tokenizer = None
9495

9596
def add(
9697
self,
@@ -165,6 +166,7 @@ def search(
165166
search_priority: dict | None = None,
166167
search_filter: dict | None = None,
167168
user_name: str | None = None,
169+
**kwargs,
168170
) -> list[TextualMemoryItem]:
169171
"""Search for memories based on a query.
170172
User query -> TaskGoalParser -> MemoryPathResolver ->
@@ -197,6 +199,7 @@ def search(
197199
internet_retriever=None,
198200
search_strategy=self.search_strategy,
199201
manual_close_internet=manual_close_internet,
202+
tokenizer=self.tokenizer,
200203
)
201204
else:
202205
searcher = Searcher(
@@ -208,6 +211,7 @@ def search(
208211
internet_retriever=self.internet_retriever,
209212
search_strategy=self.search_strategy,
210213
manual_close_internet=manual_close_internet,
214+
tokenizer=self.tokenizer,
211215
)
212216
return searcher.search(
213217
query,
@@ -218,6 +222,7 @@ def search(
218222
search_filter,
219223
search_priority,
220224
user_name=user_name,
225+
plugin=kwargs.get("plugin", False),
221226
)
222227

223228
def get_relevant_subgraph(

src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from memos.memories.textual.item import TextualMemoryItem, TextualMemoryMetadata
1111
from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25
1212
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import (
13+
FastTokenizer,
1314
parse_structured_output,
1415
)
1516
from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
@@ -33,6 +34,7 @@ def __init__(
3334
search_strategy: dict | None = None,
3435
manual_close_internet: bool = True,
3536
process_llm: Any | None = None,
37+
tokenizer: FastTokenizer | None = None,
3638
):
3739
super().__init__(
3840
dispatcher_llm=dispatcher_llm,
@@ -43,6 +45,7 @@ def __init__(
4345
internet_retriever=internet_retriever,
4446
search_strategy=search_strategy,
4547
manual_close_internet=manual_close_internet,
48+
tokenizer=tokenizer,
4649
)
4750

4851
self.stage_retrieve_top = 3

src/memos/memories/textual/tree_text_memory/retrieve/recall.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,25 @@ def retrieve_from_cube(
148148

149149
return list(combined.values())
150150

151+
def retrieve_from_mixed(
152+
self,
153+
top_k: int,
154+
memory_scope: str | None = None,
155+
query_embedding: list[list[float]] | None = None,
156+
search_filter: dict | None = None,
157+
user_name: str | None = None,
158+
) -> list[TextualMemoryItem]:
159+
"""Retrieve from mixed and memory"""
160+
vector_results = self._vector_recall(
161+
query_embedding or [],
162+
memory_scope,
163+
top_k,
164+
search_filter=search_filter,
165+
user_name=user_name,
166+
) # Merge and deduplicate by ID
167+
combined = {item.id: item for item in vector_results}
168+
return list(combined.values())
169+
151170
def _graph_recall(
152171
self, parsed_goal: ParsedTaskGoal, memory_scope: str, user_name: str | None = None, **kwargs
153172
) -> list[TextualMemoryItem]:

src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from pathlib import Path
55
from typing import Any
66

7+
import numpy as np
8+
79
from memos.dependency import require_python_package
810
from memos.log import get_logger
911

@@ -463,3 +465,28 @@ def format_memory_item(memory_data: Any) -> dict[str, Any]:
463465
memory["metadata"]["memory"] = memory["memory"]
464466

465467
return memory
468+
469+
470+
def find_best_unrelated_subgroup(sentences: list, similarity_matrix: list, bar: float = 0.8):
471+
assert len(sentences) == len(similarity_matrix)
472+
473+
num_sentence = len(sentences)
474+
selected_sentences = []
475+
selected_indices = []
476+
for i in range(num_sentence):
477+
can_add = True
478+
for j in selected_indices:
479+
if similarity_matrix[i][j] > bar:
480+
can_add = False
481+
break
482+
if can_add:
483+
selected_sentences.append(i)
484+
selected_indices.append(i)
485+
return selected_sentences, selected_indices
486+
487+
488+
def cosine_similarity_matrix(embeddings: list[list[float]]) -> list[list[float]]:
489+
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
490+
x_normalized = embeddings / norms
491+
similarity_matrix = np.dot(x_normalized, x_normalized.T)
492+
return similarity_matrix

0 commit comments

Comments
 (0)