Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 49 additions & 30 deletions mesa_llm/memory/st_lt_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,42 +72,56 @@ def __init__(

self.llm.system_prompt = self.system_prompt

def _build_consolidation_prompt(self) -> str:
def _build_consolidation_prompt(self, evicted_entries: list[MemoryEntry]) -> str:
"""
Prompt builder function to reduce redundancy
"""
return f"""
Short term memory:
{self.format_short_term()}
Long term memory:
{self.long_term_memory}
Build a prompt that asks the LLM to integrate *evicted* memories
into the existing long-term summary.

Args:
evicted_entries: the oldest short-term entries that were just
removed from the deque and need to be summarized.
"""
evicted_text = "\n".join(
f"Step {e.step}: \n{e.content}" for e in evicted_entries
)
return (
"Memories to consolidate (oldest entries being removed "
"from short-term memory):\n"
f"{evicted_text}\n\n"
f"Existing long term memory:\n{self.long_term_memory}\n\n"
"Please integrate the above memories into a concise, updated "
"long-term memory summary."
)

def _update_long_term_memory(self):
def _update_long_term_memory(self, evicted_entries: list[MemoryEntry]):
"""
Update the long term memory by summarizing the short term memory with a LLM
Update the long term memory by summarizing the evicted entries
"""
prompt = self._build_consolidation_prompt()
prompt = self._build_consolidation_prompt(evicted_entries)
response = self.llm.generate(prompt)
self.long_term_memory = response.choices[0].message.content

async def _aupdate_long_term_memory(self):
async def _aupdate_long_term_memory(self, evicted_entries: list[MemoryEntry]):
"""
Async Function to update long term memory
Async version of _update_long_term_memory
"""
prompt = self._build_consolidation_prompt()
prompt = self._build_consolidation_prompt(evicted_entries)
response = await self.llm.agenerate(prompt)
self.long_term_memory = response.choices[0].message.content

def _process_step_core(self, pre_step: bool):
"""
Shared core logic for process_step and aprocess_step
Shared core logic for process_step and aprocess_step.

Update short-term memory and decide if consolidation is needed.
When entries are evicted for consolidation they are captured and
returned so the caller can pass them to the LLM for summarization.

Returns:
"(new_entry, should_consolidate)"
``(new_entry, evicted_entries)`` where *evicted_entries* is a
(possibly empty) list of MemoryEntry objects that were removed
from short-term memory and should be consolidated.
"""
# Add the new entry to the short term memory
if pre_step:
new_entry = MemoryEntry(
agent=self.agent,
Expand All @@ -116,45 +130,50 @@ def _process_step_core(self, pre_step: bool):
)
self.short_term_memory.append(new_entry)
self.step_content = {}
return None, False
return None, []

if not self.short_term_memory or self.short_term_memory[-1].step is not None:
return None, False
return None, []

pre_step_entry = self.short_term_memory.pop()
self.step_content.update(pre_step_entry.content)
new_entry = MemoryEntry(
agent=self.agent,
content=self.step_content,
step=self.agent.model.steps,
)

self.short_term_memory.append(new_entry)
self.step_content = {}

should_consolidate = False
evicted: list[MemoryEntry] = []

if (
len(self.short_term_memory)
> self.capacity + (self.consolidation_capacity or 0)
and self.consolidation_capacity
):
self.short_term_memory.popleft()
should_consolidate = True
# Pop consolidation_capacity oldest entries for summarization
for _ in range(self.consolidation_capacity):
if self.short_term_memory:
evicted.append(self.short_term_memory.popleft())

elif (
len(self.short_term_memory) > self.capacity
and not self.consolidation_capacity
):
# No consolidation configured — just discard the oldest entry
self.short_term_memory.popleft()

return new_entry, should_consolidate
return new_entry, evicted

def process_step(self, pre_step: bool = False):
"""
Synchronous memory step handler
"""
new_entry, should_consolidate = self._process_step_core(pre_step)
new_entry, evicted = self._process_step_core(pre_step)

if should_consolidate:
self._update_long_term_memory()
if evicted:
self._update_long_term_memory(evicted)

if new_entry and self.display:
new_entry.display()
Expand All @@ -163,10 +182,10 @@ async def aprocess_step(self, pre_step: bool = False):
"""
Async memory step handler (non-blocking consolidation)
"""
new_entry, should_consolidate = self._process_step_core(pre_step)
new_entry, evicted = self._process_step_core(pre_step)

if should_consolidate:
await self._aupdate_long_term_memory()
if evicted:
await self._aupdate_long_term_memory(evicted)

if new_entry and self.display:
new_entry.display()
Expand Down
74 changes: 69 additions & 5 deletions tests/test_memory/test_STLT_memory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from collections import deque
from unittest.mock import patch
from unittest.mock import AsyncMock, patch

import pytest

from mesa_llm.memory.memory import MemoryEntry
from mesa_llm.memory.st_lt_memory import STLTMemory
Expand Down Expand Up @@ -120,11 +122,15 @@ def test_update_long_term_memory(self, mock_agent, mock_llm, llm_response_factor
memory.llm = mock_llm
memory.long_term_memory = "Previous memory"

memory._update_long_term_memory()
evicted = [
MemoryEntry(
content={"observation": "old content"}, step=0, agent=mock_agent
)
]
memory._update_long_term_memory(evicted)

call_args = mock_llm.generate.call_args[0][0]
assert "Short term memory:" in call_args
assert "Long term memory:" in call_args
assert "old content" in call_args
assert "Previous memory" in call_args

# Must be a plain string, not a ModelResponse object
Expand All @@ -145,13 +151,71 @@ def test_long_term_memory_stores_string_not_response_object(
memory = STLTMemory(agent=mock_agent, llm_model="provider/test_model")
memory.llm = mock_llm

memory._update_long_term_memory()
evicted = [MemoryEntry(content={"data": "evicted"}, step=0, agent=mock_agent)]
memory._update_long_term_memory(evicted)

assert isinstance(memory.long_term_memory, str), (
"long_term_memory must be a string, not a ModelResponse object"
)
assert memory.long_term_memory == "This is the summary text"

def test_consolidation_receives_evicted_entries(
self, mock_agent, mock_llm, llm_response_factory
):
"""Regression test for #107: evicted entries must be passed to the
LLM for summarization, not the remaining short-term memories."""
mock_llm.generate.return_value = llm_response_factory("Consolidated summary")

memory = STLTMemory(
agent=mock_agent,
short_term_capacity=2,
consolidation_capacity=2,
llm_model="provider/test_model",
)
memory.llm = mock_llm

# Fill up: 2 (capacity) + 2 (consolidation) + 1 to trigger
with patch("rich.console.Console"):
for i in range(5):
memory.add_to_memory("observation", {"content": f"step_{i}"})
memory.process_step(pre_step=True)
mock_agent.model.steps = i + 1
memory.process_step(pre_step=False)

# The LLM should have been called with the evicted entries
assert mock_llm.generate.called
prompt = mock_llm.generate.call_args[0][0]
# The prompt must contain the evicted memories, not just the
# remaining ones
assert "consolidate" in prompt.lower() or "removed" in prompt.lower()

@pytest.mark.asyncio
async def test_aupdate_long_term_memory(
self, mock_agent, mock_llm, llm_response_factory
):
"""Cover the async consolidation path (_aupdate_long_term_memory)."""
mock_llm.agenerate = AsyncMock(
return_value=llm_response_factory("Async consolidated summary")
)

memory = STLTMemory(agent=mock_agent, llm_model="provider/test_model")
memory.llm = mock_llm
memory.long_term_memory = "Old summary"

evicted = [
MemoryEntry(
content={"observation": "evicted data"}, step=0, agent=mock_agent
)
]
await memory._aupdate_long_term_memory(evicted)

mock_llm.agenerate.assert_called_once()
prompt = mock_llm.agenerate.call_args[0][0]
assert "evicted data" in prompt
assert "Old summary" in prompt
assert isinstance(memory.long_term_memory, str)
assert memory.long_term_memory == "Async consolidated summary"

def test_observation_tracking(self, mock_agent):
"""Test that observations are properly tracked and only changes stored"""
memory = STLTMemory(agent=mock_agent, llm_model="provider/test_model")
Expand Down
Loading