Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 53 additions & 27 deletions src/smolagents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,52 @@ def system_prompt(self, value: str):
"""The 'system_prompt' property is read-only. Use 'self.prompt_templates["system_prompt"]' instead."""
)

def _populate_planning_initial_plan_template(self, task: str, step: int) -> str:
return populate_template(
self.prompt_templates["planning"]["initial_plan"],
variables={"task": task, "tools": self.tools, "managed_agents": self.managed_agents},
)

def _populate_planning_update_plan_pre_messages_template(self, task: str, step: int) -> str:
return populate_template(
self.prompt_templates["planning"]["update_plan_pre_messages"], variables={"task": task}
)

def _populate_planning_update_plan_post_messages_template(self, task: str, step: int) -> str:
return populate_template(
self.prompt_templates["planning"]["update_plan_post_messages"],
variables={
"task": task,
"tools": self.tools,
"managed_agents": self.managed_agents,
"remaining_steps": (self.max_steps - step),
},
)

def _populate_final_answer_pre_messages_template(self, task: str) -> str:
return populate_template(
self.prompt_templates["final_answer"]["pre_messages"],
variables={"task": task},
)

def _populate_final_answer_post_messages_template(self, task: str) -> str:
return populate_template(
self.prompt_templates["final_answer"]["post_messages"],
variables={"task": task},
)

def _populate_managed_agent_task_template(self, agent_name: str, task: str) -> str:
return populate_template(
self.prompt_templates["managed_agent"]["task"],
variables={"name": agent_name, "task": task},
)

def _populate_managed_agent_report_template(self, agent_name: str, answer: str) -> str:
return populate_template(
self.prompt_templates["managed_agent"]["report"],
variables={"name": agent_name, "final_answer": answer},
)

def _validate_name(self, name: str | None) -> str | None:
if name is not None and not is_valid_name(name):
raise ValueError(f"Agent name '{name}' must be a valid Python identifier and not a reserved keyword.")
Expand Down Expand Up @@ -647,10 +693,7 @@ def _generate_planning_step(
content=[
{
"type": "text",
"text": populate_template(
self.prompt_templates["planning"]["initial_plan"],
variables={"task": task, "tools": self.tools, "managed_agents": self.managed_agents},
),
"text": self._populate_planning_initial_plan_template(task, step),
}
],
)
Expand Down Expand Up @@ -687,9 +730,7 @@ def _generate_planning_step(
content=[
{
"type": "text",
"text": populate_template(
self.prompt_templates["planning"]["update_plan_pre_messages"], variables={"task": task}
),
"text": self._populate_planning_update_plan_pre_messages_template(task, step),
}
],
)
Expand All @@ -698,15 +739,7 @@ def _generate_planning_step(
content=[
{
"type": "text",
"text": populate_template(
self.prompt_templates["planning"]["update_plan_post_messages"],
variables={
"task": task,
"tools": self.tools,
"managed_agents": self.managed_agents,
"remaining_steps": (self.max_steps - step),
},
),
"text": self._populate_planning_update_plan_post_messages_template(task, step),
}
],
)
Expand Down Expand Up @@ -824,7 +857,7 @@ def provide_final_answer(self, task: str) -> ChatMessage:
content=[
{
"type": "text",
"text": self.prompt_templates["final_answer"]["pre_messages"],
"text": self._populate_final_answer_pre_messages_template(task),
}
],
)
Expand All @@ -836,9 +869,7 @@ def provide_final_answer(self, task: str) -> ChatMessage:
content=[
{
"type": "text",
"text": populate_template(
self.prompt_templates["final_answer"]["post_messages"], variables={"task": task}
),
"text": self._populate_final_answer_post_messages_template(task),
}
],
)
Expand Down Expand Up @@ -869,18 +900,13 @@ def __call__(self, task: str, **kwargs):
"""Adds additional prompting for the managed agent, runs it, and wraps the output.
This method is called only by a managed agent.
"""
full_task = populate_template(
self.prompt_templates["managed_agent"]["task"],
variables=dict(name=self.name, task=task),
)
full_task = self._populate_managed_agent_task_template(self.name, task)
result = self.run(full_task, **kwargs)
if isinstance(result, RunResult):
report = result.output
else:
report = result
answer = populate_template(
self.prompt_templates["managed_agent"]["report"], variables=dict(name=self.name, final_answer=report)
)
answer = self._populate_managed_agent_report_template(self.name, report)
if self.provide_run_summary:
answer += "\n\nFor more detail, find below a summary of this agent's work:\n<summary_of_work>\n"
for message in self.write_memory_to_messages(summary_mode=True):
Expand Down