44
55from .templates .templates import get_template
66from ..__init__ import AGENT_DATA_DIR
7- from .llm_backend import AsyncVLLMBackend , AsyncVerlBackend , ClientBackend , TransformersBackend , VLLMBackend , VerlBackend
7+ from .llm_backend import AsyncVLLMBackend , AsyncVerlBackend , ClientBackend , TransformersBackend , VLLMBackend
88from ..utils .logging import get_logger
99from typing import Any , Callable , Dict , List , Optional , Tuple , Union
1010import numpy as np
1111import torch
12- from .templates .utils import is_vlm_template , tokenize_conversations
12+ from .templates .utils import tokenize_conversations
13+ from .templates .vision_processor import is_vision_template
1314from .chain .chain_base import ChainGeneration
1415import os
1516import transformers
1617import warnings
18+ import logging
1719from .chain .streaming_observer import ConsoleStreamObserver , StreamingManager
20+ from .utils .tokenizer import create_tokenizer
21+ from .backend_config import BACKEND_CONFIGS
1822try :
1923 from verl .protocol import DataProto
2024except ImportError :
2125 print ("verl can not be imported." )
2226 pass
2327
28+ Logger = logging .getLogger (__name__ )
2429
2530class BaseAgent (ChainGeneration , ABC ):
2631 """
@@ -34,12 +39,13 @@ class BaseAgent(ChainGeneration, ABC):
3439 def __init__ (
3540 self ,
3641 model_name_or_path ,
37- template : str ,
42+ template : str = None ,
3843 system_prompt : str = None ,
3944 tools : List = None ,
4045 max_length : int = 8192 ,
4146 debug : bool = False ,
4247 backend : str = "transformers" ,
48+ backend_config : Any = None ,
4349 reward_fn : Callable = None ,
4450 log_file : str = "agent" ,
4551 project_name : str = None ,
@@ -65,9 +71,30 @@ def __init__(
6571 self .tools = tools
6672 self .system_prompt = system_prompt
6773 self .model_name_or_path = model_name_or_path
68- self .llm_engine , self .tokenizer , self .processor = self ._init_llm_engine (model_name_or_path , backend )
74+
75+ # Handle backend configuration
76+ if backend_config is None :
77+ # Use default configuration for the backend
78+ config_class = BACKEND_CONFIGS .get (backend )
79+ if config_class :
80+ self .backend_config = config_class ()
81+ else :
82+ self .backend_config = None
83+ else :
84+ self .backend_config = backend_config
85+
86+ self .llm_engine = self ._init_llm_engine (model_name_or_path , backend )
87+
88+ # Create appropriate tokenizer for trajectory processing
89+ self .tokenizer = create_tokenizer (model_name_or_path )
90+
6991 self ._reward_fn = reward_fn
70- self .jinja_template = get_template (self .template ).jinja_template ()
92+
93+ if self .template is None :
94+ self .jinja_template = None
95+ else :
96+ self .jinja_template = get_template (self .template ).jinja_template ()
97+
7198 self .project_name = project_name
7299 self .run_name = run_name
73100 self .streaming_manager = StreamingManager ()
@@ -78,38 +105,59 @@ def __init__(
78105 raise ValueError (f"Streaming mode { streaming } is not supported." )
79106 super ().__init__ ()
80107 if kwargs :
81- warnings .warn (f"Unused arguments for agent initialization: { kwargs } " )
108+ # warnings.warn(f"Unused arguments for agent initialization: {kwargs}")
109+ raise ValueError (f"Unused arguments for agent initialization: { kwargs } " )
82110
83111 def _init_llm_engine (self , model_name_or_path : str , backend : str ):
84112 if isinstance (model_name_or_path , str ):
113+ # Extract backend-specific configuration
114+ config_kwargs = {}
115+ if self .backend_config :
116+ config_kwargs = {k : v for k , v in self .backend_config .__dict__ .items ()
117+ if not k .startswith ('_' )}
118+
85119 if backend == "transformers" :
86- llm_engine = TransformersBackend (model_name_or_path , self .template , max_length = self .max_length )
87- elif backend == "vllm" :
88- llm_engine = VLLMBackend (model_name_or_path , self .template , max_length = self .max_length )
120+ llm_engine = TransformersBackend (
121+ model_name_or_path ,
122+ self .template ,
123+ max_length = self .max_length ,
124+ ** config_kwargs
125+ )
89126 elif backend == "async_vllm" :
90- llm_engine = AsyncVLLMBackend (model_name_or_path , self .template , max_length = self .max_length )
91- elif backend == "verl" :
92- llm_engine = VerlBackend (llm_engine = None , model_name_or_path = model_name_or_path , template = self .template , max_length = self .max_length )
127+ llm_engine = AsyncVLLMBackend (
128+ model_name_or_path ,
129+ self .template ,
130+ max_length = self .max_length ,
131+ ** config_kwargs
132+ )
93133 elif backend == "async_verl" :
94- llm_engine = AsyncVerlBackend (llm_engine = None , model_name_or_path = model_name_or_path , template = self .template , max_length = self .max_length )
134+ llm_engine = AsyncVerlBackend (
135+ llm_engine = None ,
136+ model_name_or_path = model_name_or_path ,
137+ template = self .template ,
138+ max_length = self .max_length ,
139+ ** config_kwargs
140+ )
95141 elif backend == "client" :
96- llm_engine = ClientBackend (model_name_or_path , self .template , max_length = self .max_length )
142+ print (f"config_kwargs: { config_kwargs } " )
143+ llm_engine = ClientBackend (
144+ model_name_or_path ,
145+ self .template ,
146+ max_length = self .max_length ,
147+ ** config_kwargs
148+ )
97149 else :
98150 raise ValueError (f"Backend { backend } is not supported." )
99151 else :
100152 raise ValueError ("model_name_or_path must be a string." )
101153
102- tokenizer = transformers .AutoTokenizer .from_pretrained (model_name_or_path )
103- if is_vlm_template (self .template ):
104- processor = transformers .AutoProcessor .from_pretrained (model_name_or_path )
105- else :
106- processor = None
107- return llm_engine , tokenizer , processor
154+ return llm_engine
108155
109- def set_llm_engine (self , llm_engine : Any , tokenizer : Any ):
156+ def set_llm_engine (self , llm_engine : Any , tokenizer : Any , processor : Any ):
110157 assert self .backend == "async_verl" , "Only async verl backend is supported for now"
111158 self .llm_engine .llm_engine = llm_engine
112159 self .tokenizer = tokenizer
160+ self .processor = processor
113161
114162 def generate (self , messages_list_or_inputs : List [List [Dict ]], ** args ):
115163 return self .llm_engine .generate (messages_list_or_inputs , ** args )
@@ -151,25 +199,17 @@ async def generate_streaming(self, messages_list_or_inputs: List[List[Dict]], st
151199 @property
152200 def timing_data (self ):
153201 return self .timer .timing_data
154-
155- def forward (self , messages_list_or_inputs : List [List [Dict ]], ** args ):
156- if isinstance (messages_list_or_inputs , List ):
157- inputs = tokenize_conversations (messages_list_or_inputs , tokenizer = self .tokenizer , conv_template = self .template , max_length = self .max_length , processor = self .processor )
158- else :
159- raise ValueError ("messages_list_or_inputs must be a list of messages or a dictionary of padded inputs." )
160-
161- if isinstance (self .llm_engine , transformers .PreTrainedModel ):
162- return self .llm_engine .forward (** inputs , ** args ) # Only support transformers models for now.
163- else :
164- raise ValueError ("llm_engine must be a transformers.PretrainedModel." )
165202
166203 @property
167204 def trajectories (self ):
168205 trajectories = self .get_messages ()
169206
170207 return trajectories
171208
172- def tokenize_trajectories (self , return_action_mask : bool = False , return_reward_mask : bool = False ):
209+ def tokenize_trajectories (self , tokenizer , return_action_mask : bool = False , return_reward_mask : bool = False ):
210+ if tokenizer is None :
211+ tokenizer = self .tokenizer
212+
173213 trajectories = self .trajectories
174214 self .logger .info ("================ Trajectory ================" )
175215 self .logger .info (trajectories [0 ])
@@ -196,7 +236,7 @@ def tokenize_trajectories(self, return_action_mask: bool = False, return_reward_
196236 info ['last_response' ] = last_response
197237 other_info_list .append (info )
198238
199- inputs = tokenize_conversations (messages_list , tokenizer = self . tokenizer , conv_template = self .template , processor = self .processor , max_length = self .max_length , return_reward_mask = return_reward_mask )
239+ inputs = tokenize_conversations (messages_list , tokenizer = tokenizer , conv_template = self .template , processor = self .processor , max_length = self .max_length , return_reward_mask = return_reward_mask )
200240 position_ids = torch .clip (torch .cumsum (inputs ['attention_mask' ], dim = - 1 ) - 1 , min = 0 , max = None )
201241 inputs ['position_ids' ] = position_ids
202242
0 commit comments