Skip to content

Commit fb67735

Browse files
authored
Fix/email search (#351)
1 parent 6ff5195 commit fb67735

File tree

10 files changed

+36
-8
lines changed

10 files changed

+36
-8
lines changed
-147 KB
Loading
-15.8 KB
Loading
464 KB
Loading
-50.4 KB
Loading

docs/sphinx_doc/source/tutorial/example_search_email.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,6 @@ The results are shown in the following figure (the accuracy ranges from -0.1 to
4848

4949
![](../../assets/email_rollout_accuracy.png)
5050

51+
![](../../assets/email_reward_mean.png)
5152

5253
![](../../assets/email_eval_accuracy.png)

docs/sphinx_doc/source_zh/tutorial/example_search_email.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,6 @@ trinity run --config examples/grpo_email_search/email_search.yaml
4444

4545
![](../../assets/email_rollout_accuracy.png)
4646

47+
![](../../assets/email_reward_mean.png)
48+
4749
![](../../assets/email_eval_accuracy.png)

examples/grpo_email_search/email_search.yaml

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,20 @@ algorithm:
66
repeat_times: 8
77
optimizer:
88
lr: 1e-6
9+
policy_loss_fn: "rec"
10+
policy_loss_fn_args:
11+
epsilon_low: 0.2
12+
epsilon_high: 0.2
13+
clip_mode: "one-side"
14+
weight: "none"
15+
temp: 1.0
16+
regularizer: "none"
17+
regularizer_coef: 0.0
18+
kl_loss_fn: 'k2'
19+
kl_loss_fn_args:
20+
kl_coef: 0.0
21+
advantage_fn_args:
22+
std_cal_level: 'batch'
923
model:
1024
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-4B-Instruct-2507}
1125
max_response_tokens: 4096
@@ -15,8 +29,8 @@ cluster:
1529
gpu_per_node: 8
1630
buffer:
1731
total_epochs: 1
18-
batch_size: 16
19-
train_batch_size: 640 # 16*8*5
32+
batch_size: 64
33+
train_batch_size: 2560 # 64*8*5
2034
explorer_input:
2135
taskset:
2236
name: enron_train
@@ -56,6 +70,12 @@ buffer:
5670
storage_type: queue
5771
replay_buffer:
5872
enable: true
73+
# reuse_cooldown_time is None
74+
priority_fn: 'decay_limit_randomization'
75+
priority_fn_args:
76+
decay: 2.0
77+
use_count_limit: 3
78+
sigma: 2.0
5979
explorer:
6080
eval_interval: 10
6181
max_repeat_times_per_runner: 1
@@ -93,3 +113,5 @@ trainer:
93113
use_dynamic_bsz: true
94114
max_token_len_per_gpu: 16384
95115
ulysses_sequence_parallel_size: 1
116+
monitor:
117+
monitor_type: wandb

trinity/common/workflows/agentscope/react/react_agent.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def __init__(
1717
system_prompt: str,
1818
generate_kwargs: dict,
1919
response_structure: Type[BaseModel],
20+
max_iters: int = 10,
2021
toolkit: Toolkit | None = None,
2122
):
2223
"""Initialize the AgentScope ReAct agent with specified tools and model.
@@ -44,6 +45,7 @@ def __init__(
4445
# we enable agentscope's meta tool to allow agent to call tools dynamically without pre-registration
4546
enable_meta_tool=True,
4647
toolkit=toolkit,
48+
max_iters=max_iters,
4749
)
4850
self.response_structure = response_structure
4951

trinity/common/workflows/envs/email_searcher/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def read_email_tool(message_id: str) -> Optional[Email]:
288288
############ LLM-as-a-judge ############
289289

290290

291-
def judge_correctness(
291+
async def judge_correctness(
292292
answer: str,
293293
query: QueryModel,
294294
judger: Any,
@@ -318,7 +318,7 @@ def judge_correctness(
318318
{"role": "system", "content": system_prompt},
319319
{"role": "user", "content": prompt},
320320
]
321-
completion = judger.chat.completions.create(
321+
completion = await judger.chat.completions.create(
322322
model=judger.model_path, messages=messages, stream=False
323323
)
324324
result = completion.choices[0].message.content

trinity/common/workflows/envs/email_searcher/workflow.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def reset(self, task: Task):
8282
"max_tokens": self.task.rollout_args.max_tokens or 4096,
8383
},
8484
response_structure=AnswerModel,
85+
max_iters=self.max_turns,
8586
)
8687

8788
async def run_async(self):
@@ -92,7 +93,7 @@ async def run_async(self):
9293
experiences
9394
) # NOTE: this metrics works only if the agent calls model once in each turn
9495

95-
reward_dict = self.calculate_reward(answer_and_sources)
96+
reward_dict = await self.calculate_reward(answer_and_sources)
9697
reward = sum(reward_dict.values())
9798

9899
for i, experience in enumerate(experiences):
@@ -107,7 +108,7 @@ async def run_async(self):
107108
)
108109
return experiences
109110

110-
def calculate_reward(self, answer_and_sources: Dict) -> Dict[str, float]:
111+
async def calculate_reward(self, answer_and_sources: Dict) -> Dict[str, float]:
111112
"""Ref: calculate_reward in https://github.com/OpenPipe/ART/blob/main/dev/art-e/art_e/rollout.py#L64"""
112113
try:
113114
answer = answer_and_sources.get("answer", None)
@@ -140,7 +141,7 @@ def calculate_reward(self, answer_and_sources: Dict) -> Dict[str, float]:
140141

141142
try:
142143
judge_model = self.auxiliary_models[0] if self.auxiliary_models else None
143-
judge_response = judge_correctness(answer, self.query, judge_model)
144+
judge_response = await judge_correctness(answer, self.query, judge_model)
144145
rubric.answer_correct = judge_response
145146

146147
except Exception as e:
@@ -179,4 +180,4 @@ def calculate_reward(self, answer_and_sources: Dict) -> Dict[str, float]:
179180
return result
180181

181182
self.logger.error(f"Rubric {rubric} not handled properly")
182-
raise ValueError("Rubric is not handled properly")
183+
return {"accuracy": 0.0, "format": 0.0}

0 commit comments

Comments
 (0)