@@ -22,13 +22,15 @@ def __init__(
2222 graph_store : Neo4jGraphDB ,
2323 embedder : OllamaEmbedder ,
2424 bm25_retriever : EnhancedBM25 | None = None ,
25+ include_embedding : bool = False ,
2526 ):
2627 self .graph_store = graph_store
2728 self .embedder = embedder
2829 self .bm25_retriever = bm25_retriever
2930 self .max_workers = 10
3031 self .filter_weight = 0.6
3132 self .use_bm25 = bool (self .bm25_retriever )
33+ self .include_embedding = include_embedding
3234
3335 def retrieve (
3436 self ,
@@ -72,7 +74,7 @@ def retrieve(
7274 # For working memory, retrieve all entries (no session-oriented filtering)
7375 working_memories = self .graph_store .get_all_memory_items (
7476 scope = "WorkingMemory" ,
75- include_embedding = False ,
77+ include_embedding = self . include_embedding ,
7678 user_name = user_name ,
7779 filter = search_filter ,
7880 )
@@ -244,7 +246,9 @@ def process_node(node):
244246 return []
245247
246248 # Load nodes and post-filter
247- node_dicts = self .graph_store .get_nodes (list (candidate_ids ), include_embedding = False )
249+ node_dicts = self .graph_store .get_nodes (
250+ list (candidate_ids ), include_embedding = self .include_embedding
251+ )
248252
249253 final_nodes = []
250254 for node in node_dicts :
@@ -291,7 +295,7 @@ def process_node(node):
291295
292296 # Load nodes and post-filter
293297 node_dicts = self .graph_store .get_nodes (
294- list (candidate_ids ), include_embedding = False , user_name = user_name
298+ list (candidate_ids ), include_embedding = self . include_embedding , user_name = user_name
295299 )
296300
297301 final_nodes = []
@@ -385,7 +389,10 @@ def search_path_b():
385389 unique_ids = {r ["id" ] for r in all_hits if r .get ("id" )}
386390 node_dicts = (
387391 self .graph_store .get_nodes (
388- list (unique_ids ), include_embedding = False , cube_name = cube_name , user_name = user_name
392+ list (unique_ids ),
393+ include_embedding = self .include_embedding ,
394+ cube_name = cube_name ,
395+ user_name = user_name ,
389396 )
390397 or []
391398 )
@@ -416,7 +423,9 @@ def _bm25_recall(
416423 key_filters .append ({"field" : key , "op" : "=" , "value" : value })
417424 corpus_name += "" .join (list (search_filter .values ()))
418425 candidate_ids = self .graph_store .get_by_metadata (key_filters , user_name = user_name )
419- node_dicts = self .graph_store .get_nodes (list (candidate_ids ), include_embedding = False )
426+ node_dicts = self .graph_store .get_nodes (
427+ list (candidate_ids ), include_embedding = self .include_embedding
428+ )
420429
421430 bm25_query = " " .join (list ({query , * parsed_goal .keys }))
422431 bm25_results = self .bm25_retriever .search (
@@ -471,7 +480,10 @@ def _fulltext_recall(
471480 unique_ids = {r ["id" ] for r in all_hits if r .get ("id" )}
472481 node_dicts = (
473482 self .graph_store .get_nodes (
474- list (unique_ids ), include_embedding = False , cube_name = cube_name , user_name = user_name
483+ list (unique_ids ),
484+ include_embedding = self .include_embedding ,
485+ cube_name = cube_name ,
486+ user_name = user_name ,
475487 )
476488 or []
477489 )
0 commit comments