Skip to content

Commit 8db52c1

Browse files
authored
Merge pull request #109 from psbuilds/fix/episodic-memory
fix: store graded EpisodicMemory entries as MemoryEntry objects and use correct llm instance
2 parents f888a0a + e124ec8 commit 8db52c1

File tree

3 files changed

+325
-98
lines changed

3 files changed

+325
-98
lines changed

mesa_llm/memory/episodic_memory.py

Lines changed: 128 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,60 @@ class EventGrade(BaseModel):
1414
grade: int
1515

1616

17+
def normalize_dict_values(scores: dict, min_target: float, max_target: float) -> dict:
18+
"""
19+
Normalize dictionary values to a target range with min-max scaling.
20+
21+
This mirrors the min-max helper used in the Generative Agents reference
22+
retrieval implementation:
23+
https://github.com/joonspk-research/generative_agents/blob/main/reverie/backend_server/persona/cognitive_modules/retrieve.py
24+
"""
25+
if not scores:
26+
return {}
27+
28+
vals = list(scores.values())
29+
min_val = min(vals)
30+
max_val = max(vals)
31+
32+
range_val = max_val - min_val
33+
34+
if range_val == 0:
35+
midpoint = (max_target - min_target) / 2 + min_target
36+
for key in scores:
37+
scores[key] = midpoint
38+
else:
39+
for key, val in scores.items():
40+
scores[key] = (val - min_val) * (
41+
max_target - min_target
42+
) / range_val + min_target
43+
44+
return scores
45+
46+
1747
class EpisodicMemory(Memory):
1848
"""
19-
Stores memories based on event importance scoring. Each new memory entry is evaluated by a LLM
20-
for its relevance and importance (1-5 scale) relative to the agent's current task and previous
21-
experiences. Based on a Stanford/DeepMind paper:
22-
[Generative Agents: Interactive Simulacra of Human Behavior](https://arxiv.org/pdf/2304.03442)
49+
Event-level memory with LLM-based importance scoring and recency-aware retrieval.
50+
51+
Credit / references:
52+
- Paper: Generative Agents: Interactive Simulacra of Human Behavior
53+
https://arxiv.org/abs/2304.03442
54+
- Reference retrieval code:
55+
https://github.com/joonspk-research/generative_agents/blob/main/reverie/backend_server/persona/cognitive_modules/retrieve.py
56+
57+
This implementation is inspired by the paper's retrieval scoring design
58+
(component-wise min-max normalization, then weighted combination). It is
59+
not a strict copy of the original code: relevance scoring via embeddings is
60+
not implemented yet, and recency is computed from step age.
2361
"""
2462

