Skip to content

Commit 57e42f1

Browse files
authored
feat: add full text memory by jieba (#532)
* feat: add full text memory * fix: remove search
1 parent 5d434ea commit 57e42f1

File tree

9 files changed

+240
-15
lines changed

9 files changed

+240
-15
lines changed

src/memos/api/handlers/component_init.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from memos.memories.textual.simple_preference import SimplePreferenceTextMemory
4141
from memos.memories.textual.simple_tree import SimpleTreeTextMemory
4242
from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager
43+
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer
4344

4445

4546
if TYPE_CHECKING:
@@ -142,7 +143,7 @@ def init_server() -> dict[str, Any]:
142143
)
143144

144145
logger.debug("Memory manager initialized")
145-
146+
tokenizer = FastTokenizer()
146147
# Initialize text memory
147148
text_mem = SimpleTreeTextMemory(
148149
llm=llm,
@@ -153,6 +154,7 @@ def init_server() -> dict[str, Any]:
153154
memory_manager=memory_manager,
154155
config=default_cube_config.text_mem.config,
155156
internet_retriever=internet_retriever,
157+
tokenizer=tokenizer,
156158
)
157159

158160
logger.debug("Text memory initialized")
@@ -270,7 +272,6 @@ def init_server() -> dict[str, Any]:
270272

