Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions mesa_llm/module_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
RateLimitError,
Timeout,
)
from pydantic import BaseModel
from tenacity import AsyncRetrying, retry, retry_if_exception_type, wait_exponential

RETRYABLE_EXCEPTIONS = (
Expand Down Expand Up @@ -79,7 +80,11 @@ def __init__(
self.llm_model,
)

def _build_messages(self, prompt: str | list[str] | None = None) -> list[dict]:
def _build_messages(
self,
prompt: str | list[str] | None = None,
system_prompt: str | None = None,
) -> list[dict]:
"""
Format the prompt messages for the LLM of the form : {"role": ..., "content": ...}

Expand All @@ -92,7 +97,10 @@ def _build_messages(self, prompt: str | list[str] | None = None) -> list[dict]:
messages = []

# Always include a system message. Default to empty string if no system prompt to support Ollama
system_content = self.system_prompt if self.system_prompt else ""
system_content = (
system_prompt if system_prompt is not None else self.system_prompt
)
system_content = system_content if system_content else ""
messages.append({"role": "system", "content": system_content})

if prompt:
Expand All @@ -104,6 +112,23 @@ def _build_messages(self, prompt: str | list[str] | None = None) -> list[dict]:

return messages

def parse_structured_output(
self,
response,
response_model: type[BaseModel],
) -> BaseModel:
"""Normalize structured LLM output into the requested pydantic model."""
message = response.choices[0].message
parsed = getattr(message, "parsed", None)

if isinstance(parsed, response_model):
return parsed

if parsed is not None:
return response_model.model_validate(parsed)

return response_model.model_validate_json(message.content)

@retry(
wait=wait_exponential(multiplier=1, min=1, max=60),
retry=retry_if_exception_type(RETRYABLE_EXCEPTIONS),
Expand All @@ -115,6 +140,7 @@ def generate(
tool_schema: list[dict] | None = None,
tool_choice: str = "auto",
response_format: dict | object | None = None,
system_prompt: str | None = None,
) -> str:
"""
Generate a response from the LLM using litellm based on the prompt
Expand All @@ -129,7 +155,7 @@ def generate(
The response from the LLM
"""

messages = self._build_messages(prompt)
messages = self._build_messages(prompt, system_prompt=system_prompt)

completion_kwargs = {
"model": self.llm_model,
Expand All @@ -151,11 +177,12 @@ async def agenerate(
tool_schema: list[dict] | None = None,
tool_choice: str = "auto",
response_format: dict | object | None = None,
system_prompt: str | None = None,
) -> str:
"""
Asynchronous version of generate() method for parallel LLM calls.
"""
messages = self._build_messages(prompt)
messages = self._build_messages(prompt, system_prompt=system_prompt)
async for attempt in AsyncRetrying(
wait=wait_exponential(multiplier=1, min=1, max=60),
retry=retry_if_exception_type(RETRYABLE_EXCEPTIONS),
Expand Down
192 changes: 192 additions & 0 deletions mesa_llm/reasoning/decision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
from typing import TYPE_CHECKING

from pydantic import BaseModel, Field

from mesa_llm.reasoning.reasoning import Observation, Plan, Reasoning

if TYPE_CHECKING:
from mesa_llm.llm_agent import LLMAgent


class DecisionOption(BaseModel):
name: str
description: str
tradeoffs: list[str]
score: float = Field(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a constraint here would be better in my opinion

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure !!

ge=0.0,
le=1.0,
description="Relative evaluation score for this option in the current context.",
)


class DecisionOutput(BaseModel):
goal: str
constraints: list[str]
known_facts: list[str]
unknowns: list[str]
assumptions: list[str]
options: list[DecisionOption]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic relies on a specific message.parsed attribute. The reasoning layer should not depend on response internals of a particular LLM warpper. It's better to normalise this earlier in the LLM wrapper.

chosen_option: str
rationale: str
confidence: float = Field(ge=0.0, le=1.0)
risks: list[str]
next_action: str


class DecisionReasoning(Reasoning):
"""
Structured decision-making reasoning that returns a strict JSON object before
converting the selected next action into tool calls.
"""

def __init__(self, agent: "LLMAgent"):
super().__init__(agent=agent)

def get_decision_system_prompt(self) -> str:
return """
You are an autonomous agent operating within a simulation environment.

Your task is to analyze your current observation and memory to make a highly structured, optimal decision.
Do not produce free-form chain-of-thought prose. You must evaluate the situation and return a strict JSON object matching the required schema.

Your response must include:
- goal: Your current primary objective within the simulation.
- constraints: Any rules, resource limits, or environmental boundaries restricting your actions.
- known_facts: Verified data strictly grounded in your current observation or historical memory.
- unknowns: Critical missing information required for perfect decision-making.
- assumptions: Logical inferences made to bridge the gap between known facts and unknowns.
- options: A list of distinct, executable choices currently available to you. Each must include a name, description, tradeoffs, and a relative evaluation score.
- chosen_option: The exact name of the best option selected from the list above.
- rationale: A concise, logical justification for why this option was chosen over the alternatives.
- confidence: A float between 0.0 and 1.0 representing your certainty in this decision.
- risks: Potential negative outcomes or failure states associated with the chosen option.
- next_action: A single, concrete, and strictly formatted executable command.

Execution Requirements:
1. Ground all known_facts entirely in the provided observation context. Do not hallucinate simulation state or capabilities.
2. next_action must strictly match an available execution command. Do not invent tools.
3. If information is heavily constrained or missing, explicitly reflect this by lowering the confidence score and detailing the danger in risks.
"""

def get_decision_prompt(self, obs: Observation) -> list[str]:
prompt_list = []

get_prompt_ready = getattr(self.agent.memory, "get_prompt_ready", None)
if callable(get_prompt_ready):
prompt_list.append(get_prompt_ready())

get_communication_history = getattr(
self.agent.memory, "get_communication_history", None
)
last_communication = (
get_communication_history() if callable(get_communication_history) else ""
)

if last_communication:
prompt_list.append("last communication: \n" + str(last_communication))
if obs:
prompt_list.append("current observation: \n" + str(obs))

return prompt_list

def plan(
self,
prompt: str | None = None,
obs: Observation | None = None,
ttl: int = 1,
selected_tools: list[str] | None = None,
) -> Plan:
"""
Plan the next action through a structured decision artifact.
"""
if obs is None:
obs = self.agent.generate_obs()

prompt_list = self.get_decision_prompt(obs)

if prompt is not None:
prompt_list.append(prompt)
elif self.agent.step_prompt is not None:
prompt_list.append(self.agent.step_prompt)
else:
raise ValueError("No prompt provided and agent.step_prompt is None.")

selected_tools_schema = self.agent.tool_manager.get_all_tools_schema(
selected_tools
)

rsp = self.agent.llm.generate(
prompt=prompt_list,
tool_schema=selected_tools_schema,
tool_choice="none",
response_format=DecisionOutput,
system_prompt=self.get_decision_system_prompt(),
)

formatted_response = self.agent.llm.parse_structured_output(
rsp, DecisionOutput
).model_dump()
self.agent.memory.add_to_memory(type="decision", content=formatted_response)

if hasattr(self.agent, "_step_display_data"):
self.agent._step_display_data["plan_content"] = formatted_response[
"rationale"
]

return self.execute_tool_call(
formatted_response["next_action"],
selected_tools=selected_tools,
ttl=ttl,
)

async def aplan(
self,
prompt: str | None = None,
obs: Observation | None = None,
ttl: int = 1,
selected_tools: list[str] | None = None,
) -> Plan:
"""
Asynchronous version of plan() method for parallel planning.
"""
if obs is None:
obs = await self.agent.agenerate_obs()

prompt_list = self.get_decision_prompt(obs)

if prompt is not None:
prompt_list.append(prompt)
elif self.agent.step_prompt is not None:
prompt_list.append(self.agent.step_prompt)
else:
raise ValueError("No prompt provided and agent.step_prompt is None.")

selected_tools_schema = self.agent.tool_manager.get_all_tools_schema(
selected_tools
)

rsp = await self.agent.llm.agenerate(
prompt=prompt_list,
tool_schema=selected_tools_schema,
tool_choice="none",
response_format=DecisionOutput,
system_prompt=self.get_decision_system_prompt(),
)

formatted_response = self.agent.llm.parse_structured_output(
rsp, DecisionOutput
).model_dump()
await self.agent.memory.aadd_to_memory(
type="decision", content=formatted_response
)

if hasattr(self.agent, "_step_display_data"):
self.agent._step_display_data["plan_content"] = formatted_response[
"rationale"
]

return await self.aexecute_tool_call(
formatted_response["next_action"],
selected_tools=selected_tools,
ttl=ttl,
)
33 changes: 32 additions & 1 deletion tests/test_module_llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
from unittest.mock import patch
from unittest.mock import Mock, patch

import pytest
from pydantic import BaseModel

from mesa_llm.module_llm import ModuleLLM

Expand Down Expand Up @@ -82,6 +83,36 @@ def test_build_messages(self):
messages = llm._build_messages(prompt=None)
assert messages == [{"role": "system", "content": ""}]

messages = llm._build_messages(
"Hello, how are you?", system_prompt="Per-call prompt"
)
assert messages == [
{"role": "system", "content": "Per-call prompt"},
{"role": "user", "content": "Hello, how are you?"},
]

def test_parse_structured_output(self):
class DummyOutput(BaseModel):
answer: str

llm = ModuleLLM(llm_model="openai/gpt-4o")

response = Mock()
response.choices = [Mock()]
response.choices[0].message = Mock()
response.choices[0].message.parsed = DummyOutput(answer="parsed")
parsed = llm.parse_structured_output(response, DummyOutput)
assert parsed.answer == "parsed"

response.choices[0].message.parsed = {"answer": "dict"}
parsed = llm.parse_structured_output(response, DummyOutput)
assert parsed.answer == "dict"

response.choices[0].message.parsed = None
response.choices[0].message.content = '{"answer":"json"}'
parsed = llm.parse_structured_output(response, DummyOutput)
assert parsed.answer == "json"

def test_generate(self, monkeypatch, llm_response_factory):
monkeypatch.setattr(
"mesa_llm.module_llm.completion", lambda **kwargs: llm_response_factory()
Expand Down
Loading