Skip to content

Commit ccf5bf8

Browse files
authored
Merge pull request #4 from Agent-One-Lab/agents
Add mock tests, fix training bug
2 parents d4abd9f + 5403dbf commit ccf5bf8

File tree

15 files changed

+1477
-73
lines changed

15 files changed

+1477
-73
lines changed

agents/agents/agents/agent_base.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import warnings
1818
import logging
1919
from .chain.streaming_observer import ConsoleStreamObserver, StreamingManager
20-
from .utils.tokenizer import create_tokenizer
20+
from .utils.tokenizer import create_processor, create_tokenizer
2121
from .backend_config import BACKEND_CONFIGS
2222
try:
2323
from verl.protocol import DataProto
@@ -43,14 +43,14 @@ def __init__(
4343
system_prompt: str = None,
4444
tools: List = None,
4545
max_length: int=8192,
46-
debug: bool = False,
4746
backend: str = "transformers",
4847
backend_config: Any = None,
4948
reward_fn: Callable = None,
5049
log_file: str = "agent",
5150
project_name: str = None,
5251
run_name: str = None,
5352
streaming: str = "console",
53+
debug: bool = False,
5454
**kwargs # To pass other unused arguments
5555
):
5656
"""
@@ -65,6 +65,7 @@ def __init__(
6565
"""
6666
torch.set_printoptions(threshold=10_000)
6767
self.logger = get_logger(directory=os.path.join(AGENT_DATA_DIR, "debug"), filename=log_file, level="DEBUG" if debug else "INFO")
68+
self.debug = debug
6869
self.backend = backend
6970
self.template = template
7071
self.max_length = max_length
@@ -87,6 +88,8 @@ def __init__(
8788

8889
# Create appropriate tokenizer for trajectory processing
8990
self.tokenizer = create_tokenizer(model_name_or_path)
91+
92+
self.processor = create_processor(model_name_or_path)
9093

9194
self._reward_fn = reward_fn
9295

@@ -105,8 +108,7 @@ def __init__(
105108
raise ValueError(f"Streaming mode {streaming} is not supported.")
106109
super().__init__()
107110
if kwargs:
108-
# warnings.warn(f"Unused arguments for agent initialization: {kwargs}")
109-
raise ValueError(f"Unused arguments for agent initialization: {kwargs}")
111+
warnings.warn(f"Unused arguments for agent initialization: {kwargs}")
110112

111113
def _init_llm_engine(self, model_name_or_path: str, backend: str):
112114
if isinstance(model_name_or_path, str):
@@ -206,7 +208,7 @@ def trajectories(self):
206208

207209
return trajectories
208210

209-
def tokenize_trajectories(self, tokenizer, return_action_mask: bool = False, return_reward_mask: bool = False):
211+
def tokenize_trajectories(self, tokenizer = None, return_reward_mask: bool = False):
210212
if tokenizer is None:
211213
tokenizer = self.tokenizer
212214

@@ -318,7 +320,7 @@ def rewards(self):
318320

319321

320322
def get_verl_data_proto(self):
321-
inputs, other_info_list = self.tokenize_trajectories(return_action_mask=True, return_reward_mask=True)
323+
inputs, other_info_list = self.tokenize_trajectories(return_reward_mask=True)
322324
group_ids = np.array([info["group_id"] for info in other_info_list], dtype=object)
323325
# Do evaluation here
324326
reward_values, other_values = self.rewards

agents/agents/agents/auto.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Optional, Type, Union
1+
from typing import Any, Callable, Dict, List, Optional, Type, Union
22

33
from .specialized.think_agent import ThinkAgent
44
from agents.agents.specialized.openai_agent import OpenAIAgent
@@ -8,8 +8,7 @@
88
from .specialized.code_agent import CodeAgent
99
from ..rewards.reward_base import get_reward_from_name
1010

11-
# Registry for agent types - will be populated dynamically
12-
AGENT_MAPPING = {}
11+
1312

1413
class AutoAgent:
1514
"""
@@ -22,7 +21,7 @@ class AutoAgent:
2221
These agents are registered automatically. Additional custom agents can be
2322
registered using the register_agent method.
2423
"""
25-
24+
AGENT_MAPPING = {}
2625
@classmethod
2726
def register_agent(cls, agent_type: str, agent_class: Type[BaseAgent]) -> None:
2827
"""
@@ -32,7 +31,7 @@ def register_agent(cls, agent_type: str, agent_class: Type[BaseAgent]) -> None:
3231
agent_type: The name identifier for the agent type (e.g., 'react', 'code')
3332
agent_class: The agent class to instantiate for this type
3433
"""
35-
AGENT_MAPPING[agent_type.lower()] = agent_class
34+
cls.AGENT_MAPPING[agent_type.lower()] = agent_class
3635

3736
@classmethod
3837
def _get_agent_class(cls, agent_type: str) -> Type[BaseAgent]:
@@ -50,11 +49,11 @@ def _get_agent_class(cls, agent_type: str) -> Type[BaseAgent]:
5049
"""
5150
agent_type = agent_type.lower()
5251

53-
if agent_type not in AGENT_MAPPING:
54-
available_types = list(AGENT_MAPPING.keys())
52+
if agent_type not in cls.AGENT_MAPPING:
53+
available_types = list(cls.AGENT_MAPPING.keys())
5554
raise ValueError(f"Unknown agent type: '{agent_type}'. Available types: {available_types}")
5655

57-
return AGENT_MAPPING[agent_type]
56+
return cls.AGENT_MAPPING[agent_type]
5857

5958
@classmethod
6059
def from_config(cls, config: Dict[str, Any]) -> BaseAgent:
@@ -81,27 +80,36 @@ def from_config(cls, config: Dict[str, Any]) -> BaseAgent:
8180
An initialized agent instance.
8281
"""
8382
# Extract and validate required parameters
83+
if config is None:
84+
raise ValueError("Config could not be None")
85+
86+
# construct a copy for agent_kwargs
87+
agent_kwargs = {}
88+
for k, v in config.items():
89+
agent_kwargs[k] = v
90+
8491
required_params = ["agent_type", "template", "tools", "backend"]
8592
missing_params = [param for param in required_params if not config.get(param)]
8693

8794
if missing_params:
8895
raise ValueError(f"Missing required parameters: {', '.join(missing_params)}")
8996

9097
agent_type = config["agent_type"]
98+
agent_kwargs.pop("agent_type")
9199
tools = get_tools_from_names(config["tools"])
92100
agent_class = cls._get_agent_class(agent_type)
101+
reward_name = config.get("reward_name")
102+
if reward_name is not None:
103+
reward_fn = get_reward_from_name(reward_name)
104+
agent_kwargs.pop("reward_name")
105+
else:
106+
reward_fn = None
93107

94-
# construct a copy for agent_kwargs
95-
agent_kwargs = {}
96-
for k, v in config.items():
97-
agent_kwargs[k] = v
98-
99-
agent_kwargs.pop("agent_type")
100108
agent_kwargs['tools'] = tools
101-
if "reward_name" in config and config["reward_name"] is not None:
102-
agent_kwargs.pop("reward_name")
103-
reward_fn = get_reward_from_name(config["reward_name"])
104-
agent_kwargs["reward_fn"] = reward_fn
109+
agent_kwargs['reward_fn'] = reward_fn
110+
111+
if "use_agent" in agent_kwargs:
112+
agent_kwargs.pop("use_agent")
105113

106114
agent = agent_class(**agent_kwargs)
107115

@@ -114,11 +122,9 @@ def from_pretrained(
114122
agent_type: str,
115123
template: str,
116124
tools: Optional[List] = None,
117-
vllm: bool = False,
118125
debug: bool = False,
119126
log_file: str = "agent",
120-
wrapper: bool = False,
121-
reward_name: Optional[str] = None,
127+
reward_fn: Optional[Callable] = None,
122128
**kwargs
123129
) -> BaseAgent:
124130
"""
@@ -147,11 +153,9 @@ def from_pretrained(
147153
"model_name_or_path": model_name_or_path,
148154
"template": template,
149155
"tools": tools or [],
150-
"vllm": vllm,
151156
"debug": debug,
152157
"log_file": log_file,
153-
"wrapper": wrapper,
154-
"reward_name": reward_name,
158+
"reward_fn": reward_fn,
155159
**kwargs
156160
}
157161

agents/agents/agents/llm_backend.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
import asyncio
66
from asyncore import loop
77
from collections import deque
8+
import copy
89
from functools import partial
910
import time
1011
from typing import Dict, Any, List, Optional, Callable, AsyncGenerator
1112
import uuid
12-
from .templates.utils import convert_messages_to_openai_format
1313
import numpy as np
1414
from tenacity import retry, stop_after_attempt, wait_exponential
1515
import torch
@@ -24,8 +24,8 @@
2424
import logging
2525
import PIL
2626

27+
2728
LOGGER = logging.getLogger(__name__)
28-
LOGGER.setLevel(logging.DEBUG)
2929

3030
try:
3131
from verl.protocol import DataProto
@@ -353,14 +353,29 @@ def _process_inputs(self, prompts: List[str], vision_inputs: Dict[str, List[PIL.
353353

354354
def generate(self, messages_list: str, **kwargs) -> str:
355355
raise NotImplementedError("Async Verl backend does not support sync generation")
356+
357+
def _convert_to_openai_chat_without_tool_call_processing(self, messages: list) -> list:
358+
"""
359+
We use the pure generated content as the history. So we don't want any tool call to be part of the history.
360+
This is used when models are not openai's official models like GPT-4o.
361+
"""
362+
messages = copy.deepcopy(messages)
363+
for message in messages:
364+
if "tool_calls" in message:
365+
del message["tool_calls"]
366+
if "tool_call_id" in message:
367+
del message["tool_call_id"]
368+
if "tool_choice" in message:
369+
del message["tool_choice"]
370+
return messages
356371

357372
async def generate_async(self, messages_list: str, **kwargs) -> str:
358373
"""Generate text from prompt using Verl"""
359374
# We need to build a DataProto from the prompts
360375

361376
generation_config = {}
362377
tensors = torch.ones(len(messages_list), dtype=torch.int64)
363-
messages_list = [convert_messages_to_openai_format(messages) for messages in messages_list]
378+
messages_list = [self._convert_to_openai_chat_without_tool_call_processing(messages) for messages in messages_list]
364379
tools = kwargs.get("tools", None)
365380
tools_list = np.array([tools] * len(messages_list))
366381
data = {"input_ids": tensors, "raw_prompt": np.array(messages_list), "tools": tools_list}
@@ -457,6 +472,21 @@ async def _call(self, messages: List[List[Dict]], **kw) -> str:
457472
loop = asyncio.get_running_loop()
458473
return await loop.run_in_executor(None, partial(self._blocking_call, messages, **kw))
459474

475+
def _convert_to_openai_chat_without_tool_call_processing(self, messages: list) -> list:
476+
"""
477+
We use the pure generated content as the history. So we don't want any tool call to be part of the history.
478+
This is used when models are not openai's official models like GPT-4o.
479+
TODO: we need to add support for openai models
480+
"""
481+
messages = copy.deepcopy(messages)
482+
for message in messages:
483+
if "tool_calls" in message:
484+
del message["tool_calls"]
485+
if "tool_call_id" in message:
486+
del message["tool_call_id"]
487+
if "tool_choice" in message:
488+
del message["tool_choice"]
489+
return messages
460490

461491
# Public API ‑‑ sync or async depending on caller's context
462492
def async_generate(
@@ -478,7 +508,7 @@ def async_generate(
478508
else:
479509
messages_list = messages # batch
480510
print(f"[ClientBackend] messages_list: {messages_list}")
481-
messages_list = [convert_messages_to_openai_format(messages) for messages in messages_list]
511+
messages_list = [self._convert_to_openai_chat_without_tool_call_processing(messages) for messages in messages_list]
482512

483513
async def _runner():
484514
tasks = [asyncio.create_task(self._call(_input, **kwargs)) for _input in messages_list]

agents/agents/agents/react/react_agent.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ def __init__(self,
123123
model_name_or_path=model_name_or_path,
124124
tools=tools,
125125
system_prompt=system_prompt,
126-
max_length=8192,
127126
**kwargs
128127
)
129128

agents/agents/agents/templates/utils.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,6 @@ def strip_ansi(s: str) -> str:
2222
return ANSI_RE.sub('', s)
2323

2424

25-
def convert_messages_to_openai_format(messages: list) -> list:
26-
"""
27-
Convert messages to OpenAI format.
28-
TODO: add more processing for other types of content
29-
"""
30-
messages = copy.deepcopy(messages)
31-
for message in messages:
32-
# if "tool_calls" in message:
33-
# del message["tool_calls"]
34-
# if "tool_call_id" in message:
35-
# del message["tool_call_id"]
36-
if "tool_choice" in message:
37-
del message["tool_choice"]
38-
return messages
39-
40-
4125
def convert_messages_to_hf_format(messages: list) -> list:
4226
"""
4327
Convert messages to Hugging Face format.
@@ -305,9 +289,7 @@ def compare_hf_template(tokenizer, template_name, messages=None, tools=None, add
305289
plain_highlighted_prompt = strip_ansi(highlighted_prompt)
306290
is_equal_between_implemented_prompts = implemented_prompt == plain_highlighted_prompt
307291
jinja_template = chat.template.jinja_template()
308-
# Save jinja template to file
309-
with open("jinja_template.jinja", "w") as f:
310-
f.write(jinja_template)
292+
311293
tokenizer.chat_template = jinja_template
312294
implemented_jinja_prompt = tokenizer.apply_chat_template(messages, tokenize=False, tools=tools, add_generation_prompt=add_generation_prompt)
313295
is_equal_between_jinja_prompts = implemented_jinja_prompt == implemented_prompt

agents/agents/agents/utils/tokenizer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from transformers import AutoTokenizer
1+
from transformers import AutoProcessor, AutoTokenizer
22

33
def create_tokenizer(model_name_or_path: str):
44
try:
@@ -8,3 +8,12 @@ def create_tokenizer(model_name_or_path: str):
88
tokenizer = None
99

1010
return tokenizer
11+
12+
13+
def create_processor(model_name_or_path: str):
14+
try:
15+
processor = AutoProcessor.from_pretrained(model_name_or_path)
16+
except OSError:
17+
processor = None
18+
19+
return processor
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Mock tests package for agents

0 commit comments

Comments
 (0)