Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/memos/api/handlers/formatters_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def to_iter(running: Any) -> list[Any]:
return list(running) if running else []


def format_memory_item(memory_data: Any) -> dict[str, Any]:
def format_memory_item(memory_data: Any, include_embedding: bool = False) -> dict[str, Any]:
"""
Format a single memory item for API response.

Expand All @@ -47,7 +47,8 @@ def format_memory_item(memory_data: Any) -> dict[str, Any]:
ref_id = f"[{memory_id.split('-')[0]}]"

memory["ref_id"] = ref_id
memory["metadata"]["embedding"] = []
if not include_embedding:
memory["metadata"]["embedding"] = []
memory["metadata"]["sources"] = []
memory["metadata"]["usage"] = []
memory["metadata"]["ref_id"] = ref_id
Expand Down
102 changes: 102 additions & 0 deletions src/memos/api/handlers/search_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@
using dependency injection for better modularity and testability.
"""

from typing import Any

from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies
from memos.api.product_models import APISearchRequest, SearchResponse
from memos.log import get_logger
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import (
cosine_similarity_matrix,
)
from memos.multi_mem_cube.composite_cube import CompositeCubeView
from memos.multi_mem_cube.single_cube import SingleCubeView
from memos.multi_mem_cube.views import MemCubeView
Expand Down Expand Up @@ -50,9 +55,19 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse
"""
self.logger.info(f"[SearchHandler] Search Req is: {search_req}")

# Increase recall pool if deduplication is enabled to ensure diversity
original_top_k = search_req.top_k
if search_req.dedup == "sim":
search_req.top_k = original_top_k * 5

cube_view = self._build_cube_view(search_req)

results = cube_view.search_memories(search_req)
if search_req.dedup == "sim":
results = self._dedup_text_memories(results, original_top_k)
self._strip_embeddings(results)
# Restore original top_k for downstream logic or response metadata
search_req.top_k = original_top_k

self.logger.info(
f"[SearchHandler] Final search results: count={len(results)} results={results}"
Expand All @@ -63,6 +78,93 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse
data=results,
)

def _dedup_text_memories(self, results: dict[str, Any], target_top_k: int) -> dict[str, Any]:
buckets = results.get("text_mem", [])
if not buckets:
return results

flat: list[tuple[int, dict[str, Any], float]] = []
for bucket_idx, bucket in enumerate(buckets):
for mem in bucket.get("memories", []):
score = mem.get("metadata", {}).get("relativity", 0.0)
flat.append((bucket_idx, mem, score))

if len(flat) <= 1:
return results

embeddings = self._extract_embeddings([mem for _, mem, _ in flat])
if embeddings is None:
documents = [mem.get("memory", "") for _, mem, _ in flat]
embeddings = self.searcher.embedder.embed(documents)

similarity_matrix = cosine_similarity_matrix(embeddings)

indices_by_bucket: dict[int, list[int]] = {i: [] for i in range(len(buckets))}
for flat_index, (bucket_idx, _, _) in enumerate(flat):
indices_by_bucket[bucket_idx].append(flat_index)

selected_global: list[int] = []
selected_by_bucket: dict[int, list[int]] = {i: [] for i in range(len(buckets))}

ordered_indices = sorted(range(len(flat)), key=lambda idx: flat[idx][2], reverse=True)
for idx in ordered_indices:
bucket_idx = flat[idx][0]
if len(selected_by_bucket[bucket_idx]) >= target_top_k:
continue
# Use 0.92 threshold strictly
if self._is_unrelated(idx, selected_global, similarity_matrix, 0.92):
selected_by_bucket[bucket_idx].append(idx)
selected_global.append(idx)

# Removed the 'filling' logic that was pulling back similar items.
# Now it will only return items that truly pass the 0.92 threshold,
# up to target_top_k.

for bucket_idx, bucket in enumerate(buckets):
selected_indices = selected_by_bucket.get(bucket_idx, [])
bucket["memories"] = [flat[i][1] for i in selected_indices]
return results

@staticmethod
def _is_unrelated(
index: int,
selected_indices: list[int],
similarity_matrix: list[list[float]],
similarity_threshold: float,
) -> bool:
return all(similarity_matrix[index][j] <= similarity_threshold for j in selected_indices)

@staticmethod
def _max_similarity(
index: int, selected_indices: list[int], similarity_matrix: list[list[float]]
) -> float:
if not selected_indices:
return 0.0
return max(similarity_matrix[index][j] for j in selected_indices)

@staticmethod
def _extract_embeddings(memories: list[dict[str, Any]]) -> list[list[float]] | None:
embeddings: list[list[float]] = []
for mem in memories:
embedding = mem.get("metadata", {}).get("embedding")
if not embedding:
return None
embeddings.append(embedding)
return embeddings

@staticmethod
def _strip_embeddings(results: dict[str, Any]) -> None:
for bucket in results.get("text_mem", []):
for mem in bucket.get("memories", []):
metadata = mem.get("metadata", {})
if "embedding" in metadata:
metadata["embedding"] = []
for bucket in results.get("tool_mem", []):
for mem in bucket.get("memories", []):
metadata = mem.get("metadata", {})
if "embedding" in metadata:
metadata["embedding"] = []

def _resolve_cube_ids(self, search_req: APISearchRequest) -> list[str]:
"""
Normalize target cube ids from search_req.
Expand Down
9 changes: 9 additions & 0 deletions src/memos/api/product_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,15 @@ class APISearchRequest(BaseRequest):
description="Number of textual memories to retrieve (top-K). Default: 10.",
)

