Skip to content

Commit 98e0e87

Browse files
committed
Merge branch 'main' into fix/episodic-memory
2 parents ac90308 + 4ea91ed commit 98e0e87

28 files changed

+2190
-184
lines changed

.github/workflows/release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
run: python -m build
3838
- name: Upload package as artifact to GitHub
3939
if: github.repository == 'mesa/mesa-llm' && startsWith(github.ref, 'refs/tags')
40-
uses: actions/upload-artifact@v6
40+
uses: actions/upload-artifact@v7
4141
with:
4242
name: package
4343
path: dist/

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ ci:
55
repos:
66
- repo: https://github.com/astral-sh/ruff-pre-commit
77
# Ruff version.
8-
rev: v0.14.14
8+
rev: v0.15.4
99
hooks:
1010
# Run the linter with fix argument.
1111
- id: ruff-check

mesa_llm/memory/lt_memory.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,16 @@ def _update_long_term_memory(self):
6363
Update the long term memory by summarizing the short term memory with a LLM
6464
"""
6565
prompt = self._build_consolidation_prompt()
66-
self.long_term_memory = self.llm.generate(prompt)
66+
response = self.llm.generate(prompt)
67+
self.long_term_memory = response.choices[0].message.content
6768

6869
async def _aupdate_long_term_memory(self):
6970
"""
7071
Asynchronous version of _update_long_term_memory
7172
"""
7273
prompt = self._build_consolidation_prompt()
73-
self.long_term_memory = await self.llm.agenerate(prompt)
74+
response = await self.llm.agenerate(prompt)
75+
self.long_term_memory = response.choices[0].message.content
7476

7577
def process_step(self, pre_step: bool = False):
7678
"""

