@@ -44,11 +44,26 @@ def retrieve(
4444 raise ValueError (f"Unsupported memory scope: { memory_scope } " )
4545
4646 if memory_scope == "WorkingMemory" :
47- # For working memory, retrieve all entries (no filtering)
48- # TODO: use search filter if exists
47+ # For working memory, retrieve all entries with optional filtering
4948 working_memories = self .graph_store .get_all_memory_items (
5049 scope = "WorkingMemory" , include_embedding = True
5150 )
51+
52+ # Apply search_filter if provided
53+ if search_filter :
54+ filtered_memories = []
55+ for record in working_memories :
56+ metadata = record .get ("metadata" , {})
57+ # Check if all search_filter conditions are met
58+ match = True
59+ for key , value in search_filter .items ():
60+ if metadata .get (key ) != value :
61+ match = False
62+ break
63+ if match :
64+ filtered_memories .append (record )
65+ working_memories = filtered_memories
66+
5267 return [TextualMemoryItem .from_dict (record ) for record in working_memories ]
5368
5469 with concurrent .futures .ThreadPoolExecutor (max_workers = 2 ) as executor :
@@ -131,11 +146,11 @@ def _graph_recall(
131146 self , parsed_goal : ParsedTaskGoal , memory_scope : str , search_filter : dict | None = None
132147 ) -> list [TextualMemoryItem ]:
133148 """
134- TODO: use search filter if exists
135149 Perform structured node-based retrieval from Neo4j.
136150 - keys must match exactly (n.key IN keys)
137151 - tags must overlap with at least 2 input tags
138152 - scope filters by memory_type if provided
153+ - search_filter applies additional metadata filtering
139154 """
140155 candidate_ids = set ()
141156
@@ -179,6 +194,14 @@ def _graph_recall(
179194 overlap = len (set (node_tags ) & set (parsed_goal .tags ))
180195 if overlap >= 2 :
181196 keep = True
197+
198+ # Apply search_filter if provided
199+ if keep and search_filter :
200+ for key , value in search_filter .items ():
201+ if meta .get (key ) != value :
202+ keep = False
203+ break
204+
182205 if keep :
183206 final_nodes .append (TextualMemoryItem .from_dict (node ))
184207 return final_nodes
0 commit comments