1111from trinity .common .models .model import ModelWrapper
1212from trinity .common .workflows .workflow import WORKFLOWS , Task , Workflow
1313
14- from .templates import TEMPLATE_MAP
15-
1614
1715@WORKFLOWS .register_module ("as_react_workflow" )
1816class AgentScopeReActWorkflow (Workflow ):
@@ -29,8 +27,13 @@ def __init__(
2927 auxiliary_models = auxiliary_models ,
3028 )
3129 self .model_client = model .get_openai_async_client ()
30+ self .reset (task )
31+
32+ def reset (self , task : Task ):
33+ from trinity .common .workflows .agentscope .react .templates import TEMPLATE_MAP
3234
3335 task_type = task .workflow_args .get ("type" , "gsm8k" )
36+ self .logger .info (f"task_type: { task_type } " )
3437 template = TEMPLATE_MAP .get (task_type , None )
3538 if template is None :
3639 raise ValueError (
@@ -40,6 +43,13 @@ def __init__(
4043 self .query = task .raw_task .get (task .format_args .prompt_key ) # type: ignore [index]
4144 self .answer = task .raw_task .get (task .format_args .response_key ) # type: ignore [index]
4245 self .reward_fn = template .reward_fn_cls (** task .reward_fn_args )
46+ self .toolkit_manager = template .toolkit_manager (task = task )
47+
48+ system_prompt = (
49+ template .system_prompt
50+ if isinstance (template .system_prompt , str )
51+ else template .system_prompt (task )
52+ )
4353
4454 # import here to avoid the import error if agentscope is not installed and this workflow is not used
4555 try :
@@ -53,32 +63,43 @@ def __init__(
5363 self .agent = AgentScopeReActAgent (
5464 model_name = self .model_client .model_path ,
5565 openai_client = self .model_client ,
56- system_prompt = template . system_prompt ,
66+ system_prompt = system_prompt ,
5767 generate_kwargs = {
5868 "temperature" : self .rollout_args .get ("temperature" , 1.0 ),
5969 "max_tokens" : self .rollout_args .get ("max_tokens" , 4096 ),
6070 },
6171 response_structure = template .response_structure ,
72+ toolkit = self .toolkit_manager .toolkit ,
6273 )
6374
6475 async def run_async (self ):
6576 """Run the workflow asynchronously."""
6677 # Step 1: call the react agent to solve the task
6778 response = await self .agent .reply (self .query )
68- # Step 2: calculate the reward based on the response
69- reward = await self .calculate_reward (response )
70- # Step 3: construct experiences from the interaction history and return them
71- return self .construct_experiences (reward )
79+ # Step 2: extract the experience
80+ exps = self .model .extract_experience_from_history ()
81+ # Step 3: calculate the reward based on the response
82+ reward = await self .calculate_reward (response , exps )
83+ # Step 4: construct experiences from the interaction history and return them
84+ return self .construct_experiences (reward , exps )
7285
73- async def calculate_reward (self , response ) -> Union [float , Dict [str , float ]]:
86+ async def calculate_reward (self , response , exps ) -> Union [float , Dict [str , float ]]:
7487 """Calculate the reward for the workflow.
7588
7689 Returns:
7790 Union[float, Dict[str, float]]: The reward value or a dictionary of reward value.
7891 """
79- return self .reward_fn (response = response , truth = self .answer )
92+ return self .reward_fn (
93+ response = response ,
94+ truth = self .answer ,
95+ auxiliary_models = self .auxiliary_models ,
96+ num_turns = len (exps ),
97+ ** self .toolkit_manager .get_status (),
98+ )
8099
81- def construct_experiences (self , reward : Union [float , Dict [str , float ]]) -> List [Experience ]:
100+ def construct_experiences (
101+ self , reward : Union [float , Dict [str , float ]], exps
102+ ) -> List [Experience ]:
82103 """Construct experiences from the agent's interaction history.
83104
84105 Args:
@@ -87,10 +108,16 @@ def construct_experiences(self, reward: Union[float, Dict[str, float]]) -> List[
87108 Returns:
88109 List: A list of Experience objects.
89110 """
90- exps = self .model .extract_experience_from_history ()
91- for exp in exps :
92- exp .reward = reward if isinstance (reward , float ) else sum (reward .values ())
93- exp .metrics = {"react_memory_length" : len (self .agent .agent .memory .content )}
111+ reward_value = reward if isinstance (reward , float ) else sum (reward .values ())
112+ react_memory_length = len (self .agent .agent .memory .content )
113+ num_turns = len (exps )
114+ for i , exp in enumerate (exps ):
115+ exp .eid .step = i
116+ exp .reward = reward_value
117+ if exp .metrics is None :
118+ exp .metrics = {}
119+ exp .metrics ["react_memory_length" ] = react_memory_length
120+ exp .metrics ["num_turns" ] = num_turns
94121 # record detailed reward if available
95122 if isinstance (reward , dict ):
96123 exp .metrics .update (reward )
@@ -105,3 +132,7 @@ def asynchronous(self):
105132 def repeatable (self ):
106133 """This workflow is not repeatable."""
107134 return False
135+
136+ @property
137+ def resettable (self ):
138+ return True
0 commit comments