diff --git a/mesa_llm/memory/st_lt_memory.py b/mesa_llm/memory/st_lt_memory.py index a4de6836..6b53c1c0 100644 --- a/mesa_llm/memory/st_lt_memory.py +++ b/mesa_llm/memory/st_lt_memory.py @@ -72,30 +72,46 @@ def __init__( self.llm.system_prompt = self.system_prompt - def _build_consolidation_prompt(self) -> str: + def _build_consolidation_prompt( + self, popped_memories: list[MemoryEntry] | None = None + ) -> str: """ Prompt builder function to reduce redundancy """ + if popped_memories: + lines = [] + for st_memory_entry in popped_memories: + lines.append( + f"Step {st_memory_entry.step}: \n{st_memory_entry.content}" + ) + short_term_str = "\n".join(lines) + else: + short_term_str = self.format_short_term() + return f""" Short term memory: - {self.format_short_term()} + {short_term_str} Long term memory: {self.long_term_memory} """ - def _update_long_term_memory(self): + def _update_long_term_memory( + self, popped_memories: list[MemoryEntry] | None = None + ): """ Update the long term memory by summarizing the short term memory with a LLM """ - prompt = self._build_consolidation_prompt() + prompt = self._build_consolidation_prompt(popped_memories) response = self.llm.generate(prompt) self.long_term_memory = response.choices[0].message.content - async def _aupdate_long_term_memory(self): + async def _aupdate_long_term_memory( + self, popped_memories: list[MemoryEntry] | None = None + ): """ Async Function to update long term memory """ - prompt = self._build_consolidation_prompt() + prompt = self._build_consolidation_prompt(popped_memories) response = await self.llm.agenerate(prompt) self.long_term_memory = response.choices[0].message.content @@ -130,14 +146,14 @@ def _process_step_core(self, pre_step: bool): self.short_term_memory.append(new_entry) self.step_content = {} - should_consolidate = False + popped_memories = [] if ( len(self.short_term_memory) > self.capacity + (self.consolidation_capacity or 0) and self.consolidation_capacity ): - self.short_term_memory.popleft() - should_consolidate = True + for _ in range(self.consolidation_capacity): + popped_memories.append(self.short_term_memory.popleft()) elif ( len(self.short_term_memory) > self.capacity @@ -145,16 +161,16 @@ def _process_step_core(self, pre_step: bool): ): self.short_term_memory.popleft() - return new_entry, should_consolidate + return new_entry, popped_memories def process_step(self, pre_step: bool = False): """ Synchronous memory step handler """ - new_entry, should_consolidate = self._process_step_core(pre_step) + new_entry, popped_memories = self._process_step_core(pre_step) - if should_consolidate: - self._update_long_term_memory() + if popped_memories: + self._update_long_term_memory(popped_memories) if new_entry and self.display: new_entry.display() @@ -163,10 +179,10 @@ async def aprocess_step(self, pre_step: bool = False): """ Async memory step handler (non-blocking consolidation) """ - new_entry, should_consolidate = self._process_step_core(pre_step) + new_entry, popped_memories = self._process_step_core(pre_step) - if should_consolidate: - await self._aupdate_long_term_memory() + if popped_memories: + await self._aupdate_long_term_memory(popped_memories) if new_entry and self.display: new_entry.display() diff --git a/tests/test_memory/test_STLT_memory.py b/tests/test_memory/test_STLT_memory.py index c0d90696..36296c0d 100644 --- a/tests/test_memory/test_STLT_memory.py +++ b/tests/test_memory/test_STLT_memory.py @@ -200,3 +200,84 @@ def test_get_prompt_ready_returns_str_when_empty(self, mock_agent): ) assert "Short term memory:" in result assert "Long term memory:" in result + + def test_memory_does_not_drop_oldest_during_consolidation( + self, mock_agent, mock_llm, llm_response_factory + ): + """ + Verify that the oldest entry in short-term memory is still present when + building the consolidation prompt. Fixes Issue #186. + """ + mock_llm.generate.return_value = llm_response_factory( + "Consolidated memory summary" + ) + + memory = STLTMemory( + agent=mock_agent, + short_term_capacity=2, + consolidation_capacity=1, + llm_model="provider/test_model", + ) + memory.llm = mock_llm + + # Fill to capacity + consolidation_capacity + 1 to trigger consolidation + with patch("rich.console.Console"): + for i in range(4): + memory.step_content = {"observation": f"critical event at step {i}"} + memory.process_step(pre_step=True) + mock_agent.model.steps = i + 1 + memory.step_content = {"action": f"response to step {i}"} + memory.process_step() + + # Get the arguments passed to llm.generate during consolidation + call_args = mock_llm.generate.call_args[0][0] + + # The oldest entry MUST be in the prompt sent to the LLM + assert "critical event at step 0" in call_args, ( + "Oldest memory entry was silently dropped before consolidation!" + ) + + def test_memory_pops_and_summarizes_exact_k_items( + self, mock_agent, mock_llm, llm_response_factory + ): + """ + Verify that exactly consolidation_capacity items are popped when capacity is exceeded + and ONLY those popped items are sent to the LLM for summarization. Issue #107. + """ + mock_llm.generate.return_value = llm_response_factory( + "Consolidated memory summary" + ) + + memory = STLTMemory( + agent=mock_agent, + short_term_capacity=2, + consolidation_capacity=2, + llm_model="provider/test_model", + ) + memory.llm = mock_llm + + # We will add 5 items. The first 4 fill the capacity + consolidation_capacity (2 + 2 = 4). + # When the 5th item is added, consolidation should trigger. + with patch("rich.console.Console"): + for i in range(5): + memory.step_content = {"observation": f"event at step {i}"} + memory.process_step(pre_step=True) + mock_agent.model.steps = i + 1 + memory.step_content = {"action": f"response to step {i}"} + memory.process_step() + + # At the 5th item (i=4), consolidation is triggered. + # So exactly 'consolidation_capacity' (2) items are popped (items i=0 and i=1) + # And the remaining queue should have capacity (2) + 1 (new entry) = 3 items (i=2, i=3, i=4) + assert len(memory.short_term_memory) == 3 + + call_args = mock_llm.generate.call_args[0][0] + + # The prompt should contain only the popped items (0 and 1) + assert "event at step 0" in call_args + assert "event at step 1" in call_args + + # The prompt MUST NOT contain items that were NOT popped (2, 3, 4) + assert "event at step 2" not in call_args + assert "event at step 3" not in call_args + assert "event at step 4" not in call_args