Skip to content

Commit ebec85d

Browse files
authored
[rollout] feat: pass agent_data to tool calling (verl-project#4469)
1 parent baf3a63 commit ebec85d

File tree

4 files changed

+59
-72
lines changed

4 files changed

+59
-72
lines changed

recipe/fully_async_policy/agent_loop/agent_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
AgentLoopOutput,
2828
AgentLoopWorkerBase,
2929
AsyncLLMServerManager,
30+
DictConfigWrap,
3031
_agent_loop_registry,
31-
_DummyConfig,
3232
get_trajectory_info,
3333
)
3434
from verl.experimental.agent_loop.prometheus_utils import update_prometheus_config
@@ -182,7 +182,7 @@ async def _partial_run_agent_loop(
182182
agent_loop_config = _agent_loop_registry[agent_name]
183183
agent_loop = hydra.utils.instantiate(
184184
config=agent_loop_config,
185-
trainer_config=_DummyConfig(config=self.config),
185+
trainer_config=DictConfigWrap(config=self.config),
186186
server_manager=self.server_manager,
187187
tokenizer=self.tokenizer,
188188
processor=self.processor,

verl/experimental/agent_loop/agent_loop.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -175,21 +175,20 @@ class _InternalAgentLoopOutput(AgentLoopOutput):
175175
"""Extra fields for dynamic addition."""
176176

177177

178-
# make hydra.utils.instantiate happy
179-
class _DummyConfig:
180-
def __init__(self, config: DictConfig) -> None:
178+
class DictConfigWrap:
179+
"""Wrapper for DictConfig to avoid hydra.utils.instantiate recursive resolve."""
180+
181+
def __init__(self, config: DictConfig):
181182
self.config = config
182183

183184

184185
class AgentLoopBase(ABC):
185186
"""An agent loop takes an input message, chat with OpenAI compatible LLM server and interact with various
186187
environments."""
187188

188-
_class_initialized = False
189-
190189
def __init__(
191190
self,
192-
trainer_config: _DummyConfig,
191+
trainer_config: DictConfigWrap,
193192
server_manager: AsyncLLMServerManager,
194193
tokenizer: AutoTokenizer,
195194
processor: AutoProcessor,
@@ -198,32 +197,17 @@ def __init__(
198197
"""Initialize agent loop, each sample will have its own loop instance.
199198
200199
Args:
201-
trainer_config (_DummyConfig): trainer config.
200+
trainer_config (DictConfigWrap): trainer config.
202201
server_manager (AsyncLLMServerManager): OpenAI compatible LLM server manager.
203202
tokenizer (AutoTokenizer): Tokenizer for tokenize messages.
204203
processor (AutoProcessor): Processor for process messages.
205204
"""
206-
self.init_class(config=trainer_config.config, tokenizer=tokenizer, processor=processor, **kwargs)
207205
self.config = trainer_config.config
208206
self.server_manager = server_manager
209207
self.tokenizer = tokenizer
210208
self.processor = processor
211209
self.loop = asyncio.get_running_loop()
212210

213-
@classmethod
214-
def init_class(cls, config: DictConfig, tokenizer: AutoTokenizer, processor: AutoProcessor, **kwargs):
215-
"""This is used to do heavy initialization work that should shared across all instances. It's only called once.
216-
217-
Args:
218-
config (DictConfig): trainer config.
219-
tokenizer (AutoTokenizer): Tokenizer for tokenize messages.
220-
processor (AutoProcessor): Processor for process multi_modal data.
221-
**kwargs: extra kwargs from config file passed in by `hydra.utils.instantiate`.
222-
"""
223-
if cls._class_initialized:
224-
return
225-
cls._class_initialized = True
226-
227211
@abstractmethod
228212
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
229213
"""Run agent loop to interact with LLM server and environment.
@@ -420,7 +404,7 @@ async def _run_agent_loop(
420404
agent_loop_config = _agent_loop_registry[agent_name]
421405
agent_loop = hydra.utils.instantiate(
422406
config=agent_loop_config,
423-
trainer_config=_DummyConfig(config=self.config),
407+
trainer_config=DictConfigWrap(config=self.config),
424408
server_manager=self.server_manager,
425409
tokenizer=self.tokenizer,
426410
processor=self.processor,

verl/experimental/agent_loop/tool_agent_loop.py

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,15 @@
2020
from typing import Any, Optional
2121
from uuid import uuid4
2222

23-
from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput, register
23+
from transformers import AutoProcessor, AutoTokenizer
24+
25+
from verl.experimental.agent_loop.agent_loop import (
26+
AgentLoopBase,
27+
AgentLoopOutput,
28+
AsyncLLMServerManager,
29+
DictConfigWrap,
30+
register,
31+
)
2432
from verl.experimental.agent_loop.tool_parser import FunctionCall, ToolParser
2533
from verl.experimental.agent_loop.utils import add_generation_prompt_for_gpt_oss, format_gpt_oss_tool_response_manually
2634
from verl.interactions.base import BaseInteraction
@@ -44,7 +52,8 @@ class AgentState(Enum):
4452

4553

4654
class AgentData:
47-
"""Encapsulates all state variables for the agent loop."""
55+
"""Encapsulates all state variables for the agent loop. AgentData is passed to tool calling in case that
56+
tool may need to access full history state. User can store any tool session data in `extra_fields`."""
4857

4958
def __init__(
5059
self,
@@ -77,44 +86,49 @@ def __init__(
7786
# Temporary state for tool calls
7887
self.tool_calls: list[FunctionCall] = []
7988

80-
# Extra fields for dynamic addition
89+
# Extra fields for dynamic addition, e.g., tool session data
8190
self.extra_fields: dict[str, Any] = {}
8291

8392

8493
@register("tool_agent")
8594
class ToolAgentLoop(AgentLoopBase):
86-
@classmethod
87-
def init_class(cls, config, tokenizer, processor, **kwargs):
88-
if cls._class_initialized:
89-
return
90-
cls._class_initialized = True
91-
print("Performing class-level ToolAgentLoop initialization")
95+
def __init__(
96+
self,
97+
trainer_config: DictConfigWrap,
98+
server_manager: AsyncLLMServerManager,
99+
tokenizer: AutoTokenizer,
100+
processor: AutoProcessor,
101+
**kwargs,
102+
):
103+
super().__init__(trainer_config, server_manager, tokenizer, processor, **kwargs)
104+
config = trainer_config.config
92105

93106
# Initialize tools from config file
94-
cls.tokenizer = tokenizer
95-
cls.processor = processor
96-
cls.max_user_turns = config.actor_rollout_ref.rollout.multi_turn.max_user_turns
97-
cls.max_assistant_turns = config.actor_rollout_ref.rollout.multi_turn.max_assistant_turns
98-
cls.max_parallel_calls = config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls
99-
cls.max_tool_response_length = config.actor_rollout_ref.rollout.multi_turn.max_tool_response_length
100-
cls.tool_response_truncate_side = config.actor_rollout_ref.rollout.multi_turn.tool_response_truncate_side
107+
self.max_user_turns = config.actor_rollout_ref.rollout.multi_turn.max_user_turns
108+
self.max_assistant_turns = config.actor_rollout_ref.rollout.multi_turn.max_assistant_turns
109+
self.max_parallel_calls = config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls
110+
self.max_tool_response_length = config.actor_rollout_ref.rollout.multi_turn.max_tool_response_length
111+
self.tool_response_truncate_side = config.actor_rollout_ref.rollout.multi_turn.tool_response_truncate_side
101112
tool_config_path = config.actor_rollout_ref.rollout.multi_turn.tool_config_path
102113
tool_list = initialize_tools_from_config(tool_config_path) if tool_config_path else []
103-
cls.tools = {tool.name: tool for tool in tool_list}
104-
cls.tool_schemas = [tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list]
105-
cls.tool_parser = ToolParser.get_tool_parser(config.actor_rollout_ref.rollout.multi_turn.format, cls.tokenizer)
106-
cls.tool_parser_name = config.actor_rollout_ref.rollout.multi_turn.format
107-
print(f"Initialized tools: {cls.tools}")
114+
self.tools = {tool.name: tool for tool in tool_list}
115+
self.tool_schemas = [tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list]
116+
self.tool_parser = ToolParser.get_tool_parser(
117+
config.actor_rollout_ref.rollout.multi_turn.format, self.tokenizer
118+
)
119+
self.tool_parser_name = config.actor_rollout_ref.rollout.multi_turn.format
108120

109-
cls.apply_chat_template_kwargs = config.data.get("apply_chat_template_kwargs", {})
110-
cls.prompt_length = config.actor_rollout_ref.rollout.prompt_length
111-
cls.response_length = config.actor_rollout_ref.rollout.response_length
112-
cls.system_prompt = initialize_system_prompt(cls.tokenizer, **cls.apply_chat_template_kwargs)
121+
self.apply_chat_template_kwargs = config.data.get("apply_chat_template_kwargs", {})
122+
self.prompt_length = config.actor_rollout_ref.rollout.prompt_length
123+
self.response_length = config.actor_rollout_ref.rollout.response_length
124+
self.system_prompt = initialize_system_prompt(self.tokenizer, **self.apply_chat_template_kwargs)
113125

114126
# Initialize interactions from config file
115-
cls.interaction_config_file = config.actor_rollout_ref.rollout.multi_turn.interaction_config_path
116-
if cls.interaction_config_file:
117-
cls.interaction_map: dict[str, BaseInteraction] = cls._initialize_interactions(cls.interaction_config_file)
127+
self.interaction_config_file = config.actor_rollout_ref.rollout.multi_turn.interaction_config_path
128+
if self.interaction_config_file:
129+
self.interaction_map: dict[str, BaseInteraction] = self._initialize_interactions(
130+
self.interaction_config_file
131+
)
118132

119133
@rollout_trace_op
120134
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
@@ -271,7 +285,7 @@ async def _handle_processing_tools_state(self, agent_data: AgentData) -> AgentSt
271285
tasks = []
272286
tool_call_names = []
273287
for tool_call in agent_data.tool_calls[: self.max_parallel_calls]:
274-
tasks.append(self._call_tool(tool_call, agent_data.tools_kwargs))
288+
tasks.append(self._call_tool(tool_call, agent_data.tools_kwargs, agent_data))
275289
tool_call_names.append(tool_call.name)
276290

277291
with simple_timer("tool_calls", agent_data.metrics):
@@ -434,7 +448,7 @@ async def _handle_interacting_state(self, agent_data: AgentData) -> AgentState:
434448
return AgentState.GENERATING
435449

436450
async def _call_tool(
437-
self, tool_call: FunctionCall, tools_kwargs: dict[str, Any]
451+
self, tool_call: FunctionCall, tools_kwargs: dict[str, Any], agent_data: AgentData
438452
) -> tuple[ToolResponse, float, dict]:
439453
"""Call tool and return tool response."""
440454
tool, instance_id = None, None
@@ -445,7 +459,9 @@ async def _call_tool(
445459
tool = self.tools[tool_name]
446460
kwargs = tools_kwargs.get(tool_name, {})
447461
instance_id, _ = await tool.create(create_kwargs=kwargs.get("create_kwargs", {}))
448-
tool_execution_response, tool_reward, res = await tool.execute(instance_id, tool_args)
462+
tool_execution_response, tool_reward, res = await tool.execute(
463+
instance_id, tool_args, agent_data=agent_data
464+
)
449465
except Exception as e:
450466
logger.warning(f"Error when executing tool: {e}")
451467
return (
@@ -481,8 +497,7 @@ async def _call_tool(
481497

482498
return ToolResponse(**tool_response_kwargs), tool_reward, res
483499

484-
@classmethod
485-
def _initialize_interactions(cls, interaction_config_file):
500+
def _initialize_interactions(self, interaction_config_file):
486501
"""Initialize interactions from configuration.
487502
Returns:
488503
dict[str, BaseInteraction]: A dictionary mapping interaction names to interaction instances.
@@ -491,5 +506,4 @@ def _initialize_interactions(cls, interaction_config_file):
491506
return {}
492507

493508
interaction_map = initialize_interactions_from_config(interaction_config_file)
494-
logger.info(f"Initialize interactions from configuration: interaction_map: {list(interaction_map.keys())}")
495509
return interaction_map

verl/workers/rollout/replica.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from enum import Enum
1919
from typing import Any, Callable, Optional
2020

21-
from omegaconf import DictConfig, OmegaConf
21+
from omegaconf import DictConfig
2222
from pydantic import BaseModel
2323
from ray.actor import ActorHandle
2424

@@ -90,18 +90,7 @@ def __init__(
9090
) -> None:
9191
self.replica_rank = replica_rank
9292
self.config = omega_conf_to_dataclass(config)
93-
# TODO: make lora config irrelevant to the model engine choice
94-
# Convert megatron lora config to HFModelConfig
95-
# If model_config is not an OmegaConf object, convert it first
96-
if OmegaConf.is_config(model_config):
97-
model_config_dict = OmegaConf.to_container(model_config)
98-
model_config_dict.pop("lora", None)
99-
100-
self.model_config: HFModelConfig = omega_conf_to_dataclass(
101-
OmegaConf.create(model_config_dict), dataclass_type=HFModelConfig
102-
)
103-
else:
104-
self.model_config: HFModelConfig = model_config
93+
self.model_config: HFModelConfig = model_config
10594

10695
self.world_size = (
10796
self.config.tensor_model_parallel_size

0 commit comments

Comments
 (0)