2563
def __init__(
2664
self,
2765
agent: "LLMAgent",
2866
llm_model: str | None = None,
2967
display: bool = True,
30-
max_capacity: int = 10,
31-
considered_entries: int = 5,
68+
max_capacity: int = 200,
69+
considered_entries: int = 30,
70+
recency_decay: float = 0.995,
3271
):
3372
"""
3473
Initialize the EpisodicMemory
@@ -43,6 +82,7 @@ def __init__(
4382
self.max_capacity = max_capacity
4483
self.memory_entries = deque(maxlen=self.max_capacity)
4584
self.considered_entries = considered_entries
85+
self.recency_decay = recency_decay
4686

4787
self.system_prompt = """
4888
You are an assistant that evaluates memory entries on a scale from 1 to 5, based on their importance to a specific problem or task. Your goal is to assign a score that reflects how much each entry contributes to understanding, solving, or advancing the task. Use the following grading scale:
@@ -60,6 +100,24 @@ def __init__(
60100
Only assess based on the entry's content and its value to the task at hand. Ignore style, grammar, or tone.
61101
"""
62102

103+
def _extract_importance(self, entry) -> int:
104+
"""
105+
Safely extracts importance score regardless of data structure.
106+
Handles:
107+
- Nested: {"msg": {"importance": 5}}
108+
- Flat: {"importance": 5}
109+
"""
110+
if "importance" in entry.content:
111+
val = entry.content["importance"]
112+
return val if isinstance(val, (int, float)) else 1
113+
114+
for value in entry.content.values():
115+
if isinstance(value, dict) and "importance" in value:
116+
val = value["importance"]
117+
return val if isinstance(val, (int, float)) else 1
118+
119+
return 1
120+
63121
def _build_grade_prompt(self, type: str, content: dict) -> str:
64122
"""
65123
This helper assembles a prompt that includes the event type, event content,
@@ -89,7 +147,7 @@ def grade_event_importance(self, type: str, content: dict) -> float:
89147
prompt = self._build_grade_prompt(type, content)
90148
self.llm.system_prompt = self.system_prompt
91149

92-
rsp = self.agent.llm.generate(
150+
rsp = self.llm.generate(
93151
prompt=prompt,
94152
response_format=EventGrade,
95153
)
@@ -104,7 +162,7 @@ async def agrade_event_importance(self, type: str, content: dict) -> float:
104162
prompt = self._build_grade_prompt(type, content)
105163
self.llm.system_prompt = self.system_prompt
106164

107-
rsp = await self.agent.llm.agenerate(
165+
rsp = await self.llm.agenerate(
108166
prompt=prompt,
109167
response_format=EventGrade,
110168
)
@@ -114,30 +172,68 @@ async def agrade_event_importance(self, type: str, content: dict) -> float:
114172

115173
def retrieve_top_k_entries(self, k: int) -> list[MemoryEntry]:
116174
"""
117-
Retrieve the top k entries based on the importance and recency
175+
Retrieve the top-k entries using normalized importance and recency.
176+
177+
Notes:
178+
- Inspired by Generative Agents retrieval scoring:
179+
recency/importance/relevance are normalized separately and combined.
180+
- This implementation currently combines importance + recency only.
181+
Relevance (embedding cosine similarity with a focal query) is pending.
118182
"""
119-
top_list = sorted(
120-
self.memory_entries,
121-
key=lambda x: x.content["importance"] - (self.agent.model.steps - x.step),
122-
reverse=True,
123-
)
183+
if not self.memory_entries:
184+
return []
185+
186+
importance_dict = {}
187+
recency_dict = {}
188+
189+
entries = list(self.memory_entries)
190+
current_step = self.agent.model.steps
191+
192+
for i, entry in enumerate(entries):
193+
importance_dict[i] = self._extract_importance(entry)
194+
195+
age = current_step - entry.step
196+
recency_dict[i] = self.recency_decay**age
197+
198+
importance_scaled = normalize_dict_values(importance_dict, 0, 1)
199+
recency_scaled = normalize_dict_values(recency_dict, 0, 1)
124200

125-
return top_list[:k]
201+
final_scores = []
202+
for i in range(len(entries)):
203+
total_score = importance_scaled[i] + recency_scaled[i]
204+
final_scores.append((total_score, entries[i]))
205+
206+
final_scores.sort(key=lambda x: x[0], reverse=True)
207+
return [entry for _, entry in final_scores[:k]]
208+
209+
def _finalize_entry(self, type: str, graded_content: dict):
210+
"""Create and persist a finalized episodic entry."""
211+
new_entry = MemoryEntry(
212+
agent=self.agent,
213+
content={type: graded_content},
214+
step=self.agent.model.steps,
215+
)
216+
self.memory_entries.append(new_entry)
126217

127218
def add_to_memory(self, type: str, content: dict):
128219
"""
129-
Add a new memory entry to the memory
220+
grading logic + adding to memory function call
130221
"""
131-
content["importance"] = self.grade_event_importance(type, content)
132-
133-
super().add_to_memory(type, content)
222+
graded_content = {
223+
**content,
224+
"importance": self.grade_event_importance(type, content),
225+
}
226+
self._finalize_entry(type, graded_content)
134227

135228
async def aadd_to_memory(self, type: str, content: dict):
136229
"""
137-
Async version of add_to_memory
230+
Async version of add_to_memory + grading logic
138231
"""
139-
content["importance"] = await self.agrade_event_importance(type, content)
140-
super().add_to_memory(type, content)
232+
graded_content = {
233+
**content,
234+
"importance": await self.agrade_event_importance(type, content),
235+
}
236+
self._finalize_entry(type, graded_content)
141237

142238
def get_prompt_ready(self) -> str:
143239
return f"Top {self.considered_entries} memory entries:\n\n" + "\n".join(
@@ -161,20 +257,18 @@ def get_communication_history(self) -> str:
161257

162258
async def aprocess_step(self, pre_step: bool = False):
163259
"""
164-
Asynchronous version of process_step
260+
Asynchronous version of process_step.
261+
262+
EpisodicMemory persists entries at add-time and does not use two-phase
263+
pre/post-step buffering.
165264
"""
166-
if pre_step:
167-
await self.aadd_to_memory(type="observation", content=self.step_content)
168-
self.step_content = {}
169-
return
265+
return
170266

171267
def process_step(self, pre_step: bool = False):
172268
"""
173-
Process the step of the agent :
174-
- Add the new entry to the memory
175-
- Display the new entry
269+
Process step hook (no-op for episodic memory).
270+
271+
EpisodicMemory persists entries at add-time and does not use two-phase
272+
pre/post-step buffering.
176273
"""
177-
if pre_step:
178-
self.add_to_memory(type="observation", content=self.step_content)
179-
self.step_content = {}
180-
return
274+
return

tests/test_integration/test_memory_reasoning.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -280,12 +280,14 @@ def test_plan_records_to_memory(self, monkeypatch):
280280
plan = reasoning.plan(obs=obs)
281281

282282
assert isinstance(plan, Plan)
283-
assert memory.step_content["Observation"]["content"] == str(obs)
284-
assert memory.step_content["Plan"]["content"] == plan_content
285-
assert memory.step_content["Plan-Execution"]["content"] == str(plan)
286-
assert memory.step_content["Observation"]["importance"] == 3
287-
assert memory.step_content["Plan"]["importance"] == 3
288-
assert memory.step_content["Plan-Execution"]["importance"] == 3
283+
entries = list(memory.memory_entries)
284+
assert len(entries) == 3
285+
assert entries[0].content["Observation"]["content"] == str(obs)
286+
assert entries[1].content["Plan"]["content"] == plan_content
287+
assert entries[2].content["Plan-Execution"]["content"] == str(plan)
288+
assert entries[0].content["Observation"]["importance"] == 3
289+
assert entries[1].content["Plan"]["importance"] == 3
290+
assert entries[2].content["Plan-Execution"]["importance"] == 3
289291
assert memory.grade_event_importance.call_count == 3
290292

291293
def test_async_plan_works(self, monkeypatch):
@@ -301,12 +303,14 @@ def test_async_plan_works(self, monkeypatch):
301303
plan = asyncio.run(reasoning.aplan(obs=obs))
302304

303305
assert isinstance(plan, Plan)
304-
assert memory.step_content["Observation"]["content"] == str(obs)
305-
assert memory.step_content["Plan"]["content"] == plan_content
306-
assert memory.step_content["Plan-Execution"]["content"] == str(plan)
307-
assert memory.step_content["Observation"]["importance"] == 3
308-
assert memory.step_content["Plan"]["importance"] == 3
309-
assert memory.step_content["Plan-Execution"]["importance"] == 3
306+
entries = list(memory.memory_entries)
307+
assert len(entries) == 3
308+
assert entries[0].content["Observation"]["content"] == str(obs)
309+
assert entries[1].content["Plan"]["content"] == plan_content
310+
assert entries[2].content["Plan-Execution"]["content"] == str(plan)
311+
assert entries[0].content["Observation"]["importance"] == 3
312+
assert entries[1].content["Plan"]["importance"] == 3
313+
assert entries[2].content["Plan-Execution"]["importance"] == 3
310314
assert memory.agrade_event_importance.await_count == 3
311315

312316

@@ -677,8 +681,10 @@ def test_plan_records_to_memory(self, monkeypatch):
677681

678682
plan = reasoning.plan()
679683
assert isinstance(plan, Plan)
680-
assert memory.step_content["plan"]["content"] == plan_content
681-
assert memory.step_content["plan"]["importance"] == 3
684+
entries = list(memory.memory_entries)
685+
assert len(entries) == 1
686+
assert entries[0].content["plan"]["content"] == plan_content
687+
assert entries[0].content["plan"]["importance"] == 3
682688
assert memory.grade_event_importance.call_count == 1
683689
reasoning.execute_tool_call.assert_called_once_with(
684690
plan_content, selected_tools=None, ttl=1
@@ -699,8 +705,10 @@ def test_async_plan_works(self, monkeypatch):
699705

700706
plan = asyncio.run(reasoning.aplan())
701707
assert isinstance(plan, Plan)
702-
assert memory.step_content["plan"]["content"] == plan_content
703-
assert memory.step_content["plan"]["importance"] == 3
708+
entries = list(memory.memory_entries)
709+
assert len(entries) == 1
710+
assert entries[0].content["plan"]["content"] == plan_content
711+
assert entries[0].content["plan"]["importance"] == 3
704712
assert memory.grade_event_importance.call_count == 1
705713
reasoning.aexecute_tool_call.assert_awaited_once_with(
706714
plan_content, selected_tools=None, ttl=1

0 commit comments

Comments
 (0)