Skip to content

Commit d4abd9f

Browse files
authored
Merge pull request #3 from Agent-One-Lab/agents
More powerful template system
2 parents 62402a6 + c8af7ea commit d4abd9f

26 files changed

+2533
-888
lines changed

agents/agents/agents/agent_base.py

Lines changed: 74 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,28 @@
44

55
from .templates.templates import get_template
66
from ..__init__ import AGENT_DATA_DIR
7-
from .llm_backend import AsyncVLLMBackend, AsyncVerlBackend, ClientBackend, TransformersBackend, VLLMBackend, VerlBackend
7+
from .llm_backend import AsyncVLLMBackend, AsyncVerlBackend, ClientBackend, TransformersBackend, VLLMBackend
88
from ..utils.logging import get_logger
99
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1010
import numpy as np
1111
import torch
12-
from .templates.utils import is_vlm_template, tokenize_conversations
12+
from .templates.utils import tokenize_conversations
13+
from .templates.vision_processor import is_vision_template
1314
from .chain.chain_base import ChainGeneration
1415
import os
1516
import transformers
1617
import warnings
18+
import logging
1719
from .chain.streaming_observer import ConsoleStreamObserver, StreamingManager
20+
from .utils.tokenizer import create_tokenizer
21+
from .backend_config import BACKEND_CONFIGS
1822
try:
1923
from verl.protocol import DataProto
2024
except ImportError:
2125
print("verl can not be imported.")
2226
pass
2327

28+
Logger = logging.getLogger(__name__)
2429

