From 79261097f6563d7d0fe79d05c96b6061acd0d547 Mon Sep 17 00:00:00 2001 From: gael-ft <24495136+gael-ft@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:47:33 +0200 Subject: [PATCH] Isolate template generation to override --- src/smolagents/agents.py | 80 ++++++++++++++++++++++++++-------------- 1 file changed, 53 insertions(+), 27 deletions(-) diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index cdda9aca8..ebad28c3d 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -365,6 +365,52 @@ def system_prompt(self, value: str): """The 'system_prompt' property is read-only. Use 'self.prompt_templates["system_prompt"]' instead.""" ) + def _populate_planning_initial_plan_template(self, task: str, step: int) -> str: + return populate_template( + self.prompt_templates["planning"]["initial_plan"], + variables={"task": task, "tools": self.tools, "managed_agents": self.managed_agents}, + ) + + def _populate_planning_update_plan_pre_messages_template(self, task: str, step: int) -> str: + return populate_template( + self.prompt_templates["planning"]["update_plan_pre_messages"], variables={"task": task} + ) + + def _populate_planning_update_plan_post_messages_template(self, task: str, step: int) -> str: + return populate_template( + self.prompt_templates["planning"]["update_plan_post_messages"], + variables={ + "task": task, + "tools": self.tools, + "managed_agents": self.managed_agents, + "remaining_steps": (self.max_steps - step), + }, + ) + + def _populate_final_answer_pre_messages_template(self, task: str) -> str: + return populate_template( + self.prompt_templates["final_answer"]["pre_messages"], + variables={"task": task}, + ) + + def _populate_final_answer_post_messages_template(self, task: str) -> str: + return populate_template( + self.prompt_templates["final_answer"]["post_messages"], + variables={"task": task}, + ) + + def _populate_managed_agent_task_template(self, agent_name: str, task: str) -> str: + return populate_template( + self.prompt_templates["managed_agent"]["task"], + variables={"name": agent_name, "task": task}, + ) + + def _populate_managed_agent_report_template(self, agent_name: str, answer: str) -> str: + return populate_template( + self.prompt_templates["managed_agent"]["report"], + variables={"name": agent_name, "final_answer": answer}, + ) + def _validate_name(self, name: str | None) -> str | None: if name is not None and not is_valid_name(name): raise ValueError(f"Agent name '{name}' must be a valid Python identifier and not a reserved keyword.") @@ -647,10 +693,7 @@ def _generate_planning_step( content=[ { "type": "text", - "text": populate_template( - self.prompt_templates["planning"]["initial_plan"], - variables={"task": task, "tools": self.tools, "managed_agents": self.managed_agents}, - ), + "text": self._populate_planning_initial_plan_template(task, step), } ], ) @@ -687,9 +730,7 @@ def _generate_planning_step( content=[ { "type": "text", - "text": populate_template( - self.prompt_templates["planning"]["update_plan_pre_messages"], variables={"task": task} - ), + "text": self._populate_planning_update_plan_pre_messages_template(task, step), } ], ) @@ -698,15 +739,7 @@ def _generate_planning_step( content=[ { "type": "text", - "text": populate_template( - self.prompt_templates["planning"]["update_plan_post_messages"], - variables={ - "task": task, - "tools": self.tools, - "managed_agents": self.managed_agents, - "remaining_steps": (self.max_steps - step), - }, - ), + "text": self._populate_planning_update_plan_post_messages_template(task, step), } ], ) @@ -824,7 +857,7 @@ def provide_final_answer(self, task: str) -> ChatMessage: content=[ { "type": "text", - "text": self.prompt_templates["final_answer"]["pre_messages"], + "text": self._populate_final_answer_pre_messages_template(task), } ], ) @@ -836,9 +869,7 @@ def provide_final_answer(self, task: str) -> ChatMessage: content=[ { "type": "text", - "text": populate_template( - self.prompt_templates["final_answer"]["post_messages"], variables={"task": task} - ), + "text": self._populate_final_answer_post_messages_template(task), } ], ) @@ -869,18 +900,13 @@ def __call__(self, task: str, **kwargs): """Adds additional prompting for the managed agent, runs it, and wraps the output. This method is called only by a managed agent. """ - full_task = populate_template( - self.prompt_templates["managed_agent"]["task"], - variables=dict(name=self.name, task=task), - ) + full_task = self._populate_managed_agent_task_template(self.name, task) result = self.run(full_task, **kwargs) if isinstance(result, RunResult): report = result.output else: report = result - answer = populate_template( - self.prompt_templates["managed_agent"]["report"], variables=dict(name=self.name, final_answer=report) - ) + answer = self._populate_managed_agent_report_template(self.name, report) if self.provide_run_summary: answer += "\n\nFor more detail, find below a summary of this agent's work:\n\n" for message in self.write_memory_to_messages(summary_mode=True):