Skip to content

Commit 78e14ea

Browse files
authored
feat: update filter mem (#285)
* feat: update filter mem * fix:change top * fix:rm embedding
1 parent 7bb26a9 commit 78e14ea

File tree

1 file changed

+34
-6
lines changed

1 file changed

+34
-6
lines changed

src/memos/mem_os/product.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -704,16 +704,39 @@ def run_async_in_thread():
704704
thread.start()
705705

706706
def _filter_memories_by_threshold(
707-
self, memories: list[TextualMemoryItem], threshold: float = 0.30, min_num: int = 3
707+
self,
708+
memories: list[TextualMemoryItem],
709+
threshold: float = 0.30,
710+
min_num: int = 3,
711+
memory_type: Literal["OuterMemory"] = "OuterMemory",
708712
) -> list[TextualMemoryItem]:
709713
"""
710-
Filter memories by threshold.
714+
Filter memories by threshold and type, at least min_num memories for Non-OuterMemory.
715+
Args:
716+
memories: list[TextualMemoryItem],
717+
threshold: float,
718+
min_num: int,
719+
memory_type: Literal["OuterMemory"],
720+
Returns:
721+
list[TextualMemoryItem]
711722
"""
712723
sorted_memories = sorted(memories, key=lambda m: m.metadata.relativity, reverse=True)
713-
filtered = [m for m in sorted_memories if m.metadata.relativity >= threshold]
724+
filtered_person = [m for m in memories if m.metadata.memory_type != memory_type]
725+
filtered_outer = [m for m in memories if m.metadata.memory_type == memory_type]
726+
filtered = []
727+
per_memory_count = 0
728+
for m in sorted_memories:
729+
if m.metadata.relativity >= threshold:
730+
if m.metadata.memory_type != memory_type:
731+
per_memory_count += 1
732+
filtered.append(m)
714733
if len(filtered) < min_num:
715-
filtered = sorted_memories[:min_num]
716-
return filtered
734+
filtered = filtered_person[:min_num] + filtered_outer[:min_num]
735+
else:
736+
if len(per_memory_count) < min_num:
737+
filtered += filtered_person[per_memory_count:min_num]
738+
filtered_memory = sorted(filtered, key=lambda m: m.metadata.relativity, reverse=True)
739+
return filtered_memory
717740

718741
def register_mem_cube(
719742
self,
@@ -919,6 +942,11 @@ def chat(
919942
if memories_result:
920943
memories_list = memories_result[0]["memories"]
921944
memories_list = self._filter_memories_by_threshold(memories_list, threshold)
945+
new_memories_list = []
946+
for m in memories_list:
947+
m.metadata.embedding = []
948+
new_memories_list.append(m)
949+
memories_list = new_memories_list
922950
system_prompt = super()._build_system_prompt(memories_list, base_prompt)
923951
history_info = []
924952
if history:
@@ -949,7 +977,7 @@ def chat_with_references(
949977
user_id: str,
950978
cube_id: str | None = None,
951979
history: MessageList | None = None,
952-
top_k: int = 10,
980+
top_k: int = 20,
953981
internet_search: bool = False,
954982
moscube: bool = False,
955983
) -> Generator[str, None, None]:

0 commit comments

Comments
 (0)