mesa_llm/memory/memory.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,12 @@ def add_to_memory(self, type: str, content: dict):
137137
"""
138138
Add a new entry to the memory
139139
"""
140+
if not isinstance(content, dict):
141+
raise TypeError(
142+
"Expected 'content' to be dict, "
143+
f"got {content.__class__.__name__}: {content!r}"
144+
)
145+
140146
if type == "observation":
141147
# Only store changed parts of observation
142148
changed_parts = {

mesa_llm/memory/st_lt_memory.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,16 @@ def _update_long_term_memory(self):
8888
Update the long term memory by summarizing the short term memory with a LLM
8989
"""
9090
prompt = self._build_consolidation_prompt()
91-
self.long_term_memory = self.llm.generate(prompt)
91+
response = self.llm.generate(prompt)
92+
self.long_term_memory = response.choices[0].message.content
9293

9394
async def _aupdate_long_term_memory(self):
9495
"""
9596
Async Function to update long term memory
9697
"""
9798
prompt = self._build_consolidation_prompt()
98-
self.long_term_memory = await self.llm.agenerate(prompt)
99+
response = await self.llm.agenerate(prompt)
100+
self.long_term_memory = response.choices[0].message.content
99101

100102
def _process_step_core(self, pre_step: bool):
101103
"""
@@ -191,10 +193,10 @@ def format_short_term(self) -> str:
191193
return "\n".join(lines)
192194

193195
def get_prompt_ready(self) -> str:
194-
return [
195-
f"Short term memory:\n {self.format_short_term()}",
196-
f"Long term memory: \n{self.format_long_term()}",
197-
]
196+
return (
197+
f"Short term memory:\n {self.format_short_term()}\n\n"
198+
f"Long term memory: \n{self.format_long_term()}"
199+
)
198200

199201
def get_communication_history(self) -> str:
200202
"""

mesa_llm/memory/st_memory.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class ShortTermMemory(Memory):
1313
1414
Attributes:
1515
agent : the agent that the memory belongs to
16-
n : number of short-term memories to remember
16+
n : positive number of short-term memories to remember
1717
display : whether to display the memory
1818
llm_model : the model to use for the summarization
1919
"""
@@ -24,12 +24,16 @@ def __init__(
2424
n: int = 5,
2525
display: bool = True,
2626
):
27+
if n < 1:
28+
raise ValueError("n must be >= 1 for ShortTermMemory")
29+
2730
super().__init__(
2831
agent=agent,
2932
display=display,
3033
)
3134
self.n = n
32-
self.short_term_memory = deque()
35+
self.short_term_memory = deque(maxlen=self.n)
36+
self._current_step_entry: MemoryEntry | None = None
3337

3438
async def aprocess_step(self, pre_step: bool = False):
3539
"""
@@ -40,35 +44,36 @@ async def aprocess_step(self, pre_step: bool = False):
4044
def process_step(self, pre_step: bool = False):
4145
"""
4246
Process the step of the agent :
43-
- Add the new entry to the short term memory
47+
- Capture pre-step content into the current in-progress step entry
48+
- Merge current and post-step content into one finalized entry
4449
- Display the new entry
4550
"""
4651

47-
# Add the new entry to the short term memory
52+
# Save a temporary pre-step snapshot. This entry is not persisted in deque.
4853
if pre_step:
49-
new_entry = MemoryEntry(
54+
self._current_step_entry = MemoryEntry(
5055
agent=self.agent,
5156
content=self.step_content,
5257
step=None,
5358
)
54-
self.short_term_memory.append(new_entry)
5559
self.step_content = {}
5660
return
5761

58-
elif not self.short_term_memory[-1].content.get("step", None):
59-
pre_step = self.short_term_memory.pop()
60-
self.step_content.update(pre_step.content)
62+
new_entry = None
63+
if self._current_step_entry is not None:
64+
merged_content = dict(self.step_content)
65+
merged_content.update(self._current_step_entry.content)
6166
new_entry = MemoryEntry(
6267
agent=self.agent,
63-
content=self.step_content,
68+
content=merged_content,
6469
step=self.agent.model.steps,
6570
)
66-
6771
self.short_term_memory.append(new_entry)
72+
self._current_step_entry = None
6873
self.step_content = {}
6974

7075
# Display the new entry
71-
if self.display:
76+
if self.display and new_entry is not None:
7277
new_entry.display()
7378

7479
def format_short_term(self) -> str:

mesa_llm/module_llm.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,25 @@ def __init__(
3737
Initialize the LLM module
3838
3939
Args:
40-
llm_model: The model to use for the LLM in the format of {provider}/{LLM}
40+
llm_model: The model to use for the LLM in the format
41+
"{provider}/{model}" (for example, "openai/gpt-4o").
4142
api_base: The API base to use if the LLM provider is Ollama
4243
system_prompt: The system prompt to use for the LLM
44+
45+
Raises:
46+
ValueError: If llm_model is not in the expected "{provider}/{model}"
47+
format, or if the provider API key is missing.
4348
"""
4449
self.api_base = api_base
4550
self.llm_model = llm_model
4651
self.system_prompt = system_prompt
52+
53+
if "/" not in llm_model:
54+
raise ValueError(
55+
f"Invalid model format '{llm_model}'. "
56+
"Expected '{provider}/{model}', e.g. 'openai/gpt-4o'."
57+
)
58+
4759
provider = self.llm_model.split("/")[0].upper()
4860

4961
if provider in ["OLLAMA", "OLLAMA_CHAT"]:

mesa_llm/reasoning/cot.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ class CoTReasoning(Reasoning):
1414
- **agent** (LLMAgent reference)
1515
1616
Methods:
17-
- **plan(prompt, obs=None, ttl=1, selected_tools=None)** → *Plan* - Generate synchronous plan with CoT reasoning
18-
- **async aplan(prompt, obs=None, ttl=1, selected_tools=None)** → *Plan* - Generate asynchronous plan with CoT reasoning
17+
- **plan(obs, ttl=1, prompt=None, selected_tools=None)** → *Plan* - Generate synchronous plan with CoT reasoning
18+
- **async aplan(obs, ttl=1, prompt=None, selected_tools=None)** → *Plan* - Generate asynchronous plan with CoT reasoning
1919
2020
Reasoning Format:
2121
Thought 1: [Initial reasoning based on observation]
@@ -87,9 +87,9 @@ def get_cot_system_prompt(self, obs: Observation) -> str:
8787

8888
def plan(
8989
self,
90-
obs: Observation,
91-
ttl: int = 1,
9290
prompt: str | None = None,
91+
obs: Observation | None = None,
92+
ttl: int = 1,
9393
selected_tools: list[str] | None = None,
9494
) -> Plan:
9595
"""
@@ -102,12 +102,17 @@ def plan(
102102
else:
103103
raise ValueError("No prompt provided and agent.step_prompt is None.")
104104

105+
if obs is None:
106+
obs = self.agent.generate_obs()
107+
105108
step = obs.step + 1
106109
llm = self.agent.llm
107110
obs_str = str(obs)
108111

109112
# Add current observation to memory (for record)
110-
self.agent.memory.add_to_memory(type="Observation", content=obs_str)
113+
self.agent.memory.add_to_memory(
114+
type="Observation", content={"content": obs_str}
115+
)
111116
system_prompt = self.get_cot_system_prompt(obs)
112117

113118
llm.system_prompt = system_prompt
@@ -118,7 +123,9 @@ def plan(
118123
)
119124

120125
chaining_message = rsp.choices[0].message.content
121-
self.agent.memory.add_to_memory(type="Plan", content=chaining_message)
126+
self.agent.memory.add_to_memory(
127+
type="Plan", content={"content": chaining_message}
128+
)
122129

123130
# Pass plan content to agent for display
124131
if hasattr(self.agent, "_step_display_data"):
@@ -131,25 +138,41 @@ def plan(
131138
tool_choice="required",
132139
)
133140
response_message = rsp.choices[0].message
134-
cot_plan = Plan(step=step, llm_plan=response_message, ttl=1)
141+
cot_plan = Plan(step=step, llm_plan=response_message, ttl=ttl)
135142

136-
self.agent.memory.add_to_memory(type="Plan-Execution", content=str(cot_plan))
143+
self.agent.memory.add_to_memory(
144+
type="Plan-Execution", content={"content": str(cot_plan)}
145+
)
137146

138147
return cot_plan
139148

140149
async def aplan(
141150
self,
142-
prompt: str,
143-
obs: Observation,
151+
prompt: str | None = None,
152+
obs: Observation | None = None,
144153
ttl: int = 1,
145154
selected_tools: list[str] | None = None,
146155
) -> Plan:
147156
"""
148157
Asynchronous version of plan() method for parallel planning.
149158
"""
159+
# If no prompt is provided, use the agent's default step prompt
160+
if prompt is None:
161+
if self.agent.step_prompt is not None:
162+
prompt = self.agent.step_prompt
163+
else:
164+
raise ValueError("No prompt provided and agent.step_prompt is None.")
165+
166+
if obs is None:
167+
obs = await self.agent.agenerate_obs()
168+
150169
step = obs.step + 1
151170
llm = self.agent.llm
152171

172+
obs_str = str(obs)
173+
await self.agent.memory.aadd_to_memory(
174+
type="Observation", content={"content": obs_str}
175+
)
153176
system_prompt = self.get_cot_system_prompt(obs)
154177
llm.system_prompt = system_prompt
155178

@@ -160,7 +183,9 @@ async def aplan(
160183
)
161184

162185
chaining_message = rsp.choices[0].message.content
163-
await self.agent.memory.aadd_to_memory(type="Plan", content=chaining_message)
186+
await self.agent.memory.aadd_to_memory(
187+
type="Plan", content={"content": chaining_message}
188+
)
164189

165190
# Pass plan content to agent for display
166191
if hasattr(self.agent, "_step_display_data"):
@@ -173,10 +198,10 @@ async def aplan(
173198
tool_choice="required",
174199
)
175200
response_message = rsp.choices[0].message
176-
cot_plan = Plan(step=step, llm_plan=response_message, ttl=1)
201+
cot_plan = Plan(step=step, llm_plan=response_message, ttl=ttl)
177202

178203
await self.agent.memory.aadd_to_memory(
179-
type="Plan-Execution", content=str(cot_plan)
204+
type="Plan-Execution", content={"content": str(cot_plan)}
180205
)
181206

182207
return cot_plan

0 commit comments

Comments
 (0)