33==========================================
44
55Builds context combining all 6 MIRIX memory types for task execution.
6+ Now with entity-based associative retrieval for connected reasoning.
67
78Based on: MIRIX Active Retrieval Mechanism (arXiv:2507.07957)
89Retrieval scoring: Stanford Generative Agents (recency + importance + relevance)
10+ Entity connections: Anthropic context patterns + Mem0 graph memory
911"""
1012
1113from __future__ import annotations
1214
1315import logging
1416from dataclasses import dataclass , field
15- from typing import Dict , List , Optional , TYPE_CHECKING
17+ from typing import Dict , List , Optional , Set , TYPE_CHECKING
1618
1719from episodic_memory .models .memory import Memory , MemoryType
1820
1921if TYPE_CHECKING :
20- from core .memory_store import MemoryStore
22+ from episodic_memory . core .memory_store import MemoryStore
2123
2224
2325logger = logging .getLogger (__name__ )
2426
27+ # Lazy import entity index to avoid circular imports
28+ _entity_extractor = None
29+ _entity_index = None
30+
31+
32+ def _get_entity_extractor ():
33+ """Lazy load entity extractor."""
34+ global _entity_extractor
35+ if _entity_extractor is None :
36+ from episodic_memory .core .entity_index import EntityExtractor
37+ _entity_extractor = EntityExtractor ()
38+ return _entity_extractor
39+
40+
41+ def _get_entity_index ():
42+ """Lazy load entity index."""
43+ global _entity_index
44+ if _entity_index is None :
45+ from episodic_memory .core .entity_index import EntityIndex
46+ _entity_index = EntityIndex .load_from_disk () or EntityIndex ()
47+ return _entity_index
48+
2549
2650# Default top-K per memory type
2751DEFAULT_TOP_K : Dict [MemoryType , int ] = {
3761@dataclass
3862class MemoryContext : # pylint: disable=too-many-instance-attributes
3963 """
40- Context combining all 6 MIRIX memory types.
64+ Context combining all 6 MIRIX memory types + connected memories .
4165
4266 Attributes:
4367 core: Identity + user facts (always included)
@@ -46,7 +70,9 @@ class MemoryContext: # pylint: disable=too-many-instance-attributes
4670 procedural: Applicable workflows/skills
4771 resource: Referenced documents
4872 vault: High-confidence consolidated memories
73+ connected: Memories connected via shared entities (associative)
4974 retrieval_scores: Score per retrieved memory
75+ entity_connections: Mapping of memory_id -> shared entities
5076 task: Original task query
5177 """
5278
@@ -56,7 +82,9 @@ class MemoryContext: # pylint: disable=too-many-instance-attributes
5682 procedural : List [Memory ] = field (default_factory = list )
5783 resource : List [Memory ] = field (default_factory = list )
5884 vault : List [Memory ] = field (default_factory = list )
85+ connected : List [Memory ] = field (default_factory = list )
5986 retrieval_scores : Dict [str , float ] = field (default_factory = dict )
87+ entity_connections : Dict [str , List [str ]] = field (default_factory = dict )
6088 task : str = ""
6189
6290 def to_prompt_context (self ) -> str :
@@ -92,6 +120,17 @@ def to_prompt_context(self) -> str:
92120 vault_content = "\n " .join (f"- { m .content } " for m in self .vault )
93121 sections .append (f"[VAULT (HIGH CONFIDENCE)]\n { vault_content } " )
94122
123+ if self .connected :
124+ # Format connected memories with their shared entities
125+ connected_lines = []
126+ for m in self .connected :
127+ shared = self .entity_connections .get (m .memory_id , [])
128+ if shared :
129+ connected_lines .append (f"- { m .content } (via: { ', ' .join (shared )} )" )
130+ else :
131+ connected_lines .append (f"- { m .content } " )
132+ sections .append (f"[CONNECTED (ASSOCIATED MEMORIES)]\n " + "\n " .join (connected_lines ))
133+
95134 return "\n \n " .join (sections )
96135
97136 def total_memories (self ) -> int :
@@ -102,7 +141,8 @@ def total_memories(self) -> int:
102141 len (self .semantic ) +
103142 len (self .procedural ) +
104143 len (self .resource ) +
105- len (self .vault )
144+ len (self .vault ) +
145+ len (self .connected )
106146 )
107147
108148 def to_dict (self ) -> Dict :
@@ -115,7 +155,9 @@ def to_dict(self) -> Dict:
115155 "procedural" : [m .model_dump () for m in self .procedural ],
116156 "resource" : [m .model_dump () for m in self .resource ],
117157 "vault" : [m .model_dump () for m in self .vault ],
158+ "connected" : [m .model_dump () for m in self .connected ],
118159 "retrieval_scores" : self .retrieval_scores ,
160+ "entity_connections" : self .entity_connections ,
119161 "total_memories" : self .total_memories (),
120162 }
121163
@@ -125,48 +167,57 @@ class ContextBuilder: # pylint: disable=too-few-public-methods
125167 Builds multi-type memory context for tasks.
126168
127169 Based on: MIRIX Active Retrieval Mechanism.
128- Uses keyword matching for now, extensible to vector search .
170+ Uses keyword matching + entity-based associative retrieval .
129171 """
130172
131173 def __init__ (
132174 self ,
133175 memory_store : "MemoryStore" ,
134- top_k : Optional [Dict [MemoryType , int ]] = None
176+ top_k : Optional [Dict [MemoryType , int ]] = None ,
177+ max_connected : int = 5
135178 ) -> None :
136179 """
137180 Initialize context builder.
138181
139182 Args:
140183 memory_store: Memory store instance
141184 top_k: Optional custom top-K per type (defaults to DEFAULT_TOP_K)
185+ max_connected: Maximum connected memories to retrieve
142186 """
143187 self .store = memory_store
144188 self .top_k = top_k or DEFAULT_TOP_K .copy ()
189+ self .max_connected = max_connected
145190
146191 async def get_context_for_task (self , task : str ) -> MemoryContext :
147192 """
148- Build context combining all 6 MIRIX memory types.
193+ Build context combining all 6 MIRIX memory types + connected memories .
149194
150195 Pipeline:
151196 1. Extract keywords from task
152- 2. Search each memory type (keyword match + importance sort)
153- 3. Compute retrieval scores
154- 4. Return combined MemoryContext
197+ 2. Extract entities from task (for associative retrieval)
198+ 3. Search each memory type (keyword match + importance sort)
199+ 4. Find connected memories via entity index
200+ 5. Compute retrieval scores
201+ 6. Return combined MemoryContext
155202
156203 Args:
157204 task: Task description
158205
159206 Returns:
160- MemoryContext with memories from all 6 types
207+ MemoryContext with memories from all 6 types + connections
161208
162209 Example:
163- >>> context = await builder.get_context_for_task("write unit tests ")
164- >>> len(context.procedural)
165- 5
210+ >>> context = await builder.get_context_for_task("what did Juan say about memory? ")
211+ >>> len(context.connected) # Memories connected via "Juan" entity
212+ 3
166213 """
167214 keywords = self ._extract_keywords (task )
168215 results : Dict [MemoryType , List [Memory ]] = {}
169216 scores : Dict [str , float ] = {}
217+
218+ # Extract entities from task for associative retrieval
219+ extractor = _get_entity_extractor ()
220+ task_entities = extractor .extract (task )
170221
171222 # Search each memory type
172223 for memory_type in MemoryType :
@@ -189,6 +240,13 @@ async def get_context_for_task(self, task: str) -> MemoryContext:
189240 scores [mem .memory_id ] = self ._compute_retrieval_score (mem , keywords )
190241 mem .record_access () # Track access for decay boosting
191242
243+ # Find connected memories via entity index
244+ connected_memories , entity_connections = await self ._find_connected_memories (
245+ task_entities ,
246+ results ,
247+ scores
248+ )
249+
192250 # Build context
193251 context = MemoryContext (
194252 core = results .get (MemoryType .CORE , []),
@@ -197,15 +255,70 @@ async def get_context_for_task(self, task: str) -> MemoryContext:
197255 procedural = results .get (MemoryType .PROCEDURAL , []),
198256 resource = results .get (MemoryType .RESOURCE , []),
199257 vault = results .get (MemoryType .VAULT , []),
258+ connected = connected_memories ,
200259 retrieval_scores = scores ,
260+ entity_connections = entity_connections ,
201261 task = task ,
202262 )
203263
204264 logger .info (
205- "Built context for task '%s': %d memories" ,
206- task [:50 ], context .total_memories ()
265+ "Built context for task '%s': %d memories (%d connected via entities) " ,
266+ task [:50 ], context .total_memories (), len ( connected_memories )
207267 )
208268 return context
269+
270+ async def _find_connected_memories (
271+ self ,
272+ task_entities : List [str ],
273+ direct_results : Dict [MemoryType , List [Memory ]],
274+ scores : Dict [str , float ]
275+ ) -> tuple [List [Memory ], Dict [str , List [str ]]]:
276+ """
277+ Find memories connected via shared entities.
278+
279+ Args:
280+ task_entities: Entities extracted from task
281+ direct_results: Already retrieved memories (to exclude)
282+ scores: Scores dict to update
283+
284+ Returns:
285+ Tuple of (connected_memories, entity_connections)
286+ """
287+ if not task_entities :
288+ return [], {}
289+
290+ index = _get_entity_index ()
291+
292+ # Get IDs of already retrieved memories
293+ already_retrieved : Set [str ] = set ()
294+ for memories in direct_results .values ():
295+ already_retrieved .update (m .memory_id for m in memories )
296+
297+ # Find related memories via entity index
298+ related = index .get_related_memories (
299+ task_entities ,
300+ exclude_ids = already_retrieved ,
301+ min_overlap = 1
302+ )
303+
304+ connected_memories : List [Memory ] = []
305+ entity_connections : Dict [str , List [str ]] = {}
306+
307+ for memory_id , overlap_count in related [:self .max_connected ]:
308+ memory = self .store ._storage .get (memory_id ) # pylint: disable=protected-access
309+ if memory :
310+ connected_memories .append (memory )
311+ # Find which entities connect this memory
312+ memory_entities = set (memory .context .get ("entities" , []))
313+ shared = list (memory_entities & set (task_entities ))
314+ entity_connections [memory_id ] = shared
315+
316+ # Compute and store score
317+ base_score = 0.5 + (overlap_count * 0.1 ) # Boost for entity overlap
318+ scores [memory_id ] = min (1.0 , base_score )
319+ memory .record_access ()
320+
321+ return connected_memories , entity_connections
209322
210323 async def _search_type (
211324 self ,
0 commit comments