dedup: Literal["no", "sim"] | None = Field(
None,
description=(
"Optional dedup option for textual memories. "
"Use 'no' for no dedup, 'sim' for similarity dedup. "
"If None, default exact-text dedup is applied."
),
)

pref_top_k: int = Field(
6,
ge=0,
Expand Down
3 changes: 2 additions & 1 deletion src/memos/mem_reader/multi_modal_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from memos.mem_reader.read_multi_modal import MultiModalParser, detect_lang
from memos.mem_reader.read_multi_modal.base import _derive_key
from memos.mem_reader.simple_struct import PROMPT_DICT, SimpleStructMemReader
from memos.mem_reader.utils import parse_json_result
from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata
from memos.templates.tool_mem_prompts import TOOL_TRAJECTORY_PROMPT_EN, TOOL_TRAJECTORY_PROMPT_ZH
from memos.types import MessagesType
Expand Down Expand Up @@ -377,7 +378,7 @@ def _get_llm_response(
messages = [{"role": "user", "content": prompt}]
try:
response_text = self.llm.generate(messages)
response_json = self.parse_json_result(response_text)
response_json = parse_json_result(response_text)
except Exception as e:
logger.error(f"[LLM] Exception during chat generation: {e}")
response_json = {
Expand Down
11 changes: 9 additions & 2 deletions src/memos/mem_scheduler/optimized_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,14 @@ def mix_search_memories(
info=info,
search_tool_memory=search_req.search_tool_memory,
tool_mem_top_k=search_req.tool_mem_top_k,
dedup=search_req.dedup,
)
memories = merged_memories[: search_req.top_k]

formatted_memories = [format_textual_memory_item(item) for item in memories]
formatted_memories = [
format_textual_memory_item(item, include_embedding=search_req.dedup == "sim")
for item in memories
]
self.submit_memory_history_async_task(
search_req=search_req,
user_context=user_context,
Expand Down Expand Up @@ -233,7 +237,10 @@ def update_search_memories_to_redis(
mem_cube=self.mem_cube,
mode=SearchMode.FAST,
)
formatted_memories = [format_textual_memory_item(data) for data in memories]
formatted_memories = [
format_textual_memory_item(data, include_embedding=search_req.dedup == "sim")
for data in memories
]
else:
memories = [
TextualMemoryItem.from_dict(one) for one in memories_to_store["memories"]
Expand Down
5 changes: 3 additions & 2 deletions src/memos/mem_scheduler/utils/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from memos.memories.textual.tree import TextualMemoryItem


def format_textual_memory_item(memory_data: Any) -> dict[str, Any]:
def format_textual_memory_item(memory_data: Any, include_embedding: bool = False) -> dict[str, Any]:
"""Format a single memory item for API response."""
memory = memory_data.model_dump()
memory_id = memory["id"]
ref_id = f"[{memory_id.split('-')[0]}]"

memory["ref_id"] = ref_id
memory["metadata"]["embedding"] = []
if not include_embedding:
memory["metadata"]["embedding"] = []
memory["metadata"]["sources"] = []
memory["metadata"]["ref_id"] = ref_id
memory["metadata"]["id"] = memory_id
Expand Down
2 changes: 2 additions & 0 deletions src/memos/memories/textual/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def search(
user_name: str | None = None,
search_tool_memory: bool = False,
tool_mem_top_k: int = 6,
dedup: str | None = None,
**kwargs,
) -> list[TextualMemoryItem]:
"""Search for memories based on a query.
Expand Down Expand Up @@ -207,6 +208,7 @@ def search(
user_name=user_name,
search_tool_memory=search_tool_memory,
tool_mem_top_k=tool_mem_top_k,
dedup=dedup,
**kwargs,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from pathlib import Path
from typing import Any


import numpy as np

from memos.dependency import require_python_package
Expand Down
51 changes: 7 additions & 44 deletions src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,13 @@ def post_retrieve(
info=None,
search_tool_memory: bool = False,
tool_mem_top_k: int = 6,
dedup: str | None = None,
plugin=False,
):
deduped = self._deduplicate_results(retrieved_results)
if dedup == "no":
deduped = retrieved_results
else:
deduped = self._deduplicate_results(retrieved_results)
final_results = self._sort_and_trim(
deduped, top_k, plugin, search_tool_memory, tool_mem_top_k
)
Expand All @@ -141,6 +145,7 @@ def search(
user_name: str | None = None,
search_tool_memory: bool = False,
tool_mem_top_k: int = 6,
dedup: str | None = None,
**kwargs,
) -> list[TextualMemoryItem]:
"""
Expand Down Expand Up @@ -202,6 +207,7 @@ def search(
plugin=kwargs.get("plugin", False),
search_tool_memory=search_tool_memory,
tool_mem_top_k=tool_mem_top_k,
dedup=dedup,
)

logger.info(f"[SEARCH] Done. Total {len(final_results)} results.")
Expand Down Expand Up @@ -284,49 +290,6 @@ def _parse_task(

return parsed_goal, query_embedding, context, query

@timed
def _retrieve_simple(
self,
query: str,
top_k: int,
search_filter: dict | None = None,
user_name: str | None = None,
**kwargs,
):
"""Retrieve from by keywords and embedding"""
query_words = []
if self.tokenizer:
query_words = self.tokenizer.tokenize_mixed(query)
else:
query_words = query.strip().split()
query_words = [query, *query_words]
logger.info(f"[SIMPLESEARCH] Query words: {query_words}")
query_embeddings = self.embedder.embed(query_words)

items = self.graph_retriever.retrieve_from_mixed(
top_k=top_k * 2,
memory_scope=None,
query_embedding=query_embeddings,
search_filter=search_filter,
user_name=user_name,
use_fast_graph=self.use_fast_graph,
)
logger.info(f"[SIMPLESEARCH] Items count: {len(items)}")
documents = [getattr(item, "memory", "") for item in items]
documents_embeddings = self.embedder.embed(documents)
similarity_matrix = cosine_similarity_matrix(documents_embeddings)
selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix)
selected_items = [items[i] for i in selected_indices]
logger.info(
f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}"
)
return self.reranker.rerank(
query=query,
query_embedding=query_embeddings[0],
graph_results=selected_items,
top_k=top_k,
)

@timed
def _retrieve_paths(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from memos.llms.base import BaseLLM
from memos.log import get_logger
from memos.memories.textual.tree_text_memory.retrieve.retrieval_mid_structs import ParsedTaskGoal
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import (
FastTokenizer,
parse_json_result,
)
from memos.memories.textual.tree_text_memory.retrieve.utils import TASK_PARSE_PROMPT


Expand Down Expand Up @@ -111,8 +114,10 @@ def _parse_response(self, response: str, **kwargs) -> ParsedTaskGoal:
for attempt_times in range(attempts):
try:
context = kwargs.get("context", "")
response = response.replace("```", "").replace("json", "").strip()
response_json = eval(response)
response_json = parse_json_result(response)
if not response_json:
raise ValueError("Parsed JSON is empty")

return ParsedTaskGoal(
memories=response_json.get("memories", []),
keys=response_json.get("keys", []),
Expand All @@ -123,6 +128,8 @@ def _parse_response(self, response: str, **kwargs) -> ParsedTaskGoal:
context=context,
)
except Exception as e:
raise ValueError(
f"Failed to parse LLM output: {e}\nRaw response:\n{response} retried: {attempt_times + 1}/{attempts + 1}"
) from e
if attempt_times == attempts - 1:
raise ValueError(
f"Failed to parse LLM output: {e}\nRaw response:\n{response} retried: {attempt_times + 1}/{attempts}"
) from e
continue
Loading