Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified docs/sphinx_doc/assets/agentscope_gsm8k_reward.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/sphinx_doc/assets/email_eval_accuracy.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/sphinx_doc/assets/email_reward_mean.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/sphinx_doc/assets/email_rollout_accuracy.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/sphinx_doc/source/tutorial/example_search_email.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,6 @@ The results are shown in the following figure (the accuracy ranges from -0.1 to

![](../../assets/email_rollout_accuracy.png)

![](../../assets/email_reward_mean.png)

![](../../assets/email_eval_accuracy.png)
2 changes: 2 additions & 0 deletions docs/sphinx_doc/source_zh/tutorial/example_search_email.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,6 @@ trinity run --config examples/grpo_email_search/email_search.yaml

![](../../assets/email_rollout_accuracy.png)

![](../../assets/email_reward_mean.png)

![](../../assets/email_eval_accuracy.png)
26 changes: 24 additions & 2 deletions examples/grpo_email_search/email_search.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,20 @@ algorithm:
repeat_times: 8
optimizer:
lr: 1e-6
policy_loss_fn: "rec"
policy_loss_fn_args:
epsilon_low: 0.2
epsilon_high: 0.2
clip_mode: "one-side"
weight: "none"
temp: 1.0
regularizer: "none"
regularizer_coef: 0.0
kl_loss_fn: 'k2'
kl_loss_fn_args:
kl_coef: 0.0
advantage_fn_args:
std_cal_level: 'batch'
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-4B-Instruct-2507}
max_response_tokens: 4096
Expand All @@ -15,8 +29,8 @@ cluster:
gpu_per_node: 8
buffer:
total_epochs: 1
batch_size: 16
train_batch_size: 640 # 16*8*5
batch_size: 64
train_batch_size: 2560 # 64*8*5
explorer_input:
taskset:
name: enron_train
Expand Down Expand Up @@ -56,6 +70,12 @@ buffer:
storage_type: queue
replay_buffer:
enable: true
# reuse_cooldown_time is None
priority_fn: 'decay_limit_randomization'
priority_fn_args:
decay: 2.0
use_count_limit: 3
sigma: 2.0
explorer:
eval_interval: 10
max_repeat_times_per_runner: 1
Expand Down Expand Up @@ -93,3 +113,5 @@ trainer:
use_dynamic_bsz: true
max_token_len_per_gpu: 16384
ulysses_sequence_parallel_size: 1
monitor:
monitor_type: wandb
2 changes: 2 additions & 0 deletions trinity/common/workflows/agentscope/react/react_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(
system_prompt: str,
generate_kwargs: dict,
response_structure: Type[BaseModel],
max_iters: int = 10,
toolkit: Toolkit | None = None,
):
"""Initialize the AgentScope ReAct agent with specified tools and model.
Expand Down Expand Up @@ -44,6 +45,7 @@ def __init__(
# we enable agentscope's meta tool to allow agent to call tools dynamically without pre-registration
enable_meta_tool=True,
toolkit=toolkit,
max_iters=max_iters,
)
self.response_structure = response_structure

Expand Down
4 changes: 2 additions & 2 deletions trinity/common/workflows/envs/email_searcher/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def read_email_tool(message_id: str) -> Optional[Email]:
############ LLM-as-a-judge ############


def judge_correctness(
async def judge_correctness(
answer: str,
query: QueryModel,
judger: Any,
Expand Down Expand Up @@ -318,7 +318,7 @@ def judge_correctness(
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
completion = judger.chat.completions.create(
completion = await judger.chat.completions.create(
model=judger.model_path, messages=messages, stream=False
)
result = completion.choices[0].message.content
Expand Down
9 changes: 5 additions & 4 deletions trinity/common/workflows/envs/email_searcher/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def reset(self, task: Task):
"max_tokens": self.task.rollout_args.max_tokens or 4096,
},
response_structure=AnswerModel,
max_iters=self.max_turns,
)

async def run_async(self):
Expand All @@ -92,7 +93,7 @@ async def run_async(self):
experiences
) # NOTE: this metrics works only if the agent calls model once in each turn

reward_dict = self.calculate_reward(answer_and_sources)
reward_dict = await self.calculate_reward(answer_and_sources)
reward = sum(reward_dict.values())

for i, experience in enumerate(experiences):
Expand All @@ -107,7 +108,7 @@ async def run_async(self):
)
return experiences

def calculate_reward(self, answer_and_sources: Dict) -> Dict[str, float]:
async def calculate_reward(self, answer_and_sources: Dict) -> Dict[str, float]:
"""Ref: calculate_reward in https://github.com/OpenPipe/ART/blob/main/dev/art-e/art_e/rollout.py#L64"""
try:
answer = answer_and_sources.get("answer", None)
Expand Down Expand Up @@ -140,7 +141,7 @@ def calculate_reward(self, answer_and_sources: Dict) -> Dict[str, float]:

try:
judge_model = self.auxiliary_models[0] if self.auxiliary_models else None
judge_response = judge_correctness(answer, self.query, judge_model)
judge_response = await judge_correctness(answer, self.query, judge_model)
rubric.answer_correct = judge_response

except Exception as e:
Expand Down Expand Up @@ -179,4 +180,4 @@ def calculate_reward(self, answer_and_sources: Dict) -> Dict[str, float]:
return result

self.logger.error(f"Rubric {rubric} not handled properly")
raise ValueError("Rubric is not handled properly")
return {"accuracy": 0.0, "format": 0.0}