Skip to content

Commit 894858b

Browse files
authored
Non-verifiable Medicine QA Task (agentscope-ai#317)
1 parent 51c4287 commit 894858b

File tree

7 files changed

+287
-1
lines changed

7 files changed

+287
-1
lines changed
367 KB
Loading
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Non-Verifiable Medicine QA
2+
3+
This example shows how to use LLM judge and rubrics to compute reward for a non-verifiable medicine QA task. This is inspired by the [RaR-Implicit](https://arxiv.org/pdf/2507.17746) method.
4+
5+
Before running this example, please make sure you have prepared the environment and the dataset [anisha2102/RaR-Medicine](https://huggingface.co/datasets/anisha2102/RaR-Medicine).
6+
7+
The RaR-Medicine dataset contains around 20k QA pairs with rubrics in medicine domain. Unlike math scenarios, it is infeasible to obtain verifiable rewards for this dataset. Below is an example data sample:
8+
9+
```json
10+
{
11+
"question": "What is the most sensitive imaging modality for diagnosing a ureteric stone in a patient presenting with acute renal colic?",
12+
"reference_answer": "The most sensitive imaging modality for diagnosing a ureteric stone in a patient presenting with acute renal colic is a non-contrast helical CT scan. This method is highly accurate, able to detect stones of varying sizes and compositions, and preferred due to its quick and reliable results without the need for contrast, making it the gold standard in such cases.",
13+
"rubric": [
14+
{
15+
"description": "Essential Criteria: Identifies non-contrast helical CT scan as the most sensitive modality for ureteric stones.",
16+
"title": "Identify Most Sensitive Modality",
17+
"weight": 5
18+
},
19+
...
20+
]
21+
}
22+
```
23+
24+
In the RaR-Implicit method, the LLM judge scores a group of responses by evaluating them against the provided rubrics and outputs the score in the range of [0, 1] for each response. The higher the score, the better the response is according to the rubrics.
25+
26+
27+
The config file is located in [`rubric.yaml`](./rubric.yaml).
28+
To run this example, you can run the following command:
29+
```bash
30+
trinity run --config examples/grpo_rubric_as_reward/rubric.yaml
31+
```
32+
33+
With the provided configurations, we can see the reward is increasing over the training steps:
34+
35+
![reward](../../docs/sphinx_doc/assets/grpo_rubric_reward.png)
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
project: "Trinity-RFT-Example"
2+
name: "MedicineQA_grpo_rubric_as_reward"
3+
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
4+
algorithm:
5+
algorithm_type: grpo
6+
advantage_fn_args:
7+
std_threshold: 0.0001 # effectively zero
8+
repeat_times: 8
9+
model:
10+
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-4B-Instruct-2507}
11+
max_response_tokens: 1024
12+
max_model_len: 10240
13+
cluster:
14+
node_num: 1
15+
gpu_per_node: 8
16+
buffer:
17+
total_epochs: 2
18+
batch_size: 96
19+
train_batch_size: 768 # 8*96
20+
explorer_input:
21+
taskset:
22+
name: rar_medicine
23+
storage_type: file
24+
path: anisha2102/RaR-Medicine
25+
split: train
26+
format:
27+
prompt_key: 'question'
28+
response_key: 'reference_answer'
29+
rollout_args:
30+
temperature: 1.0
31+
enable_progress_bar: false
32+
default_workflow_type: 'rubric_judge_workflow'
33+
trainer_input:
34+
experience_buffer:
35+
name: experience_buffer
36+
storage_type: queue
37+
use_priority_queue: true
38+
explorer:
39+
eval_interval: 10
40+
max_timeout: 3600
41+
rollout_model:
42+
engine_num: 2
43+
tensor_parallel_size: 1
44+
enable_prefix_caching: false
45+
enforce_eager: true
46+
dtype: bfloat16
47+
seed: 42
48+
auxiliary_models:
49+
- model_path: Qwen/Qwen3-30B-A3B-Instruct-2507
50+
engine_num: 1
51+
tensor_parallel_size: 2
52+
enable_thinking: false
53+
max_prompt_tokens: 19456
54+
max_response_tokens: 1024
55+
max_model_len: 20480
56+
synchronizer:
57+
sync_style: dynamic_by_explorer
58+
sync_method: 'nccl'
59+
sync_interval: 5
60+
sync_timeout: 3600
61+
trainer:
62+
save_interval: 100
63+
trainer_config:
64+
actor_rollout_ref:
65+
model:
66+
use_remove_padding: true
67+
actor:
68+
use_dynamic_bsz: true
69+
ppo_max_token_len_per_gpu: 16384
70+
ulysses_sequence_parallel_size: 1
71+
optim:
72+
lr: 2e-6
73+
ref:
74+
log_prob_use_dynamic_bsz: ${trainer.trainer_config.actor_rollout_ref.actor.use_dynamic_bsz}
75+
log_prob_max_token_len_per_gpu: ${trainer.trainer_config.actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
76+
ulysses_sequence_parallel_size: ${trainer.trainer_config.actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size

trinity/common/models/vllm_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def __init__(
8383
gpu_memory_utilization=config.gpu_memory_utilization,
8484
enable_chunked_prefill=config.enable_chunked_prefill,
8585
# max_num_batched_tokens=256, # you can further set this parameter to reduce the vllm peak memory usage
86+
disable_log_stats=True,
8687
enable_lora=config.enable_lora,
8788
**config.lora_kwargs,
8889
)

trinity/common/workflows/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from trinity.common.workflows.math_trainable_ruler_workflow import (
4646
MathTrainableRULERWorkflow,
4747
)
48+
from trinity.common.workflows.rubric_judge_workflow import RubricJudgeWorkflow
4849
from trinity.common.workflows.simple_mm_workflow import (
4950
AsyncSimpleMMWorkflow,
5051
SimpleMMWorkflow,
@@ -90,4 +91,5 @@
9091
"MathTrainableRULERWorkflow",
9192
"AsyncSimpleMMWorkflow",
9293
"SimpleMMWorkflow",
94+
"RubricJudgeWorkflow",
9395
]
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# -*- coding: utf-8 -*-
2+
"""A workflow with LLM-as-a-judge."""
3+
import json
4+
from typing import List, Optional, Tuple
5+
6+
import openai
7+
8+
from trinity.common.experience import Experience
9+
from trinity.common.models.model import ModelWrapper
10+
from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task
11+
12+
13+
@WORKFLOWS.register_module("rubric_judge_workflow")
14+
class RubricJudgeWorkflow(SimpleWorkflow):
15+
"""A workflow using LLM-as-a-judge and rubrics to get the reward.
16+
17+
Adapted from https://arxiv.org/pdf/2507.17746
18+
"""
19+
20+
def __init__(
21+
self,
22+
*,
23+
task: Task,
24+
model: ModelWrapper,
25+
auxiliary_models: Optional[List[openai.OpenAI]] = None,
26+
):
27+
super().__init__(
28+
task=task,
29+
model=model,
30+
auxiliary_models=auxiliary_models,
31+
)
32+
33+
def reset(self, task: Task):
34+
"""Modified from SimpleWorkflow.reset"""
35+
self.format_args = task.format_args
36+
self.system_prompt = task.format_args.system_prompt
37+
self.reply_prefix = task.format_args.reply_prefix
38+
39+
if self.system_prompt is None:
40+
self.system_prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.
41+
"""
42+
43+
self.raw_task = task.raw_task
44+
self.task_desc = task.task_desc
45+
self.truth = task.truth
46+
self.rubric = self.raw_task.get("rubric", [])
47+
48+
def run(self) -> List[Experience]:
49+
"""Modified from SimpleWorkflow.run"""
50+
51+
messages = self.format_messages()
52+
53+
self.logger.debug("start chat")
54+
responses = self.model.chat(messages, **self.rollout_args)
55+
56+
# === Calculate rubric-based rewards ===
57+
assert (
58+
self.auxiliary_models is not None
59+
), "Current implementation of rubric-based rewards requires that auxiliary_models is not None."
60+
61+
judge_success_list = []
62+
for i, response in enumerate(responses):
63+
judge_success, reward = self.get_judge_reward(
64+
response=response.response_text, judger=self.auxiliary_models[0]
65+
)
66+
response.reward = reward
67+
response.eid.run = i + self.run_id_base
68+
69+
judge_success_list.append(judge_success)
70+
71+
if i == 0:
72+
self.logger.debug(
73+
f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {response.reward}"
74+
)
75+
76+
# record judge success
77+
judge_success_rate = (
78+
sum(judge_success_list) / len(judge_success_list) if judge_success_list else 0.0
79+
)
80+
for response in responses:
81+
if response.metrics is None:
82+
response.metrics = {}
83+
response.metrics.update({"judge_success": float(judge_success_rate)})
84+
85+
return responses
86+
87+
def get_judge_reward(self, response: str, judger: openai.OpenAI) -> Tuple[bool, float]:
88+
"""Get rewards with LLM-as-a-judge
89+
The prompts are adapted from RAR-IMPLICIT method in https://arxiv.org/pdf/2507.17746
90+
"""
91+
# Step 1: format prompts
92+
# system prompt
93+
ruler_system_prompt = """You are an expert evaluator. Given a user prompt, a generated response, and a list of quality rubrics, please rate the overall quality of the response on a scale of 1 to 10 based on how well it satisfies the rubrics.
94+
Consider all rubrics holistically when determining your score. A response that violates multiple rubrics should receive a lower score, while a response that satisfies all rubrics should receive a higher score.
95+
Start your response with a valid JSON object that starts with "```json" and ends with "```". The JSON object should contain
96+
a single key "rating" and the value should be an integer between 1 and 10.
97+
Example response:
98+
```json
99+
{
100+
"rating": 7
101+
}```"""
102+
# user prompt
103+
if len(self.rubric) > 0:
104+
rubric_prompt_parts = [
105+
f"Rubric {i} (weight: {single_rubric['weight']}): {single_rubric['description']}"
106+
for i, single_rubric in enumerate(self.rubric)
107+
]
108+
rubric_list_string = "\n".join(rubric_prompt_parts)
109+
else:
110+
self.logger.warning("No rubric is provided!")
111+
rubric_list_string = "Rubrics are not provided."
112+
113+
ruler_user_prompt = f"""Given the following prompt, response, and rubrics, please rate the overall quality of the response on a scale of 1 to 10 based
114+
on how well it satisfies the rubrics.
115+
<prompt>
116+
{self.task_desc}
117+
</prompt>
118+
<response>
119+
{response}
120+
</response>
121+
<rubrics>
122+
{rubric_list_string}
123+
</rubrics>
124+
Your JSON Evaluation:
125+
""".strip()
126+
127+
# Step 2: invoke judger LLM
128+
messages = [
129+
{"role": "system", "content": ruler_system_prompt},
130+
{"role": "user", "content": ruler_user_prompt},
131+
]
132+
completion = judger.chat.completions.create(
133+
model=judger.model_path, messages=messages, stream=False, temperature=0.0
134+
)
135+
judger_response = completion.choices[0].message.content
136+
self.logger.debug(f"LLM judge response: {judger_response}")
137+
138+
# Step 3: extract score from judger's response (expecting a JSON block with "rating")
139+
try:
140+
# Extract content between ```json and ```
141+
start_tag = "```json"
142+
start_index = judger_response.find(start_tag)
143+
if start_index == -1:
144+
start_tag = "```"
145+
start_index = judger_response.find(start_tag)
146+
147+
if start_index == -1:
148+
self.logger.warning("No JSON code block found in judger response.")
149+
return False, 0.0
150+
151+
end_index = judger_response.find("```", start_index + len(start_tag))
152+
if end_index == -1:
153+
self.logger.warning("Malformed JSON code block in judger response.")
154+
return False, 0.0
155+
156+
json_str = judger_response[start_index + len(start_tag) : end_index].strip()
157+
parsed = json.loads(json_str)
158+
rating = parsed.get("rating")
159+
160+
if not isinstance(rating, (int, float)) or not (1 <= rating <= 10):
161+
self.logger.warning(f"Invalid or out-of-range rating: {rating}")
162+
return False, 0.0
163+
164+
normalized_score = rating * 0.1 # Normalize 1-10 to 0-1 scale
165+
return True, normalized_score
166+
167+
except json.JSONDecodeError as e:
168+
self.logger.warning(f"Failed to parse JSON from judger response: {e}")
169+
return False, 0.0
170+
except Exception as e:
171+
self.logger.warning(f"Unexpected error when processing judger response: {e}")
172+
return False, 0.0

trinity/trainer/verl_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242
class CheckpointMonitor:
4343
def __init__(self, default_local_dir: str, default_hdfs_dir: str = None):
44-
self.logger = get_logger("Checkpoint Monitor", in_ray_actor=True)
44+
self.logger = get_logger("checkpoint_monitor", in_ray_actor=True)
4545
self.default_local_dir = default_local_dir
4646
self.default_hdfs_dir = default_hdfs_dir
4747
self.local_latest_checkpointed_iteration = os.path.join(

0 commit comments

Comments
 (0)