271273
online_bot = get_online_bot_function() if dingding_enabled else None
272274
logger.info("DingDing bot is enabled")
273-
274275
# Return all components as a dictionary for easy access and extension
275276
return {
276277
"graph_db": graph_db,

src/memos/api/handlers/search_handler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def _fast_search(
191191
"""
192192
target_session_id = search_req.session_id or "default_session"
193193
search_filter = {"session_id": search_req.session_id} if search_req.session_id else None
194-
194+
plugin = bool(search_req.info is not None and search_req.info.get("origin_model"))
195195
search_results = self.naive_mem_cube.text_mem.search(
196196
query=search_req.query,
197197
user_name=user_context.mem_cube_id,
@@ -205,6 +205,7 @@ def _fast_search(
205205
"session_id": target_session_id,
206206
"chat_history": search_req.chat_history,
207207
},
208+
plugin=plugin,
208209
)
209210

210211
formatted_memories = [format_memory_item(data) for data in search_results]

src/memos/api/product_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ class APISearchRequest(BaseRequest):
185185
)
186186
include_preference: bool = Field(True, description="Whether to handle preference memory")
187187
pref_top_k: int = Field(6, description="Number of preference results to return")
188+
info: dict | None = Field(None, description="Info for search")
188189

189190

190191
class APIADDRequest(BaseRequest):

src/memos/graph_dbs/polardb.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,6 +1450,115 @@ def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]:
14501450
"""Get the ordered context chain starting from a node."""
14511451
raise NotImplementedError
14521452

1453+
@timed
1454+
def search_by_fulltext(
1455+
self,
1456+
query_words: list[str],
1457+
top_k: int = 10,
1458+
scope: str | None = None,
1459+
status: str | None = None,
1460+
threshold: float | None = None,
1461+
search_filter: dict | None = None,
1462+
user_name: str | None = None,
1463+
tsvector_field: str = "properties_tsvector_zh",
1464+
tsquery_config: str = "jiebaqry",
1465+
**kwargs,
1466+
) -> list[dict]:
1467+
"""
1468+
Full-text search functionality using PostgreSQL's full-text search capabilities.
1469+
1470+
Args:
1471+
query_text: query text
1472+
top_k: maximum number of results to return
1473+
scope: memory type filter (memory_type)
1474+
status: status filter, defaults to "activated"
1475+
threshold: similarity threshold filter
1476+
search_filter: additional property filter conditions
1477+
user_name: username filter
1478+
tsvector_field: full-text index field name, defaults to properties_tsvector_zh_1
1479+
tsquery_config: full-text search configuration, defaults to jiebaqry (Chinese word segmentation)
1480+
**kwargs: other parameters (e.g. cube_name)
1481+
1482+
Returns:
1483+
list[dict]: result list containing id and score
1484+
"""
1485+
# Build WHERE clause dynamically, same as search_by_embedding
1486+
where_clauses = []
1487+
1488+
if scope:
1489+
where_clauses.append(
1490+
f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) = '\"{scope}\"'::agtype"
1491+
)
1492+
if status:
1493+
where_clauses.append(
1494+
f"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"{status}\"'::agtype"
1495+
)
1496+
else:
1497+
where_clauses.append(
1498+
"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"activated\"'::agtype"
1499+
)
1500+
1501+
# Add user_name filter
1502+
user_name = user_name if user_name else self.config.user_name
1503+
where_clauses.append(
1504+
f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype"
1505+
)
1506+
1507+
# Add search_filter conditions
1508+
if search_filter:
1509+
for key, value in search_filter.items():
1510+
if isinstance(value, str):
1511+
where_clauses.append(
1512+
f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{value}\"'::agtype"
1513+
)
1514+
else:
1515+
where_clauses.append(
1516+
f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype"
1517+
)
1518+
1519+
# Add fulltext search condition
1520+
# Convert query_text to OR query format: "word1 | word2 | word3"
1521+
tsquery_string = " | ".join(query_words)
1522+
1523+
where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)")
1524+
1525+
where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
1526+
1527+
# Build fulltext search query
1528+
query = f"""
1529+
SELECT
1530+
ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id,
1531+
agtype_object_field_text(properties, 'memory') as memory_text,
1532+
ts_rank({tsvector_field}, to_tsquery('{tsquery_config}', %s)) as rank
1533+
FROM "{self.db_name}_graph"."Memory"
1534+
{where_clause}
1535+
ORDER BY rank DESC
1536+
LIMIT {top_k};
1537+
"""
1538+
1539+
params = [tsquery_string, tsquery_string]
1540+
1541+
conn = self._get_connection()
1542+
try:
1543+
with conn.cursor() as cursor:
1544+
cursor.execute(query, params)
1545+
results = cursor.fetchall()
1546+
output = []
1547+
for row in results:
1548+
oldid = row[0] # old_id
1549+
rank = row[2] # rank score
1550+
1551+
id_val = str(oldid)
1552+
score_val = float(rank)
1553+
1554+
# Apply threshold filter if specified
1555+
if threshold is None or score_val >= threshold:
1556+
output.append({"id": id_val, "score": score_val})
1557+
1558+
return output[:top_k]
1559+
finally:
1560+
self._return_connection(conn)
1561+
14531562
@timed
14541563
def search_by_embedding(
14551564
self,

src/memos/memories/textual/simple_tree.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from memos.memories.textual.tree import TreeTextMemory
1010
from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager
1111
from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25
12+
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer
1213
from memos.reranker.base import BaseReranker
1314

1415

@@ -35,6 +36,7 @@ def __init__(
3536
config: TreeTextMemoryConfig,
3637
internet_retriever: None = None,
3738
is_reorganize: bool = False,
39+
tokenizer: FastTokenizer | None = None,
3840
):
3941
"""Initialize memory with the given configuration."""
4042
self.config: TreeTextMemoryConfig = config
@@ -51,6 +53,7 @@ def __init__(
5153
if self.search_strategy and self.search_strategy.get("bm25", False)
5254
else None
5355
)
56+
self.tokenizer = tokenizer
5457
self.reranker = reranker
5558
self.memory_manager: MemoryManager = memory_manager
5659
# Create internet retriever if configured

src/memos/memories/textual/tree.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def __init__(self, config: TreeTextMemoryConfig):
8989
)
9090
else:
9191
logger.info("No internet retriever configured")
92+
self.tokenizer = None
9293

9394
def add(
9495
self,
@@ -165,6 +166,7 @@ def search(
165166
moscube: bool = False,
166167
search_filter: dict | None = None,
167168
user_name: str | None = None,
169+
**kwargs,
168170
) -> list[TextualMemoryItem]:
169171
"""Search for memories based on a query.
170172
User query -> TaskGoalParser -> MemoryPathResolver ->
@@ -199,6 +201,7 @@ def search(
199201
moscube=moscube,
200202
search_strategy=self.search_strategy,
201203
manual_close_internet=manual_close_internet,
204+
tokenizer=self.tokenizer,
202205
)
203206
else:
204207
searcher = Searcher(
@@ -211,9 +214,17 @@ def search(
211214
moscube=moscube,
212215
search_strategy=self.search_strategy,
213216
manual_close_internet=manual_close_internet,
217+
tokenizer=self.tokenizer,
214218
)
215219
return searcher.search(
216-
query, top_k, info, mode, memory_type, search_filter, user_name=user_name
220+
query,
221+
top_k,
222+
info,
223+
mode,
224+
memory_type,
225+
search_filter,
226+
user_name=user_name,
227+
plugin=kwargs.get("plugin", False),
217228
)
218229

219230
def get_relevant_subgraph(

src/memos/memories/textual/tree_text_memory/retrieve/recall.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,26 @@ def retrieve_from_cube(
143143

144144
return list(combined.values())
145145

146+
def retrieve_from_mixed(
147+
self,
148+
top_k: int,
149+
memory_scope: str | None = None,
150+
query_embedding: list[list[float]] | None = None,
151+
search_filter: dict | None = None,
152+
user_name: str | None = None,
153+
use_fast_graph: bool = False,
154+
) -> list[TextualMemoryItem]:
155+
"""Retrieve from mixed and memory"""
156+
vector_results = self._vector_recall(
157+
query_embedding or [],
158+
memory_scope,
159+
top_k,
160+
search_filter=search_filter,
161+
user_name=user_name,
162+
) # Merge and deduplicate by ID
163+
combined = {item.id: item for item in vector_results}
164+
return list(combined.values())
165+
146166
def _graph_recall(
147167
self, parsed_goal: ParsedTaskGoal, memory_scope: str, user_name: str | None = None, **kwargs
148168
) -> list[TextualMemoryItem]:
@@ -270,7 +290,7 @@ def _vector_recall(
270290
query_embedding: list[list[float]],
271291
memory_scope: str,
272292
top_k: int = 20,
273-
max_num: int = 5,
293+
max_num: int = 20,
274294
status: str = "activated",
275295
cube_name: str | None = None,
276296
search_filter: dict | None = None,

src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from pathlib import Path
55

6+
import numpy as np
7+
68
from memos.dependency import require_python_package
79
from memos.log import get_logger
810

@@ -376,3 +378,28 @@ def detect_lang(text):
376378
return "en"
377379
except Exception:
378380
return "en"
381+
382+
383+
def find_best_unrelated_subgroup(sentences: list, similarity_matrix: list, bar: float = 0.8):
384+
assert len(sentences) == len(similarity_matrix)
385+
386+
num_sentence = len(sentences)
387+
selected_sentences = []
388+
selected_indices = []
389+
for i in range(num_sentence):
390+
can_add = True
391+
for j in selected_indices:
392+
if similarity_matrix[i][j] > bar:
393+
can_add = False
394+
break
395+
if can_add:
396+
selected_sentences.append(i)
397+
selected_indices.append(i)
398+
return selected_sentences, selected_indices
399+
400+
401+
def cosine_similarity_matrix(embeddings: list[list[float]]) -> list[list[float]]:
402+
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
403+
x_normalized = embeddings / norms
404+
similarity_matrix = np.dot(x_normalized, x_normalized.T)
405+
return similarity_matrix

0 commit comments

Comments
 (0)