Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 32 additions & 16 deletions mesa_llm/memory/st_lt_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,30 +72,46 @@ def __init__(

self.llm.system_prompt = self.system_prompt

def _build_consolidation_prompt(self) -> str:
def _build_consolidation_prompt(
self, popped_memories: list[MemoryEntry] | None = None
) -> str:
"""
Prompt builder function to reduce redundancy
"""
if popped_memories:
lines = []
for st_memory_entry in popped_memories:
lines.append(
f"Step {st_memory_entry.step}: \n{st_memory_entry.content}"
)
short_term_str = "\n".join(lines)
else:
short_term_str = self.format_short_term()

return f"""
Short term memory:
{self.format_short_term()}
{short_term_str}
Long term memory:
{self.long_term_memory}
"""

def _update_long_term_memory(self):
def _update_long_term_memory(
self, popped_memories: list[MemoryEntry] | None = None
):
"""
Update the long term memory by summarizing the short term memory with a LLM
"""
prompt = self._build_consolidation_prompt()
prompt = self._build_consolidation_prompt(popped_memories)
response = self.llm.generate(prompt)
self.long_term_memory = response.choices[0].message.content

async def _aupdate_long_term_memory(self):
async def _aupdate_long_term_memory(
self, popped_memories: list[MemoryEntry] | None = None
):
"""
Async Function to update long term memory
"""
prompt = self._build_consolidation_prompt()
prompt = self._build_consolidation_prompt(popped_memories)
response = await self.llm.agenerate(prompt)
self.long_term_memory = response.choices[0].message.content

Expand Down Expand Up @@ -130,31 +146,31 @@ def _process_step_core(self, pre_step: bool):
self.short_term_memory.append(new_entry)
self.step_content = {}

should_consolidate = False
popped_memories = []
if (
len(self.short_term_memory)
> self.capacity + (self.consolidation_capacity or 0)
and self.consolidation_capacity
):
self.short_term_memory.popleft()
should_consolidate = True
for _ in range(self.consolidation_capacity):
popped_memories.append(self.short_term_memory.popleft())

elif (
len(self.short_term_memory) > self.capacity
and not self.consolidation_capacity
):
self.short_term_memory.popleft()

return new_entry, should_consolidate
return new_entry, popped_memories

def process_step(self, pre_step: bool = False):
"""
Synchronous memory step handler
"""
new_entry, should_consolidate = self._process_step_core(pre_step)
new_entry, popped_memories = self._process_step_core(pre_step)

if should_consolidate:
self._update_long_term_memory()
if popped_memories:
self._update_long_term_memory(popped_memories)

if new_entry and self.display:
new_entry.display()
Expand All @@ -163,10 +179,10 @@ async def aprocess_step(self, pre_step: bool = False):
"""
Async memory step handler (non-blocking consolidation)
"""
new_entry, should_consolidate = self._process_step_core(pre_step)
new_entry, popped_memories = self._process_step_core(pre_step)

if should_consolidate:
await self._aupdate_long_term_memory()
if popped_memories:
await self._aupdate_long_term_memory(popped_memories)

if new_entry and self.display:
new_entry.display()
Expand Down
81 changes: 81 additions & 0 deletions tests/test_memory/test_STLT_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,84 @@ def test_get_prompt_ready_returns_str_when_empty(self, mock_agent):
)
assert "Short term memory:" in result
assert "Long term memory:" in result

def test_memory_does_not_drop_oldest_during_consolidation(
self, mock_agent, mock_llm, llm_response_factory
):
"""
Verify that the oldest entry in short-term memory is still present when
building the consolidation prompt. Fixes Issue #186.
"""
mock_llm.generate.return_value = llm_response_factory(
"Consolidated memory summary"
)

memory = STLTMemory(
agent=mock_agent,
short_term_capacity=2,
consolidation_capacity=1,
llm_model="provider/test_model",
)
memory.llm = mock_llm

# Fill to capacity + consolidation_capacity + 1 to trigger consolidation
with patch("rich.console.Console"):
for i in range(4):
memory.step_content = {"observation": f"critical event at step {i}"}
memory.process_step(pre_step=True)
mock_agent.model.steps = i + 1
memory.step_content = {"action": f"response to step {i}"}
memory.process_step()

# Get the arguments passed to llm.generate during consolidation
call_args = mock_llm.generate.call_args[0][0]

# The oldest entry MUST be in the prompt sent to the LLM
assert "critical event at step 0" in call_args, (
"Oldest memory entry was silently dropped before consolidation!"
)

def test_memory_pops_and_summarizes_exact_k_items(
self, mock_agent, mock_llm, llm_response_factory
):
"""
Verify that exactly consolidation_capacity items are popped when capacity is exceeded
and ONLY those popped items are sent to the LLM for summarization. Issue #107.
"""
mock_llm.generate.return_value = llm_response_factory(
"Consolidated memory summary"
)

memory = STLTMemory(
agent=mock_agent,
short_term_capacity=2,
consolidation_capacity=2,
llm_model="provider/test_model",
)
memory.llm = mock_llm

# We will add 5 items. The first 4 fill the capacity + consolidation_capacity (2 + 2 = 4).
# When the 5th item is added, consolidation should trigger.
with patch("rich.console.Console"):
for i in range(5):
memory.step_content = {"observation": f"event at step {i}"}
memory.process_step(pre_step=True)
mock_agent.model.steps = i + 1
memory.step_content = {"action": f"response to step {i}"}
memory.process_step()

# At the 5th item (i=4), consolidation is triggered.
# So exactly 'consolidation_capacity' (2) items are popped (items i=0 and i=1)
# And the remaining queue should have capacity (2) + 1 (new entry) = 3 items (i=2, i=3, i=4)
assert len(memory.short_term_memory) == 3

call_args = mock_llm.generate.call_args[0][0]

# The prompt should contain only the popped items (0 and 1)
assert "event at step 0" in call_args
assert "event at step 1" in call_args

# The prompt MUST NOT contain items that were NOT popped (2, 3, 4)
assert "event at step 2" not in call_args
assert "event at step 3" not in call_args
assert "event at step 4" not in call_args