diff --git a/mesa_llm/llm_agent.py b/mesa_llm/llm_agent.py index 3ed8a32d..35d8dc40 100644 --- a/mesa_llm/llm_agent.py +++ b/mesa_llm/llm_agent.py @@ -262,15 +262,22 @@ async def asend_message(self, message: str, recipients: list[Agent]) -> str: """ Asynchronous version of send_message. """ - for recipient in [*recipients, self]: + for recipient in recipients: await recipient.memory.aadd_to_memory( type="message", content={ "message": message, "sender": self.unique_id, - "recipients": [r.unique_id for r in recipients], }, ) + await self.memory.aadd_to_memory( + type="message", + content={ + "message": message, + "sender": self.unique_id, + "recipients": [r.unique_id for r in recipients], + }, + ) return f"{self} → {recipients} : {message}" @@ -278,15 +285,22 @@ def send_message(self, message: str, recipients: list[Agent]) -> str: """ Send a message to the recipients. """ - for recipient in [*recipients, self]: + for recipient in recipients: recipient.memory.add_to_memory( type="message", content={ "message": message, "sender": self.unique_id, - "recipients": [r.unique_id for r in recipients], }, ) + self.memory.add_to_memory( + type="message", + content={ + "message": message, + "sender": self.unique_id, + "recipients": [r.unique_id for r in recipients], + }, + ) return f"{self} → {recipients} : {message}" diff --git a/mesa_llm/tools/inbuilt_tools.py b/mesa_llm/tools/inbuilt_tools.py index 9003d4ee..adca03d4 100644 --- a/mesa_llm/tools/inbuilt_tools.py +++ b/mesa_llm/tools/inbuilt_tools.py @@ -1,3 +1,4 @@ +import logging from typing import TYPE_CHECKING, Any from mesa.discrete_space import ( @@ -15,6 +16,8 @@ if TYPE_CHECKING: from mesa_llm.llm_agent import LLMAgent +logger = logging.getLogger(__name__) + # Mapping directions to (dx, dy) for Cartesian-style spaces. direction_map_xy = { "North": (0, 1), @@ -207,15 +210,35 @@ def speak_to( and listener_agent.unique_id != agent.unique_id ] + delivered_ids = [] + skipped_ids = [] + for recipient in listener_agents: + if not hasattr(recipient, "memory"): + skipped_ids.append(recipient.unique_id) + logger.warning( + "Agent %s has no memory attribute; skipping speak_to.", + recipient.unique_id, + ) + continue + delivered_ids.append(recipient.unique_id) recipient.memory.add_to_memory( type="message", content={ "message": message, "sender": agent.unique_id, - "recipients": [ - listener_agent.unique_id for listener_agent in listener_agents - ], }, ) - return f"{agent.unique_id} → {[agent.unique_id for agent in listener_agents]} : {message}" + + status_parts = [] + if delivered_ids: + status_parts.append(f"sent message {message!r} to {delivered_ids}") + if skipped_ids: + status_parts.append( + f"skipped {skipped_ids} because they have no `memory` attribute" + ) + + if not status_parts: + return f"Could not send message {message!r}: no matching recipients found." + + return "; ".join(status_parts) diff --git a/tests/test_llm_agent.py b/tests/test_llm_agent.py index a3180211..2c5aad5b 100644 --- a/tests/test_llm_agent.py +++ b/tests/test_llm_agent.py @@ -276,22 +276,109 @@ def add_agent(self, pos, agent_class=LLMAgent): ) recipient.unique_id = 2 - # Track how many times add_to_memory is called - call_counter = {"count": 0} + recorded_calls = [] def fake_add_to_memory(*args, **kwargs): - call_counter["count"] += 1 + recorded_calls.append(("sender", kwargs)) + + def fake_recipient_add_to_memory(*args, **kwargs): + recorded_calls.append(("recipient", kwargs)) # monkeypatch both agents' memory modules monkeypatch.setattr(sender.memory, "add_to_memory", fake_add_to_memory) - monkeypatch.setattr(recipient.memory, "add_to_memory", fake_add_to_memory) + monkeypatch.setattr(recipient.memory, "add_to_memory", fake_recipient_add_to_memory) result = sender.send_message("hello", recipients=[recipient]) pattern = r"LLMAgent 1 → \[\] : hello" assert re.match(pattern, result) # sender + recipient memory => should be called twice - assert call_counter["count"] == 2 + assert len(recorded_calls) == 2 + sender_call = next(call for label, call in recorded_calls if label == "sender") + recipient_call = next( + call for label, call in recorded_calls if label == "recipient" + ) + assert sender_call["type"] == "message" + assert sender_call["content"]["message"] == "hello" + assert sender_call["content"]["sender"] == sender.unique_id + assert sender_call["content"]["recipients"] == [recipient.unique_id] + assert recipient_call["type"] == "message" + assert recipient_call["content"]["message"] == "hello" + assert recipient_call["content"]["sender"] == sender.unique_id + assert "recipients" not in recipient_call["content"] + + +@pytest.mark.asyncio +async def test_asend_message_updates_both_agents_memory(monkeypatch): + monkeypatch.setenv("GEMINI_API_KEY", "dummy") + + class DummyModel(Model): + def __init__(self): + super().__init__(seed=45) + self.grid = MultiGrid(3, 3, torus=False) + + def add_agent(self, pos, agent_class=LLMAgent): + system_prompt = "You are an agent in a simulation." + agents = agent_class.create_agents( + self, + n=1, + reasoning=lambda agent: None, + system_prompt=system_prompt, + vision=-1, + internal_state=["test_state"], + ) + x, y = pos + agent = agents.to_list()[0] + self.grid.place_agent(agent, (x, y)) + return agent + + model = DummyModel() + sender = model.add_agent((0, 0)) + sender.memory = ShortTermMemory( + agent=sender, + n=5, + display=True, + ) + sender.unique_id = 1 + + recipient = model.add_agent((1, 1)) + recipient.memory = ShortTermMemory( + agent=recipient, + n=5, + display=True, + ) + recipient.unique_id = 2 + + recorded_calls = [] + + async def fake_aadd_to_memory(*args, **kwargs): + recorded_calls.append(("sender", kwargs)) + + async def fake_recipient_aadd_to_memory(*args, **kwargs): + recorded_calls.append(("recipient", kwargs)) + + monkeypatch.setattr(sender.memory, "aadd_to_memory", fake_aadd_to_memory) + monkeypatch.setattr( + recipient.memory, "aadd_to_memory", fake_recipient_aadd_to_memory + ) + + result = await sender.asend_message("hello", recipients=[recipient]) + pattern = r"LLMAgent 1 → \[\] : hello" + assert re.match(pattern, result) + + assert len(recorded_calls) == 2 + sender_call = next(call for label, call in recorded_calls if label == "sender") + recipient_call = next( + call for label, call in recorded_calls if label == "recipient" + ) + assert sender_call["type"] == "message" + assert sender_call["content"]["message"] == "hello" + assert sender_call["content"]["sender"] == sender.unique_id + assert sender_call["content"]["recipients"] == [recipient.unique_id] + assert recipient_call["type"] == "message" + assert recipient_call["content"]["message"] == "hello" + assert recipient_call["content"]["sender"] == sender.unique_id + assert "recipients" not in recipient_call["content"] @pytest.mark.asyncio @@ -833,13 +920,13 @@ def capture_content(type, content): sender.send_message("hello", recipients=[recipient]) assert captured["sender"] == 10 - assert captured["recipients"] == [20] assert captured["message"] == "hello" # Must not raise TypeError when serializing data = json.loads(json.dumps(captured)) assert data["sender"] == 10 - assert data["recipients"] == [20] + assert "recipients" not in data # recipients only stored in sender, not recipient + assert data["message"] == "hello" @pytest.mark.asyncio @@ -861,9 +948,12 @@ async def noop(*a, **kw): await sender.asend_message("hello", recipients=[recipient]) assert captured["sender"] == 10 - assert captured["recipients"] == [20] + assert ( + "recipients" not in captured + ) # recipients only stored in sender, not recipient assert captured["message"] == "hello" data = json.loads(json.dumps(captured)) assert data["sender"] == 10 - assert data["recipients"] == [20] + assert data["message"] == "hello" + assert "recipients" not in data diff --git a/tests/test_tools/test_inbuilt_tools.py b/tests/test_tools/test_inbuilt_tools.py index 7a9c367c..7946247a 100644 --- a/tests/test_tools/test_inbuilt_tools.py +++ b/tests/test_tools/test_inbuilt_tools.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging from types import SimpleNamespace import pytest @@ -170,10 +171,8 @@ def test_speak_to_records_on_recipients(mocker): content = kwargs["content"] assert content["message"] == message assert content["sender"] == sender.unique_id - assert set(content["recipients"]) == {11, 12} - - # Return string contains sender and recipients list - assert "10" in ret and "11" in ret and "12" in ret and message in ret + assert "recipients" not in content + assert ret == "sent message 'Hello there' to [11, 12]" def test_move_one_step_invalid_direction(): @@ -582,3 +581,63 @@ class _DummyOrthogonalGrid(OrthogonalMooreGrid): assert agent.cell is wrapped_cell assert result == "agent 29 moved to (2, 2)." + + +def test_speak_to_skips_non_llm_recipient(mocker): + """ + speak_to must not crash when a recipient has no memory attribute. + + This covers the case where a non-LLM (rule-based) agent is listed as a + recipient. + """ + model = DummyModel() + + sender = DummyAgent(unique_id=1, model=model) + llm_recipient = DummyAgent(unique_id=2, model=model) + rule_recipient = DummyAgent(unique_id=3, model=model) + + llm_recipient.memory = SimpleNamespace(add_to_memory=mocker.Mock()) + + model.agents = [sender, llm_recipient, rule_recipient] + + ret = speak_to(sender, [2, 3], "Hello both") + + llm_recipient.memory.add_to_memory.assert_called_once() + call_kwargs = llm_recipient.memory.add_to_memory.call_args[1] + assert call_kwargs["type"] == "message" + assert call_kwargs["content"]["message"] == "Hello both" + assert "recipients" not in call_kwargs["content"] + + assert ret == ( + "sent message 'Hello both' to [2]; skipped [3] because they have no `memory` attribute" + ) + + +def test_speak_to_warns_for_non_llm_recipient(mocker, caplog): + model = DummyModel() + sender = DummyAgent(unique_id=10, model=model) + rule_recipient = DummyAgent(unique_id=11, model=model) # no .memory + + model.agents = [sender, rule_recipient] + + with caplog.at_level(logging.WARNING, logger="mesa_llm.tools.inbuilt_tools"): + ret = speak_to(sender, [11], "Test message") + + assert any( + "11" in record.message and "memory" in record.message + for record in caplog.records + ) + assert ret == "skipped [11] because they have no `memory` attribute" + + +def test_speak_to_returns_clear_message_when_no_valid_recipients(): + model = DummyModel() + sender = DummyAgent(unique_id=20, model=model) + + model.agents = [sender] + + ret = speak_to(sender, [20, 999], "Anyone there?") + + assert ( + ret == "Could not send message 'Anyone there?': no matching recipients found." + )