Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions mesa_llm/reasoning/rewoo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from typing import TYPE_CHECKING

from mesa_llm.reasoning.reasoning import (
Expand Down Expand Up @@ -30,6 +31,7 @@ def __init__(self, agent: "LLMAgent"):
self.remaining_tool_calls = 0 # Initialize remaining tool calls
self.current_plan: Plan | None = None
self.current_obs: Observation | None = None
self._all_tool_calls = [] # Store original tool calls to avoid mutation during replay

def get_rewoo_system_prompt(self, obs: Observation) -> str:
memory = getattr(self.agent, "memory", None)
Expand Down Expand Up @@ -110,14 +112,16 @@ def plan(
"""
# If we have remaining tool calls, skip observation and plan generation
if self.remaining_tool_calls > 0:
index_of_tool = (
len(self.current_plan.tool_calls) - self.remaining_tool_calls
)
# _all_tool_calls is the single source of truth; fall back to current_plan
# if state was set externally without going through plan generation
all_calls = self._all_tool_calls or self.current_plan.tool_calls
index_of_tool = len(all_calls) - self.remaining_tool_calls
self.remaining_tool_calls -= 1
tool_call = [self.current_plan.tool_calls[index_of_tool]]
current_plan = self.current_plan
current_plan.tool_calls = tool_call
return Plan(llm_plan=current_plan, step=self.current_obs.step, ttl=ttl)
tool_call = [all_calls[index_of_tool]]
# Return a plan with only the required tool call, without mutating current_plan
temp_plan = copy.copy(self.current_plan)
temp_plan.tool_calls = tool_call
return Plan(llm_plan=temp_plan, step=self.current_obs.step, ttl=ttl)

# If no prompt is provided, use the agent's default step prompt
if prompt is None:
Expand Down Expand Up @@ -152,8 +156,11 @@ def plan(
# Count the number of tool calls in the response and set remaining_tool_calls
if hasattr(rewoo_plan.llm_plan, "tool_calls"):
self.remaining_tool_calls = len(rewoo_plan.llm_plan.tool_calls)
# Store a copy of the original tool calls to avoid mutation during replay
self._all_tool_calls = list(rewoo_plan.llm_plan.tool_calls)
else:
self.remaining_tool_calls = 0
self._all_tool_calls = []
self.current_plan = rewoo_plan.llm_plan

return rewoo_plan
Expand All @@ -170,14 +177,16 @@ async def aplan(
"""
# If we have remaining tool calls, skip observation and plan generation
if self.remaining_tool_calls > 0:
index_of_tool = (
len(self.current_plan.tool_calls) - self.remaining_tool_calls
)
# _all_tool_calls is the single source of truth; fall back to current_plan
# if state was set externally without going through plan generation
all_calls = self._all_tool_calls or self.current_plan.tool_calls
index_of_tool = len(all_calls) - self.remaining_tool_calls
self.remaining_tool_calls -= 1
tool_call = [self.current_plan.tool_calls[index_of_tool]]
current_plan = self.current_plan
current_plan.tool_calls = tool_call
return Plan(llm_plan=current_plan, step=self.current_obs.step, ttl=ttl)
tool_call = [all_calls[index_of_tool]]
# Return a plan with only the required tool call, without mutating current_plan
temp_plan = copy.copy(self.current_plan)
temp_plan.tool_calls = tool_call
return Plan(llm_plan=temp_plan, step=self.current_obs.step, ttl=ttl)

# If no prompt is provided, use the agent's default step prompt
if prompt is None:
Expand Down Expand Up @@ -212,8 +221,11 @@ async def aplan(
# Count the number of tool calls in the response and set remaining_tool_calls
if hasattr(rewoo_plan.llm_plan, "tool_calls"):
self.remaining_tool_calls = len(rewoo_plan.llm_plan.tool_calls)
# Store a copy of the original tool calls to avoid mutation during replay
self._all_tool_calls = list(rewoo_plan.llm_plan.tool_calls)
else:
self.remaining_tool_calls = 0
self._all_tool_calls = []
self.current_plan = rewoo_plan.llm_plan

return rewoo_plan