2530
class BaseAgent(ChainGeneration, ABC):
2631
"""
@@ -34,12 +39,13 @@ class BaseAgent(ChainGeneration, ABC):
3439
def __init__(
3540
self,
3641
model_name_or_path,
37-
template: str,
42+
template: str=None,
3843
system_prompt: str = None,
3944
tools: List = None,
4045
max_length: int=8192,
4146
debug: bool = False,
4247
backend: str = "transformers",
48+
backend_config: Any = None,
4349
reward_fn: Callable = None,
4450
log_file: str = "agent",
4551
project_name: str = None,
@@ -65,9 +71,30 @@ def __init__(
6571
self.tools = tools
6672
self.system_prompt = system_prompt
6773
self.model_name_or_path = model_name_or_path
68-
self.llm_engine, self.tokenizer, self.processor = self._init_llm_engine(model_name_or_path, backend)
74+
75+
# Handle backend configuration
76+
if backend_config is None:
77+
# Use default configuration for the backend
78+
config_class = BACKEND_CONFIGS.get(backend)
79+
if config_class:
80+
self.backend_config = config_class()
81+
else:
82+
self.backend_config = None
83+
else:
84+
self.backend_config = backend_config
85+
86+
self.llm_engine = self._init_llm_engine(model_name_or_path, backend)
87+
88+
# Create appropriate tokenizer for trajectory processing
89+
self.tokenizer = create_tokenizer(model_name_or_path)
90+
6991
self._reward_fn = reward_fn
70-
self.jinja_template = get_template(self.template).jinja_template()
92+
93+
if self.template is None:
94+
self.jinja_template = None
95+
else:
96+
self.jinja_template = get_template(self.template).jinja_template()
97+
7198
self.project_name = project_name
7299
self.run_name = run_name
73100
self.streaming_manager = StreamingManager()
@@ -78,38 +105,59 @@ def __init__(
78105
raise ValueError(f"Streaming mode {streaming} is not supported.")
79106
super().__init__()
80107
if kwargs:
81-
warnings.warn(f"Unused arguments for agent initialization: {kwargs}")
108+
# warnings.warn(f"Unused arguments for agent initialization: {kwargs}")
109+
raise ValueError(f"Unused arguments for agent initialization: {kwargs}")
82110

83111
def _init_llm_engine(self, model_name_or_path: str, backend: str):
84112
if isinstance(model_name_or_path, str):
113+
# Extract backend-specific configuration
114+
config_kwargs = {}
115+
if self.backend_config:
116+
config_kwargs = {k: v for k, v in self.backend_config.__dict__.items()
117+
if not k.startswith('_')}
118+
85119
if backend == "transformers":
86-
llm_engine = TransformersBackend(model_name_or_path, self.template, max_length=self.max_length)
87-
elif backend == "vllm":
88-
llm_engine = VLLMBackend(model_name_or_path, self.template, max_length=self.max_length)
120+
llm_engine = TransformersBackend(
121+
model_name_or_path,
122+
self.template,
123+
max_length=self.max_length,
124+
**config_kwargs
125+
)
89126
elif backend == "async_vllm":
90-
llm_engine = AsyncVLLMBackend(model_name_or_path, self.template, max_length=self.max_length)
91-
elif backend == "verl":
92-
llm_engine = VerlBackend(llm_engine=None, model_name_or_path=model_name_or_path, template=self.template, max_length=self.max_length)
127+
llm_engine = AsyncVLLMBackend(
128+
model_name_or_path,
129+
self.template,
130+
max_length=self.max_length,
131+
**config_kwargs
132+
)
93133
elif backend == "async_verl":
94-
llm_engine = AsyncVerlBackend(llm_engine=None, model_name_or_path=model_name_or_path, template=self.template, max_length=self.max_length)
134+
llm_engine = AsyncVerlBackend(
135+
llm_engine=None,
136+
model_name_or_path=model_name_or_path,
137+
template=self.template,
138+
max_length=self.max_length,
139+
**config_kwargs
140+
)
95141
elif backend == "client":
96-
llm_engine = ClientBackend(model_name_or_path, self.template, max_length=self.max_length)
142+
print(f"config_kwargs: {config_kwargs}")
143+
llm_engine = ClientBackend(
144+
model_name_or_path,
145+
self.template,
146+
max_length=self.max_length,
147+
**config_kwargs
148+
)
97149
else:
98150
raise ValueError(f"Backend {backend} is not supported.")
99151
else:
100152
raise ValueError("model_name_or_path must be a string.")
101153

102-
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)
103-
if is_vlm_template(self.template):
104-
processor = transformers.AutoProcessor.from_pretrained(model_name_or_path)
105-
else:
106-
processor = None
107-
return llm_engine, tokenizer, processor
154+
return llm_engine
108155

109-
def set_llm_engine(self, llm_engine: Any, tokenizer: Any):
156+
def set_llm_engine(self, llm_engine: Any, tokenizer: Any, processor: Any):
110157
assert self.backend == "async_verl", "Only async verl backend is supported for now"
111158
self.llm_engine.llm_engine = llm_engine
112159
self.tokenizer = tokenizer
160+
self.processor = processor
113161

114162
def generate(self, messages_list_or_inputs: List[List[Dict]], **args):
115163
return self.llm_engine.generate(messages_list_or_inputs, **args)
@@ -151,25 +199,17 @@ async def generate_streaming(self, messages_list_or_inputs: List[List[Dict]], st
151199
@property
152200
def timing_data(self):
153201
return self.timer.timing_data
154-
155-
def forward(self, messages_list_or_inputs: List[List[Dict]], **args):
156-
if isinstance(messages_list_or_inputs, List):
157-
inputs = tokenize_conversations(messages_list_or_inputs, tokenizer=self.tokenizer, conv_template=self.template, max_length=self.max_length, processor=self.processor)
158-
else:
159-
raise ValueError("messages_list_or_inputs must be a list of messages or a dictionary of padded inputs.")
160-
161-
if isinstance(self.llm_engine, transformers.PreTrainedModel):
162-
return self.llm_engine.forward(**inputs, **args) # Only support transformers models for now.
163-
else:
164-
raise ValueError("llm_engine must be a transformers.PretrainedModel.")
165202

166203
@property
167204
def trajectories(self):
168205
trajectories = self.get_messages()
169206

170207
return trajectories
171208

172-
def tokenize_trajectories(self, return_action_mask: bool = False, return_reward_mask: bool = False):
209+
def tokenize_trajectories(self, tokenizer, return_action_mask: bool = False, return_reward_mask: bool = False):
210+
if tokenizer is None:
211+
tokenizer = self.tokenizer
212+
173213
trajectories = self.trajectories
174214
self.logger.info("================ Trajectory ================")
175215
self.logger.info(trajectories[0])
@@ -196,7 +236,7 @@ def tokenize_trajectories(self, return_action_mask: bool = False, return_reward_
196236
info['last_response'] = last_response
197237
other_info_list.append(info)
198238

199-
inputs = tokenize_conversations(messages_list, tokenizer=self.tokenizer, conv_template=self.template, processor=self.processor, max_length=self.max_length, return_reward_mask=return_reward_mask)
239+
inputs = tokenize_conversations(messages_list, tokenizer=tokenizer, conv_template=self.template, processor=self.processor, max_length=self.max_length, return_reward_mask=return_reward_mask)
200240
position_ids = torch.clip(torch.cumsum(inputs['attention_mask'], dim=-1) - 1, min=0, max=None)
201241
inputs['position_ids'] = position_ids
202242

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from dataclasses import dataclass
2+
from typing import Optional, Dict, Any, List
3+
import asyncio
4+
5+
6+
@dataclass
7+
class TransformersConfig:
8+
"""Configuration for Transformers backend"""
9+
temperature: float = 1.0
10+
max_new_tokens: int = 1024
11+
trust_remote_code: bool = True
12+
device_map: str = "auto"
13+
14+
15+
@dataclass
16+
class VLLMConfig:
17+
"""Configuration for VLLM backend"""
18+
temperature: float = 1.0
19+
max_new_tokens: int = 1024
20+
# Add other vLLM specific parameters as needed
21+
22+
23+
@dataclass
24+
class AsyncVLLMConfig:
25+
"""Configuration for Async VLLM backend"""
26+
temperature: float = 1.0
27+
max_new_tokens: int = 1024
28+
# Add other async vLLM specific parameters as needed
29+
30+
31+
@dataclass
32+
class VerlConfig:
33+
"""Configuration for Verl backend"""
34+
temperature: float = 1.0
35+
max_new_tokens: int = 1024
36+
# Add other Verl specific parameters as needed
37+
38+
39+
@dataclass
40+
class AsyncVerlConfig:
41+
"""Configuration for Async Verl backend"""
42+
temperature: float = 1.0
43+
max_new_tokens: int = 1024
44+
# Add other async Verl specific parameters as needed
45+
46+
47+
@dataclass
48+
class ClientConfig:
49+
"""Configuration for Client backend (OpenAI-compatible)"""
50+
base_url: str = "http://localhost:8000/v1"
51+
max_requests_per_minute: int = 100
52+
timeout: int = 600
53+
api_key: str = "EMPTY"
54+
max_new_tokens: int = 1024
55+
temperature: float = 1.0
56+
57+
58+
# Backend configuration mapping
59+
BACKEND_CONFIGS = {
60+
"transformers": TransformersConfig,
61+
"vllm": VLLMConfig,
62+
"async_vllm": AsyncVLLMConfig,
63+
"verl": VerlConfig,
64+
"async_verl": AsyncVerlConfig,
65+
"client": ClientConfig,
66+
}

agents/agents/agents/chain/chain_base.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -316,14 +316,40 @@ async def _run_single_chain(self,
316316
# Handle tool calls
317317
if current_node.messages[-1].get("tool_calls"):
318318
for tool_call in current_node.messages[-1]["tool_calls"]:
319-
current_node = await self._execute_tool_call(
319+
result = await self._execute_tool_call(
320320
tool_call, newest_messages, chain, chain_id, depth,
321321
have_set_tools, enable_streaming
322322
)
323323
have_set_tools = True
324+
325+
# Create action input node
326+
action_input_node = chain.add_node(
327+
type="Action Input",
328+
messages=deepcopy(newest_messages),
329+
description=result.get("arguments", "")
330+
)
331+
332+
# Process observation
333+
observation = result["observation"]
334+
observation_json = json.dumps({
335+
"name": result["name"],
336+
"content": observation,
337+
}, indent=4)
338+
339+
action_input_node.observation = observation_json
340+
action_input_node.observation_code = result["status"]
341+
newest_messages.append({
342+
"role": "tool",
343+
"tool_call_id": tool_call["id"],
344+
"content": [{"type": "text", "text": observation_json}],
345+
})
346+
action_input_node.messages = deepcopy(newest_messages)
347+
action_input_node.is_terminal = result["status"] in self.terminal_status
324348
else:
325349
# No tool calls, chain is finished
326350
break
351+
352+
current_node = action_input_node
327353

328354
depth += 1
329355

@@ -463,33 +489,9 @@ async def _execute_tool_call(self, tool_call, newest_messages, chain, chain_id,
463489
step=depth,
464490
depth=depth
465491
))
492+
493+
return result
466494

467-
468-
# Create action input node
469-
action_input_node = chain.add_node(
470-
type="Action Input",
471-
messages=deepcopy(newest_messages),
472-
description=result.get("arguments", "")
473-
)
474-
475-
# Process observation
476-
observation = result["observation"]
477-
observation_json = json.dumps({
478-
"name": result["name"],
479-
"content": observation,
480-
}, indent=4)
481-
482-
action_input_node.observation = observation_json
483-
action_input_node.observation_code = result["status"]
484-
newest_messages.append({
485-
"role": "tool",
486-
"tool_call_id": tool_call["id"],
487-
"content": [{"type": "text", "text": observation_json}],
488-
})
489-
action_input_node.messages = deepcopy(newest_messages)
490-
action_input_node.is_terminal = result["status"] in self.terminal_status
491-
492-
return action_input_node
493495

494496
async def _finalize_chain(self, chain_id, chain, current_node, depth):
495497
"""Finalize the chain with reward calculation and cleanup."""

0 commit comments

Comments
 (0)