Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions mesa_llm/memory/episodic_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING

from pydantic import BaseModel
from terminal_style import style

from mesa_llm.memory.memory import Memory, MemoryEntry

Expand Down Expand Up @@ -152,7 +153,14 @@ def grade_event_importance(self, type: str, content: dict) -> float:
response_format=EventGrade,
)

formatted_response = json.loads(rsp.choices[0].message.content)
# Parse JSON arguments safely
try:
formatted_response = json.loads(rsp.choices[0].message.content)
except json.JSONDecodeError as e:
raise ValueError(
style(f"Invalid JSON response returned by the model: {e}", color="red")
) from e

return formatted_response["grade"]

async def agrade_event_importance(self, type: str, content: dict) -> float:
Expand All @@ -167,7 +175,14 @@ async def agrade_event_importance(self, type: str, content: dict) -> float:
response_format=EventGrade,
)

formatted_response = json.loads(rsp.choices[0].message.content)
# Parse JSON arguments safely
try:
formatted_response = json.loads(rsp.choices[0].message.content)
except json.JSONDecodeError as e:
raise ValueError(
style(f"Invalid JSON response returned by the model: {e}", color="red")
) from e

return formatted_response["grade"]

def retrieve_top_k_entries(self, k: int) -> list[MemoryEntry]:
Expand Down
17 changes: 15 additions & 2 deletions mesa_llm/reasoning/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import TYPE_CHECKING

from pydantic import BaseModel
from terminal_style import style

from mesa_llm.reasoning.reasoning import Observation, Plan, Reasoning

Expand Down Expand Up @@ -94,7 +95,13 @@ def plan(
response_format=ReActOutput,
)

formatted_response = json.loads(rsp.choices[0].message.content)
# Parse JSON arguments safely
try:
formatted_response = json.loads(rsp.choices[0].message.content)
except json.JSONDecodeError as e:
raise ValueError(
style(f"Invalid JSON response returned by the model: {e}", color="red")
) from e

self.agent.memory.add_to_memory(type="plan", content=formatted_response)

Expand Down Expand Up @@ -145,7 +152,13 @@ async def aplan(
response_format=ReActOutput,
)

formatted_response = json.loads(rsp.choices[0].message.content)
# Parse JSON arguments safely
try:
formatted_response = json.loads(rsp.choices[0].message.content)
except json.JSONDecodeError as e:
raise ValueError(
style(f"Invalid JSON response returned by the model: {e}", color="red")
) from e

await self.agent.memory.aadd_to_memory(type="plan", content=formatted_response)

Expand Down
12 changes: 11 additions & 1 deletion mesa_llm/recording/agent_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from rich.prompt import Prompt
from rich.table import Table

# Aliased as we have a 'style' variable used in this file
from terminal_style import style as terminal_style


class AgentViewer:
"""
Expand All @@ -40,7 +43,14 @@ def _load_recording(self):
return pickle.load(f) # noqa: S301
else:
with open(self.recording_path) as f:
return json.load(f)
try:
return json.load(f)
except json.JSONDecodeError as e:
raise ValueError(
terminal_style(
f"Invalid JSON object from the file: {e}", color="red"
)
) from e

def _organize_events_by_agent(self):
"""Organize events by agent ID."""
Expand Down