diff --git a/docs/sphinx_doc/assets/agentscope_gsm8k_reward.png b/docs/sphinx_doc/assets/agentscope_gsm8k_reward.png index 61ff756e8a..867842c2c4 100644 Binary files a/docs/sphinx_doc/assets/agentscope_gsm8k_reward.png and b/docs/sphinx_doc/assets/agentscope_gsm8k_reward.png differ diff --git a/docs/sphinx_doc/assets/email_eval_accuracy.png b/docs/sphinx_doc/assets/email_eval_accuracy.png index 215110084b..311d4961a1 100644 Binary files a/docs/sphinx_doc/assets/email_eval_accuracy.png and b/docs/sphinx_doc/assets/email_eval_accuracy.png differ diff --git a/docs/sphinx_doc/assets/email_reward_mean.png b/docs/sphinx_doc/assets/email_reward_mean.png new file mode 100644 index 0000000000..45c102591e Binary files /dev/null and b/docs/sphinx_doc/assets/email_reward_mean.png differ diff --git a/docs/sphinx_doc/assets/email_rollout_accuracy.png b/docs/sphinx_doc/assets/email_rollout_accuracy.png index ce7848d081..729d97aae7 100644 Binary files a/docs/sphinx_doc/assets/email_rollout_accuracy.png and b/docs/sphinx_doc/assets/email_rollout_accuracy.png differ diff --git a/docs/sphinx_doc/source/tutorial/example_search_email.md b/docs/sphinx_doc/source/tutorial/example_search_email.md index e35e7b08bc..6be9e71320 100644 --- a/docs/sphinx_doc/source/tutorial/example_search_email.md +++ b/docs/sphinx_doc/source/tutorial/example_search_email.md @@ -48,5 +48,6 @@ The results are shown in the following figure (the accuracy ranges from -0.1 to ![](../../assets/email_rollout_accuracy.png) +![](../../assets/email_reward_mean.png) ![](../../assets/email_eval_accuracy.png) diff --git a/docs/sphinx_doc/source_zh/tutorial/example_search_email.md b/docs/sphinx_doc/source_zh/tutorial/example_search_email.md index 086c21f7d4..79660657ab 100644 --- a/docs/sphinx_doc/source_zh/tutorial/example_search_email.md +++ b/docs/sphinx_doc/source_zh/tutorial/example_search_email.md @@ -44,4 +44,6 @@ trinity run --config examples/grpo_email_search/email_search.yaml ![](../../assets/email_rollout_accuracy.png) +![](../../assets/email_reward_mean.png) + ![](../../assets/email_eval_accuracy.png) diff --git a/examples/grpo_email_search/email_search.yaml b/examples/grpo_email_search/email_search.yaml index c3227b1456..24a58a8c96 100644 --- a/examples/grpo_email_search/email_search.yaml +++ b/examples/grpo_email_search/email_search.yaml @@ -6,6 +6,20 @@ algorithm: repeat_times: 8 optimizer: lr: 1e-6 + policy_loss_fn: "rec" + policy_loss_fn_args: + epsilon_low: 0.2 + epsilon_high: 0.2 + clip_mode: "one-side" + weight: "none" + temp: 1.0 + regularizer: "none" + regularizer_coef: 0.0 + kl_loss_fn: 'k2' + kl_loss_fn_args: + kl_coef: 0.0 + advantage_fn_args: + std_cal_level: 'batch' model: model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-4B-Instruct-2507} max_response_tokens: 4096 @@ -15,8 +29,8 @@ cluster: gpu_per_node: 8 buffer: total_epochs: 1 - batch_size: 16 - train_batch_size: 640 # 16*8*5 + batch_size: 64 + train_batch_size: 2560 # 64*8*5 explorer_input: taskset: name: enron_train @@ -56,6 +70,12 @@ buffer: storage_type: queue replay_buffer: enable: true + # reuse_cooldown_time is None + priority_fn: 'decay_limit_randomization' + priority_fn_args: + decay: 2.0 + use_count_limit: 3 + sigma: 2.0 explorer: eval_interval: 10 max_repeat_times_per_runner: 1 @@ -93,3 +113,5 @@ trainer: use_dynamic_bsz: true max_token_len_per_gpu: 16384 ulysses_sequence_parallel_size: 1 +monitor: + monitor_type: wandb diff --git a/trinity/common/workflows/agentscope/react/react_agent.py b/trinity/common/workflows/agentscope/react/react_agent.py index df5c2be3cf..0cf8babe22 100644 --- a/trinity/common/workflows/agentscope/react/react_agent.py +++ b/trinity/common/workflows/agentscope/react/react_agent.py @@ -17,6 +17,7 @@ def __init__( system_prompt: str, generate_kwargs: dict, response_structure: Type[BaseModel], + max_iters: int = 10, toolkit: Toolkit | None = None, ): """Initialize the AgentScope ReAct agent with specified tools and model. @@ -44,6 +45,7 @@ def __init__( # we enable agentscope's meta tool to allow agent to call tools dynamically without pre-registration enable_meta_tool=True, toolkit=toolkit, + max_iters=max_iters, ) self.response_structure = response_structure diff --git a/trinity/common/workflows/envs/email_searcher/utils.py b/trinity/common/workflows/envs/email_searcher/utils.py index 6583c613b7..daf88861db 100644 --- a/trinity/common/workflows/envs/email_searcher/utils.py +++ b/trinity/common/workflows/envs/email_searcher/utils.py @@ -288,7 +288,7 @@ def read_email_tool(message_id: str) -> Optional[Email]: ############ LLM-as-a-judge ############ -def judge_correctness( +async def judge_correctness( answer: str, query: QueryModel, judger: Any, @@ -318,7 +318,7 @@ def judge_correctness( {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}, ] - completion = judger.chat.completions.create( + completion = await judger.chat.completions.create( model=judger.model_path, messages=messages, stream=False ) result = completion.choices[0].message.content diff --git a/trinity/common/workflows/envs/email_searcher/workflow.py b/trinity/common/workflows/envs/email_searcher/workflow.py index 737f9fc279..0df651c72d 100644 --- a/trinity/common/workflows/envs/email_searcher/workflow.py +++ b/trinity/common/workflows/envs/email_searcher/workflow.py @@ -82,6 +82,7 @@ def reset(self, task: Task): "max_tokens": self.task.rollout_args.max_tokens or 4096, }, response_structure=AnswerModel, + max_iters=self.max_turns, ) async def run_async(self): @@ -92,7 +93,7 @@ async def run_async(self): experiences ) # NOTE: this metrics works only if the agent calls model once in each turn - reward_dict = self.calculate_reward(answer_and_sources) + reward_dict = await self.calculate_reward(answer_and_sources) reward = sum(reward_dict.values()) for i, experience in enumerate(experiences): @@ -107,7 +108,7 @@ async def run_async(self): ) return experiences - def calculate_reward(self, answer_and_sources: Dict) -> Dict[str, float]: + async def calculate_reward(self, answer_and_sources: Dict) -> Dict[str, float]: """Ref: calculate_reward in https://github.com/OpenPipe/ART/blob/main/dev/art-e/art_e/rollout.py#L64""" try: answer = answer_and_sources.get("answer", None) @@ -140,7 +141,7 @@ def calculate_reward(self, answer_and_sources: Dict) -> Dict[str, float]: try: judge_model = self.auxiliary_models[0] if self.auxiliary_models else None - judge_response = judge_correctness(answer, self.query, judge_model) + judge_response = await judge_correctness(answer, self.query, judge_model) rubric.answer_correct = judge_response except Exception as e: @@ -179,4 +180,4 @@ def calculate_reward(self, answer_and_sources: Dict) -> Dict[str, float]: return result self.logger.error(f"Rubric {rubric} not handled properly") - raise ValueError("Rubric is not handled properly") + return {"accuracy": 0.0, "format": 0.0}