1- from typing import Any , Dict , List , Optional , Type , Union
1+ from typing import Any , Callable , Dict , List , Optional , Type , Union
22
33from .specialized .think_agent import ThinkAgent
44from agents .agents .specialized .openai_agent import OpenAIAgent
88from .specialized .code_agent import CodeAgent
99from ..rewards .reward_base import get_reward_from_name
1010
11- # Registry for agent types - will be populated dynamically
12- AGENT_MAPPING = {}
11+
1312
1413class AutoAgent :
1514 """
@@ -22,7 +21,7 @@ class AutoAgent:
2221 These agents are registered automatically. Additional custom agents can be
2322 registered using the register_agent method.
2423 """
25-
24+ AGENT_MAPPING = {}
2625 @classmethod
2726 def register_agent (cls , agent_type : str , agent_class : Type [BaseAgent ]) -> None :
2827 """
@@ -32,7 +31,7 @@ def register_agent(cls, agent_type: str, agent_class: Type[BaseAgent]) -> None:
3231 agent_type: The name identifier for the agent type (e.g., 'react', 'code')
3332 agent_class: The agent class to instantiate for this type
3433 """
35- AGENT_MAPPING [agent_type .lower ()] = agent_class
34+ cls . AGENT_MAPPING [agent_type .lower ()] = agent_class
3635
3736 @classmethod
3837 def _get_agent_class (cls , agent_type : str ) -> Type [BaseAgent ]:
@@ -50,11 +49,11 @@ def _get_agent_class(cls, agent_type: str) -> Type[BaseAgent]:
5049 """
5150 agent_type = agent_type .lower ()
5251
53- if agent_type not in AGENT_MAPPING :
54- available_types = list (AGENT_MAPPING .keys ())
52+ if agent_type not in cls . AGENT_MAPPING :
53+ available_types = list (cls . AGENT_MAPPING .keys ())
5554 raise ValueError (f"Unknown agent type: '{ agent_type } '. Available types: { available_types } " )
5655
57- return AGENT_MAPPING [agent_type ]
56+ return cls . AGENT_MAPPING [agent_type ]
5857
5958 @classmethod
6059 def from_config (cls , config : Dict [str , Any ]) -> BaseAgent :
@@ -81,27 +80,36 @@ def from_config(cls, config: Dict[str, Any]) -> BaseAgent:
8180 An initialized agent instance.
8281 """
8382 # Extract and validate required parameters
83+ if config is None :
84+ raise ValueError ("Config could not be None" )
85+
86+ # construct a copy for agent_kwargs
87+ agent_kwargs = {}
88+ for k , v in config .items ():
89+ agent_kwargs [k ] = v
90+
8491 required_params = ["agent_type" , "template" , "tools" , "backend" ]
8592 missing_params = [param for param in required_params if not config .get (param )]
8693
8794 if missing_params :
8895 raise ValueError (f"Missing required parameters: { ', ' .join (missing_params )} " )
8996
9097 agent_type = config ["agent_type" ]
98+ agent_kwargs .pop ("agent_type" )
9199 tools = get_tools_from_names (config ["tools" ])
92100 agent_class = cls ._get_agent_class (agent_type )
101+ reward_name = config .get ("reward_name" )
102+ if reward_name is not None :
103+ reward_fn = get_reward_from_name (reward_name )
104+ agent_kwargs .pop ("reward_name" )
105+ else :
106+ reward_fn = None
93107
94- # construct a copy for agent_kwargs
95- agent_kwargs = {}
96- for k , v in config .items ():
97- agent_kwargs [k ] = v
98-
99- agent_kwargs .pop ("agent_type" )
100108 agent_kwargs ['tools' ] = tools
101- if "reward_name" in config and config [ "reward_name" ] is not None :
102- agent_kwargs . pop ( "reward_name" )
103- reward_fn = get_reward_from_name ( config [ "reward_name" ])
104- agent_kwargs [ "reward_fn" ] = reward_fn
109+ agent_kwargs [ 'reward_fn' ] = reward_fn
110+
111+ if "use_agent" in agent_kwargs :
112+ agent_kwargs . pop ( "use_agent" )
105113
106114 agent = agent_class (** agent_kwargs )
107115
@@ -114,11 +122,9 @@ def from_pretrained(
114122 agent_type : str ,
115123 template : str ,
116124 tools : Optional [List ] = None ,
117- vllm : bool = False ,
118125 debug : bool = False ,
119126 log_file : str = "agent" ,
120- wrapper : bool = False ,
121- reward_name : Optional [str ] = None ,
127+ reward_fn : Optional [Callable ] = None ,
122128 ** kwargs
123129 ) -> BaseAgent :
124130 """
@@ -147,11 +153,9 @@ def from_pretrained(
147153 "model_name_or_path" : model_name_or_path ,
148154 "template" : template ,
149155 "tools" : tools or [],
150- "vllm" : vllm ,
151156 "debug" : debug ,
152157 "log_file" : log_file ,
153- "wrapper" : wrapper ,
154- "reward_name" : reward_name ,
158+ "reward_fn" : reward_fn ,
155159 ** kwargs
156160 }
157161
0 commit comments