Skip to content

Commit bcf2459

Browse files
authored
Merge pull request #6314 from hpcaitech/grpo-reward-dev
[feat] upgrade reward functions
2 parents 3c42c0c + f8bd2db commit bcf2459

File tree

6 files changed

+183
-51
lines changed

6 files changed

+183
-51
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,4 @@ applications/ColossalChat/tests/logs
166166
applications/ColossalChat/wandb
167167
applications/ColossalChat/model
168168
applications/ColossalChat/eval
169+
applications/ColossalChat/rollouts

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,20 @@ def __init__(
120120
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
121121
)
122122
# Initialize verifiable reward.
123-
response_format_tags = {
124-
"think_start": {"text": "<think>", "num_occur": 1},
125-
"think_end": {"text": "</think>", "num_occur": 1},
126-
"answer_start": {"text": "<answer>", "num_occur": 1},
127-
"answer_end": {"text": "</answer>", "num_occur": 1},
128-
}
123+
response_format_tags = (
124+
{
125+
"think_start": {"text": "<think>", "num_occur": 1},
126+
"think_end": {"text": "</think>", "num_occur": 1},
127+
"answer_start": {"text": "<answer>", "num_occur": 1},
128+
"answer_end": {"text": "</answer>", "num_occur": 1},
129+
}
130+
if grpo_config.get("reward_fn_type") == "think_answer_tags"
131+
else None
132+
)
129133
reward_model_kwargs = {
130-
k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"]
134+
k: v
135+
for k, v in grpo_config.items()
136+
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"]
131137
}
132138
self.reward_model = VerifiableReward(
133139
reward_fns=[

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def launch_distributed(
5555
eval_interval: int = 100,
5656
eval_save_dir: Optional[str] = None,
5757
eval_generation_config: Optional[Dict[str, Any]] = None,
58+
log_rollout_interval: int = 20,
59+
rollout_save_dir: str = "./rollout",
5860
):
5961
if core_algo not in ALGO_MAP:
6062
raise NotImplementedError(f"{core_algo} is not supported yet.")
@@ -72,6 +74,10 @@ def launch_distributed(
7274

7375
run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
7476
wandb_group_name = str(uuid.uuid4())
77+
rollout_log_file = os.path.join(
78+
rollout_save_dir,
79+
f"{project_name.replace(' ','_')}_run_{wandb_group_name}.jsonl",
80+
)
7581

7682
procs = []
7783
for i in range(num_producers):
@@ -98,6 +104,8 @@ def launch_distributed(
98104
project_name=project_name,
99105
run_name=run_name,
100106
wandb_group_name=wandb_group_name,
107+
log_rollout_interval=log_rollout_interval,
108+
rollout_log_file=rollout_log_file,
101109
)
102110
procs.append(producer)
103111
generate_config_consumer = copy.deepcopy(generate_config)

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import json
23
import os
34
from typing import Any, Dict, Optional
45

@@ -49,7 +50,8 @@ def __init__(
4950
project_name: str = None,
5051
run_name: str = None,
5152
wandb_group_name: str = None,
52-
wandb_log_rollout_interval: int = 20,
53+
log_rollout_interval: int = 20,
54+
rollout_log_file: str = "./rollout_log.jsonl",
5355
):
5456
self.producer_idx = producer_idx
5557
self.num_producers = num_producers
@@ -70,9 +72,16 @@ def __init__(
7072
self.eval_save_dir = eval_save_dir
7173
self.consumer_global_step = 0
7274
self.eval_mode = False
73-
self.wandb_rollout_data = []
74-
self.wandb_log_rollout_interval = wandb_log_rollout_interval
75+
self.log_rollout_interval = log_rollout_interval
7576
self.latest_rollout_log_step = -1
77+
if producer_idx == 0:
78+
if os.path.exists(rollout_log_file):
79+
raise ValueError(
80+
f"Rollout log file {rollout_log_file} already exists. Please delete it or change the name."
81+
)
82+
else:
83+
os.makedirs(os.path.dirname(rollout_log_file), exist_ok=True)
84+
self.rollout_log_file = open(rollout_log_file, "w", encoding="utf8")
7685
if self.producer_idx == 0:
7786
self.wandb_run = wandb.init(
7887
project=project_name,
@@ -320,6 +329,8 @@ def __init__(
320329
project_name: str = None,
321330
run_name: str = None,
322331
wandb_group_name: str = None,
332+
log_rollout_interval: int = 20,
333+
rollout_log_file: str = "./rollout_log.jsonl",
323334
):
324335
super().__init__(
325336
producer_idx,
@@ -342,6 +353,8 @@ def __init__(
342353
project_name=project_name,
343354
run_name=run_name,
344355
wandb_group_name=wandb_group_name,
356+
log_rollout_interval=log_rollout_interval,
357+
rollout_log_file=rollout_log_file,
345358
)
346359
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
347360
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
@@ -353,26 +366,31 @@ def __init__(
353366
def rollout(self, input_ids, attention_mask, **kwargs):
354367
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
355368
if self.producer_idx == 0 and not self.eval_mode:
356-
wandb_rollout_data = self.wandb_rollout_data + [
357-
[
358-
str(self.consumer_global_step),
359-
str(self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)),
360-
]
361-
]
362369
if (
363-
self.consumer_global_step - self.latest_rollout_log_step >= self.wandb_log_rollout_interval
370+
self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval
364371
or self.latest_rollout_log_step == -1
365372
):
366-
self.wandb_rollout_data = wandb_rollout_data
367-
self.latest_rollout_log_step = self.consumer_global_step
368-
self.wandb_run.log(
369-
{
370-
"rollout/rollout_examples": wandb.Table(
371-
columns=["train_step", "rollout_examples"], data=wandb_rollout_data
373+
new_record = (
374+
json.dumps(
375+
{
376+
"train_step": self.consumer_global_step,
377+
"rollout": self.tokenizer.batch_decode(
378+
rollouts["input_ids"][:, 0], skip_special_tokens=True
379+
),
380+
}
372381
)
373-
}
374-
)
382+
+ "\n"
383+
)
384+
self.rollout_log_file.write(new_record)
385+
self.rollout_log_file.flush()
386+
self.latest_rollout_log_step = self.consumer_global_step
375387
return rollouts
376388

389+
def __del__(self):
390+
if self.producer_idx == 0:
391+
self.wandb_run.finish()
392+
if hasattr(self, "rollout_log_file"):
393+
self.rollout_log_file.close()
394+
377395
def load_state_dict(self, state_dict):
378396
self.model.load_state_dict(state_dict)

applications/ColossalChat/coati/distributed/reward/reward_fn.py

Lines changed: 117 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,77 @@
11
import torch
2+
from latex2sympy2_extended import NormalizationConfig
3+
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
24

35
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
46

7+
CANNOT_PARSE_GT_ANSWER = -1
8+
CANNOT_PARSE_PREDICTION = -2
9+
SUCCESS = 1
10+
MATCHING_FAIL = 0
11+
12+
13+
def verify_math_representation(completion, gt_answer):
14+
"""
15+
Verify if the completion is a valid math representation of the gt_answer.
16+
"""
17+
if not completion.startswith("\\boxed{"):
18+
completion = "\\boxed{" + completion + "}"
19+
if not gt_answer.startswith("\\boxed{"):
20+
gt_answer = "\\boxed{" + gt_answer + "}"
21+
target = (
22+
ExprExtractionConfig(),
23+
LatexExtractionConfig(
24+
normalization_config=NormalizationConfig(
25+
nits=False,
26+
malformed_operators=False,
27+
basic_latex=True,
28+
boxed="all",
29+
units=True,
30+
),
31+
boxed_match_priority=0,
32+
),
33+
)
34+
if not isinstance(gt_answer, str) or len(gt_answer) == 0:
35+
raise ValueError("gt_answer should be a string, please verify your training data.")
36+
if not isinstance(completion, str) or len(completion) == 0:
37+
return MATCHING_FAIL
38+
try:
39+
parsed_gt_answer = parse(gt_answer, extraction_config=target)
40+
if len(parsed_gt_answer) == 0:
41+
return CANNOT_PARSE_GT_ANSWER
42+
parsed_completion = parse(completion, extraction_config=target)
43+
if len(parsed_completion) == 0:
44+
return CANNOT_PARSE_PREDICTION
45+
if verify(parsed_gt_answer, parsed_completion):
46+
return SUCCESS
47+
else:
48+
return MATCHING_FAIL
49+
except Exception:
50+
return MATCHING_FAIL
51+
52+
53+
def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward):
54+
math_verify_result = verify_math_representation(decoded_final_answer, gt_answer)
55+
exact_match_result = (
56+
SUCCESS
57+
if decoded_final_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", "")
58+
== gt_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", "")
59+
else MATCHING_FAIL
60+
)
61+
if math_verify_result == SUCCESS:
62+
ans_acc += 1
63+
reward += acc_score
64+
elif exact_match_result == SUCCESS:
65+
# sometimes for answers that's not a (valid) math expression, math_verify will fail
66+
ans_acc += 1
67+
if math_verify_result == CANNOT_PARSE_PREDICTION:
68+
reward += (
69+
acc_score / 2
70+
) # not a valid latex math representation, but the answer is correct, receive half of the score
71+
else:
72+
reward += acc_score
73+
return reward, ans_acc
74+
575

676
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
777
tokenizer = kwargs["tokenizer"]
@@ -14,15 +84,18 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
1484
s, e = response_idx[0], response_idx[1]
1585

1686
length_reward = 0.0
17-
if soft_over_length_punishment:
18-
max_length = kwargs.get("max_length", 1024 * 4)
19-
cache_length = kwargs.get("cache_length", 512)
20-
res_length = e.item() - s.item() + 1
21-
if max_length - cache_length < res_length < max_length:
22-
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
87+
res_length = e.item() - s.item() + 1
88+
if not eval_mode:
89+
max_new_tokens = kwargs["max_new_tokens"]
90+
else:
91+
max_new_tokens = -1 # for eval mode, we don't need to check the length
92+
if not eval_mode and soft_over_length_punishment:
93+
cache_length = kwargs["cache_length"]
94+
if max_new_tokens - cache_length < res_length < max_new_tokens:
95+
length_reward = ((max_new_tokens - cache_length) - res_length) / cache_length * acc_score
2396

2497
if gt_answer is None:
25-
return reward
98+
raise ValueError("no gt_answer is provided, please check your training dataset.")
2699

27100
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
28101
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
@@ -35,15 +108,15 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
35108
format_acc += 1
36109

37110
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
38-
if (
39-
format_valid
40-
and final_answer is not None
41-
and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower()
42-
):
43-
ans_acc += 1
44-
reward += acc_score
111+
if final_answer is not None:
112+
if eval_mode or format_valid:
113+
reward, ans_acc = verify_model_answer(final_answer, gt_answer, ans_acc, acc_score, reward)
114+
if not eval_mode:
115+
reward = reward + length_reward
45116

46-
reward = reward + length_reward
117+
# Check if the sequence is over length
118+
if not eval_mode and res_length >= max_new_tokens:
119+
reward *= 0.0
47120

48121
if not eval_mode:
49122
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
@@ -56,6 +129,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
56129
"parsed": final_answer,
57130
"format_valid": format_acc.item(),
58131
"ans_valid": ans_acc.item(),
132+
"response_length": res_length,
133+
"reward": reward.item(),
59134
}
60135

61136

@@ -71,31 +146,45 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
71146
s, e = response_idx[0], response_idx[1]
72147

73148
length_reward = 0.0
74-
if soft_over_length_punishment:
75-
max_length = kwargs.get("max_length", 1024 * 4)
76-
cache_length = kwargs.get("cache_length", 512)
77-
res_length = e.item() - s.item() + 1
78-
if max_length - cache_length < res_length < max_length:
79-
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
149+
res_length = e.item() - s.item() + 1
150+
if not eval_mode:
151+
max_new_tokens = kwargs["max_new_tokens"]
152+
else:
153+
max_new_tokens = -1 # for eval mode, we don't need to check the length
154+
if not eval_mode and soft_over_length_punishment:
155+
cache_length = kwargs["cache_length"]
156+
if max_new_tokens - cache_length < res_length < max_new_tokens:
157+
length_reward = ((max_new_tokens - cache_length) - res_length) / cache_length * acc_score
80158

81159
if gt_answer is None:
82-
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
160+
raise ValueError("no gt_answer is provided, please check your training dataset.")
83161

84162
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
163+
85164
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
86165
final_answer = extract_boxed_solution(decoded_final_answer)
87166
format_valid = final_answer is not None
167+
if "tags" in kwargs and kwargs["tags"]:
168+
tags = kwargs["tags"]
169+
format_valid = format_valid and all(
170+
[decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags]
171+
)
88172
# Check format accuracy
89173
if format_valid:
90174
format_acc += 1
91175
reward += format_score
92176

93177
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
94-
if format_valid and final_answer is not None and gt_answer.strip().lower() == final_answer.strip().lower():
95-
ans_acc += 1
96-
reward += acc_score
178+
if final_answer is not None:
179+
if eval_mode or format_valid:
180+
reward, ans_acc = verify_model_answer(final_answer, gt_answer, ans_acc, acc_score, reward)
181+
if not eval_mode:
182+
reward = reward + length_reward
183+
184+
# Check if the sequence is over length
185+
if not eval_mode and res_length >= max_new_tokens:
186+
reward *= 0.0
97187

98-
reward = reward + length_reward
99188
if not eval_mode:
100189
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
101190
else:
@@ -107,4 +196,6 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
107196
"parsed": final_answer,
108197
"format_valid": format_acc.item(),
109198
"ans_valid": ans_acc.item(),
199+
"response_length": res_length,
200+
"reward": reward.item(),
110201
}

applications/ColossalChat/rl_example.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@
118118
parser.add_argument(
119119
"-esd", "--eval-save-dir", type=str, default="./eval", help="Directory for saving evaluation results."
120120
)
121+
parser.add_argument(
122+
"-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings."
123+
)
121124
args = parser.parse_args()
122125

123126
if args.train_minibatch_size is None:
@@ -198,6 +201,8 @@
198201
"beta": args.kl_coeff, # KL penalty coefficient
199202
"loss_variation": "sample_level",
200203
"reward_fn_type": args.reward_type,
204+
"max_length": args.max_new_tokens + args.max_prompt_tokens,
205+
"max_new_tokens": args.max_new_tokens,
201206
}
202207
elif args.algo == "DAPO":
203208
# DAPO variant settings
@@ -213,6 +218,7 @@
213218
"loss_variation": "token_level",
214219
"soft_over_length_punishment": True,
215220
"max_length": args.max_new_tokens + args.max_prompt_tokens,
221+
"max_new_tokens": args.max_new_tokens,
216222
"cache_length": min(1024, int(args.max_new_tokens / 4)),
217223
"filter_truncated_response": True,
218224
"reward_fn_type": args.reward_type,
@@ -266,4 +272,6 @@
266272
eval_interval=args.eval_interval,
267273
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
268274
eval_generation_config=eval_generation_config,
275+
log_rollout_interval=20,
276+
rollout_save_dir=args.rollout_save_dir,
269277
)

0 commit comments

Comments
 (0)