2020from typing import Any , Optional
2121from uuid import uuid4
2222
23- from verl .experimental .agent_loop .agent_loop import AgentLoopBase , AgentLoopOutput , register
23+ from transformers import AutoProcessor , AutoTokenizer
24+
25+ from verl .experimental .agent_loop .agent_loop import (
26+ AgentLoopBase ,
27+ AgentLoopOutput ,
28+ AsyncLLMServerManager ,
29+ DictConfigWrap ,
30+ register ,
31+ )
2432from verl .experimental .agent_loop .tool_parser import FunctionCall , ToolParser
2533from verl .experimental .agent_loop .utils import add_generation_prompt_for_gpt_oss , format_gpt_oss_tool_response_manually
2634from verl .interactions .base import BaseInteraction
@@ -44,7 +52,8 @@ class AgentState(Enum):
4452
4553
4654class AgentData :
47- """Encapsulates all state variables for the agent loop."""
55+ """Encapsulates all state variables for the agent loop. AgentData is passed to tool calling in case that
56+ tool may need to access full history state. User can store any tool session data in `extra_fields`."""
4857
4958 def __init__ (
5059 self ,
@@ -77,44 +86,49 @@ def __init__(
7786 # Temporary state for tool calls
7887 self .tool_calls : list [FunctionCall ] = []
7988
80- # Extra fields for dynamic addition
89+ # Extra fields for dynamic addition, e.g., tool session data
8190 self .extra_fields : dict [str , Any ] = {}
8291
8392
8493@register ("tool_agent" )
8594class ToolAgentLoop (AgentLoopBase ):
86- @classmethod
87- def init_class (cls , config , tokenizer , processor , ** kwargs ):
88- if cls ._class_initialized :
89- return
90- cls ._class_initialized = True
91- print ("Performing class-level ToolAgentLoop initialization" )
95+ def __init__ (
96+ self ,
97+ trainer_config : DictConfigWrap ,
98+ server_manager : AsyncLLMServerManager ,
99+ tokenizer : AutoTokenizer ,
100+ processor : AutoProcessor ,
101+ ** kwargs ,
102+ ):
103+ super ().__init__ (trainer_config , server_manager , tokenizer , processor , ** kwargs )
104+ config = trainer_config .config
92105
93106 # Initialize tools from config file
94- cls .tokenizer = tokenizer
95- cls .processor = processor
96- cls .max_user_turns = config .actor_rollout_ref .rollout .multi_turn .max_user_turns
97- cls .max_assistant_turns = config .actor_rollout_ref .rollout .multi_turn .max_assistant_turns
98- cls .max_parallel_calls = config .actor_rollout_ref .rollout .multi_turn .max_parallel_calls
99- cls .max_tool_response_length = config .actor_rollout_ref .rollout .multi_turn .max_tool_response_length
100- cls .tool_response_truncate_side = config .actor_rollout_ref .rollout .multi_turn .tool_response_truncate_side
107+ self .max_user_turns = config .actor_rollout_ref .rollout .multi_turn .max_user_turns
108+ self .max_assistant_turns = config .actor_rollout_ref .rollout .multi_turn .max_assistant_turns
109+ self .max_parallel_calls = config .actor_rollout_ref .rollout .multi_turn .max_parallel_calls
110+ self .max_tool_response_length = config .actor_rollout_ref .rollout .multi_turn .max_tool_response_length
111+ self .tool_response_truncate_side = config .actor_rollout_ref .rollout .multi_turn .tool_response_truncate_side
101112 tool_config_path = config .actor_rollout_ref .rollout .multi_turn .tool_config_path
102113 tool_list = initialize_tools_from_config (tool_config_path ) if tool_config_path else []
103- cls .tools = {tool .name : tool for tool in tool_list }
104- cls .tool_schemas = [tool .tool_schema .model_dump (exclude_unset = True , exclude_none = True ) for tool in tool_list ]
105- cls .tool_parser = ToolParser .get_tool_parser (config .actor_rollout_ref .rollout .multi_turn .format , cls .tokenizer )
106- cls .tool_parser_name = config .actor_rollout_ref .rollout .multi_turn .format
107- print (f"Initialized tools: { cls .tools } " )
114+ self .tools = {tool .name : tool for tool in tool_list }
115+ self .tool_schemas = [tool .tool_schema .model_dump (exclude_unset = True , exclude_none = True ) for tool in tool_list ]
116+ self .tool_parser = ToolParser .get_tool_parser (
117+ config .actor_rollout_ref .rollout .multi_turn .format , self .tokenizer
118+ )
119+ self .tool_parser_name = config .actor_rollout_ref .rollout .multi_turn .format
108120
109- cls .apply_chat_template_kwargs = config .data .get ("apply_chat_template_kwargs" , {})
110- cls .prompt_length = config .actor_rollout_ref .rollout .prompt_length
111- cls .response_length = config .actor_rollout_ref .rollout .response_length
112- cls .system_prompt = initialize_system_prompt (cls .tokenizer , ** cls .apply_chat_template_kwargs )
121+ self .apply_chat_template_kwargs = config .data .get ("apply_chat_template_kwargs" , {})
122+ self .prompt_length = config .actor_rollout_ref .rollout .prompt_length
123+ self .response_length = config .actor_rollout_ref .rollout .response_length
124+ self .system_prompt = initialize_system_prompt (self .tokenizer , ** self .apply_chat_template_kwargs )
113125
114126 # Initialize interactions from config file
115- cls .interaction_config_file = config .actor_rollout_ref .rollout .multi_turn .interaction_config_path
116- if cls .interaction_config_file :
117- cls .interaction_map : dict [str , BaseInteraction ] = cls ._initialize_interactions (cls .interaction_config_file )
127+ self .interaction_config_file = config .actor_rollout_ref .rollout .multi_turn .interaction_config_path
128+ if self .interaction_config_file :
129+ self .interaction_map : dict [str , BaseInteraction ] = self ._initialize_interactions (
130+ self .interaction_config_file
131+ )
118132
119133 @rollout_trace_op
120134 async def run (self , sampling_params : dict [str , Any ], ** kwargs ) -> AgentLoopOutput :
@@ -271,7 +285,7 @@ async def _handle_processing_tools_state(self, agent_data: AgentData) -> AgentSt
271285 tasks = []
272286 tool_call_names = []
273287 for tool_call in agent_data .tool_calls [: self .max_parallel_calls ]:
274- tasks .append (self ._call_tool (tool_call , agent_data .tools_kwargs ))
288+ tasks .append (self ._call_tool (tool_call , agent_data .tools_kwargs , agent_data ))
275289 tool_call_names .append (tool_call .name )
276290
277291 with simple_timer ("tool_calls" , agent_data .metrics ):
@@ -434,7 +448,7 @@ async def _handle_interacting_state(self, agent_data: AgentData) -> AgentState:
434448 return AgentState .GENERATING
435449
436450 async def _call_tool (
437- self , tool_call : FunctionCall , tools_kwargs : dict [str , Any ]
451+ self , tool_call : FunctionCall , tools_kwargs : dict [str , Any ], agent_data : AgentData
438452 ) -> tuple [ToolResponse , float , dict ]:
439453 """Call tool and return tool response."""
440454 tool , instance_id = None , None
@@ -445,7 +459,9 @@ async def _call_tool(
445459 tool = self .tools [tool_name ]
446460 kwargs = tools_kwargs .get (tool_name , {})
447461 instance_id , _ = await tool .create (create_kwargs = kwargs .get ("create_kwargs" , {}))
448- tool_execution_response , tool_reward , res = await tool .execute (instance_id , tool_args )
462+ tool_execution_response , tool_reward , res = await tool .execute (
463+ instance_id , tool_args , agent_data = agent_data
464+ )
449465 except Exception as e :
450466 logger .warning (f"Error when executing tool: { e } " )
451467 return (
@@ -481,8 +497,7 @@ async def _call_tool(
481497
482498 return ToolResponse (** tool_response_kwargs ), tool_reward , res
483499
484- @classmethod
485- def _initialize_interactions (cls , interaction_config_file ):
500+ def _initialize_interactions (self , interaction_config_file ):
486501 """Initialize interactions from configuration.
487502 Returns:
488503 dict[str, BaseInteraction]: A dictionary mapping interaction names to interaction instances.
@@ -491,5 +506,4 @@ def _initialize_interactions(cls, interaction_config_file):
491506 return {}
492507
493508 interaction_map = initialize_interactions_from_config (interaction_config_file )
494- logger .info (f"Initialize interactions from configuration: interaction_map: { list (interaction_map .keys ())} " )
495509 return interaction_map
0 commit comments