Skip to content

Commit 3b53c83

Browse files
committed
refactor AgentScopeReActWorkflow
1 parent 0e1f545 commit 3b53c83

File tree

3 files changed

+328
-18
lines changed

3 files changed

+328
-18
lines changed

examples/grpo_email_search/email_search.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ buffer:
2626
prompt_key: 'question'
2727
response_key: 'answer'
2828
workflow_args:
29+
type: email_search
2930
max_turns: 10
3031
reward_fn_args:
3132
llm_as_a_judge: true
@@ -42,13 +43,14 @@ buffer:
4243
response_key: 'answer'
4344
enable_progress_bar: false
4445
workflow_args:
46+
type: email_search
4547
max_turns: 10
4648
reward_fn_args:
4749
llm_as_a_judge: true
4850
rollout_args:
4951
temperature: 0.6
5052
# max_tokens: 4096
51-
default_workflow_type: 'email_search_workflow'
53+
default_workflow_type: 'as_react_workflow'
5254
trainer_input:
5355
experience_buffer:
5456
name: experience_buffer

trinity/common/workflows/agentscope/react/react_workflow.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
from trinity.common.models.model import ModelWrapper
1212
from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
1313

14-
from .templates import TEMPLATE_MAP
15-
1614

1715
@WORKFLOWS.register_module("as_react_workflow")
1816
class AgentScopeReActWorkflow(Workflow):
@@ -29,8 +27,13 @@ def __init__(
2927
auxiliary_models=auxiliary_models,
3028
)
3129
self.model_client = model.get_openai_async_client()
30+
self.reset(task)
31+
32+
def reset(self, task: Task):
33+
from trinity.common.workflows.agentscope.react.templates import TEMPLATE_MAP
3234

3335
task_type = task.workflow_args.get("type", "gsm8k")
36+
self.logger.info(f"task_type: {task_type}")
3437
template = TEMPLATE_MAP.get(task_type, None)
3538
if template is None:
3639
raise ValueError(
@@ -40,6 +43,13 @@ def __init__(
4043
self.query = task.raw_task.get(task.format_args.prompt_key) # type: ignore [index]
4144
self.answer = task.raw_task.get(task.format_args.response_key) # type: ignore [index]
4245
self.reward_fn = template.reward_fn_cls(**task.reward_fn_args)
46+
self.toolkit_manager = template.toolkit_manager(task=task)
47+
48+
system_prompt = (
49+
template.system_prompt
50+
if isinstance(template.system_prompt, str)
51+
else template.system_prompt(task)
52+
)
4353

4454
# import here to avoid the import error if agentscope is not installed and this workflow is not used
4555
try:
@@ -53,32 +63,43 @@ def __init__(
5363
self.agent = AgentScopeReActAgent(
5464
model_name=self.model_client.model_path,
5565
openai_client=self.model_client,
56-
system_prompt=template.system_prompt,
66+
system_prompt=system_prompt,
5767
generate_kwargs={
5868
"temperature": self.rollout_args.get("temperature", 1.0),
5969
"max_tokens": self.rollout_args.get("max_tokens", 4096),
6070
},
6171
response_structure=template.response_structure,
72+
toolkit=self.toolkit_manager.toolkit,
6273
)
6374

6475
async def run_async(self):
6576
"""Run the workflow asynchronously."""
6677
# Step 1: call the react agent to solve the task
6778
response = await self.agent.reply(self.query)
68-
# Step 2: calculate the reward based on the response
69-
reward = await self.calculate_reward(response)
70-
# Step 3: construct experiences from the interaction history and return them
71-
return self.construct_experiences(reward)
79+
# Step 2: extract the experience
80+
exps = self.model.extract_experience_from_history()
81+
# Step 3: calculate the reward based on the response
82+
reward = await self.calculate_reward(response, exps)
83+
# Step 4: construct experiences from the interaction history and return them
84+
return self.construct_experiences(reward, exps)
7285

73-
async def calculate_reward(self, response) -> Union[float, Dict[str, float]]:
86+
async def calculate_reward(self, response, exps) -> Union[float, Dict[str, float]]:
7487
"""Calculate the reward for the workflow.
7588
7689
Returns:
7790
Union[float, Dict[str, float]]: The reward value or a dictionary of reward value.
7891
"""
79-
return self.reward_fn(response=response, truth=self.answer)
92+
return self.reward_fn(
93+
response=response,
94+
truth=self.answer,
95+
auxiliary_models=self.auxiliary_models,
96+
num_turns=len(exps),
97+
**self.toolkit_manager.get_status(),
98+
)
8099

81-
def construct_experiences(self, reward: Union[float, Dict[str, float]]) -> List[Experience]:
100+
def construct_experiences(
101+
self, reward: Union[float, Dict[str, float]], exps
102+
) -> List[Experience]:
82103
"""Construct experiences from the agent's interaction history.
83104
84105
Args:
@@ -87,10 +108,16 @@ def construct_experiences(self, reward: Union[float, Dict[str, float]]) -> List[
87108
Returns:
88109
List: A list of Experience objects.
89110
"""
90-
exps = self.model.extract_experience_from_history()
91-
for exp in exps:
92-
exp.reward = reward if isinstance(reward, float) else sum(reward.values())
93-
exp.metrics = {"react_memory_length": len(self.agent.agent.memory.content)}
111+
reward_value = reward if isinstance(reward, float) else sum(reward.values())
112+
react_memory_length = len(self.agent.agent.memory.content)
113+
num_turns = len(exps)
114+
for i, exp in enumerate(exps):
115+
exp.eid.step = i
116+
exp.reward = reward_value
117+
if exp.metrics is None:
118+
exp.metrics = {}
119+
exp.metrics["react_memory_length"] = react_memory_length
120+
exp.metrics["num_turns"] = num_turns
94121
# record detailed reward if available
95122
if isinstance(reward, dict):
96123
exp.metrics.update(reward)
@@ -105,3 +132,7 @@ def asynchronous(self):
105132
def repeatable(self):
106133
"""This workflow is not repeatable."""
107134
return False
135+
136+
@property
137+
def resettable(self):
138+
return True

0 commit comments

Comments
 (0)