Skip to content

Commit 16ec612

Browse files
committed
feat: add filter in working memory
1 parent f2d275b commit 16ec612

File tree

1 file changed

+26
-3
lines changed
  • src/memos/memories/textual/tree_text_memory/retrieve

1 file changed

+26
-3
lines changed

src/memos/memories/textual/tree_text_memory/retrieve/recall.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,26 @@ def retrieve(
4444
raise ValueError(f"Unsupported memory scope: {memory_scope}")
4545

4646
if memory_scope == "WorkingMemory":
47-
# For working memory, retrieve all entries (no filtering)
48-
# TODO: use search filter if exists
47+
# For working memory, retrieve all entries with optional filtering
4948
working_memories = self.graph_store.get_all_memory_items(
5049
scope="WorkingMemory", include_embedding=True
5150
)
51+
52+
# Apply search_filter if provided
53+
if search_filter:
54+
filtered_memories = []
55+
for record in working_memories:
56+
metadata = record.get("metadata", {})
57+
# Check if all search_filter conditions are met
58+
match = True
59+
for key, value in search_filter.items():
60+
if metadata.get(key) != value:
61+
match = False
62+
break
63+
if match:
64+
filtered_memories.append(record)
65+
working_memories = filtered_memories
66+
5267
return [TextualMemoryItem.from_dict(record) for record in working_memories]
5368

5469
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
@@ -131,11 +146,11 @@ def _graph_recall(
131146
self, parsed_goal: ParsedTaskGoal, memory_scope: str, search_filter: dict | None = None
132147
) -> list[TextualMemoryItem]:
133148
"""
134-
TODO: use search filter if exists
135149
Perform structured node-based retrieval from Neo4j.
136150
- keys must match exactly (n.key IN keys)
137151
- tags must overlap with at least 2 input tags
138152
- scope filters by memory_type if provided
153+
- search_filter applies additional metadata filtering
139154
"""
140155
candidate_ids = set()
141156

@@ -179,6 +194,14 @@ def _graph_recall(
179194
overlap = len(set(node_tags) & set(parsed_goal.tags))
180195
if overlap >= 2:
181196
keep = True
197+
198+
# Apply search_filter if provided
199+
if keep and search_filter:
200+
for key, value in search_filter.items():
201+
if meta.get(key) != value:
202+
keep = False
203+
break
204+
182205
if keep:
183206
final_nodes.append(TextualMemoryItem.from_dict(node))
184207
return final_nodes

0 commit comments

Comments
 (0)