Skip to content

Commit 03b41d6

Browse files
committed
upgrade reward functions
1 parent 021914c commit 03b41d6

File tree

3 files changed

+123
-27
lines changed

3 files changed

+123
-27
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@ def __init__(
127127
"answer_end": {"text": "</answer>", "num_occur": 1},
128128
}
129129
reward_model_kwargs = {
130-
k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"]
130+
k: v
131+
for k, v in grpo_config.items()
132+
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"]
131133
}
132134
self.reward_model = VerifiableReward(
133135
reward_fns=[

applications/ColossalChat/coati/distributed/reward/reward_fn.py

Lines changed: 117 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,77 @@
11
import torch
2+
from latex2sympy2_extended import NormalizationConfig
3+
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
24

35
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
46

7+
CANNOT_PARSE_GT_ANSWER = -1
8+
CANNOT_PARSE_PREDICTION = -2
9+
SUCCESS = 1
10+
MATCHING_FAIL = 0
11+
12+
13+
def verify_math_representation(completion, gt_answer):
14+
"""
15+
Verify if the completion is a valid math representation of the gt_answer.
16+
"""
17+
if not completion.startswith("\\boxed{"):
18+
completion = "\\boxed{" + completion + "}"
19+
if not gt_answer.startswith("\\boxed{"):
20+
gt_answer = "\\boxed{" + gt_answer + "}"
21+
target = (
22+
ExprExtractionConfig(),
23+
LatexExtractionConfig(
24+
normalization_config=NormalizationConfig(
25+
nits=False,
26+
malformed_operators=False,
27+
basic_latex=True,
28+
boxed="all",
29+
units=True,
30+
),
31+
boxed_match_priority=0,
32+
),
33+
)
34+
if not isinstance(gt_answer, str) or len(gt_answer) == 0:
35+
raise ValueError("gt_answer should be a string, please verify your training data.")
36+
if not isinstance(completion, str) or len(completion) == 0:
37+
return MATCHING_FAIL
38+
try:
39+
parsed_gt_answer = parse(gt_answer, extraction_config=target)
40+
if len(parsed_gt_answer) == 0:
41+
return CANNOT_PARSE_GT_ANSWER
42+
parsed_completion = parse(completion, extraction_config=target)
43+
if len(parsed_completion) == 0:
44+
return CANNOT_PARSE_PREDICTION
45+
if verify(parsed_gt_answer, parsed_completion):
46+
return SUCCESS
47+
else:
48+
return MATCHING_FAIL
49+
except Exception:
50+
return MATCHING_FAIL
51+
52+
53+
def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward):
54+
math_verify_result = verify_math_representation(decoded_final_answer, gt_answer)
55+
exact_match_result = (
56+
SUCCESS
57+
if decoded_final_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", "")
58+
== gt_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", "")
59+
else MATCHING_FAIL
60+
)
61+
if math_verify_result == SUCCESS:
62+
ans_acc += 1
63+
reward += acc_score
64+
elif exact_match_result == SUCCESS:
65+
# sometimes for answers that's not a (valid) math expression, math_verify will fail
66+
ans_acc += 1
67+
if math_verify_result == CANNOT_PARSE_PREDICTION:
68+
reward += (
69+
acc_score / 2
70+
) # not a valid latex math representation, but the answer is correct, receive half of the score
71+
else:
72+
reward += acc_score
73+
return reward, ans_acc
74+
575

676
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
777
tokenizer = kwargs["tokenizer"]
@@ -14,15 +84,18 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
1484
s, e = response_idx[0], response_idx[1]
1585

1686
length_reward = 0.0
17-
if soft_over_length_punishment:
18-
max_length = kwargs.get("max_length", 1024 * 4)
19-
cache_length = kwargs.get("cache_length", 512)
20-
res_length = e.item() - s.item() + 1
21-
if max_length - cache_length < res_length < max_length:
22-
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
87+
res_length = e.item() - s.item() + 1
88+
if not eval_mode:
89+
max_new_tokens = kwargs["max_new_tokens"]
90+
else:
91+
max_new_tokens = -1 # for eval mode, we don't need to check the length
92+
if not eval_mode and soft_over_length_punishment:
93+
cache_length = kwargs["cache_length"]
94+
if max_new_tokens - cache_length < res_length < max_new_tokens:
95+
length_reward = ((max_new_tokens - cache_length) - res_length) / cache_length * acc_score
2396

2497
if gt_answer is None:
25-
return reward
98+
raise ValueError("no gt_answer is provided, please check your training dataset.")
2699

27100
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
28101
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
@@ -35,15 +108,15 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
35108
format_acc += 1
36109

37110
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
38-
if (
39-
format_valid
40-
and final_answer is not None
41-
and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower()
42-
):
43-
ans_acc += 1
44-
reward += acc_score
111+
if final_answer is not None:
112+
if eval_mode or format_valid:
113+
reward, ans_acc = verify_model_answer(final_answer, gt_answer, ans_acc, acc_score, reward)
114+
if not eval_mode:
115+
reward = reward + length_reward
45116

46-
reward = reward + length_reward
117+
# Check if the sequence is over length
118+
if not eval_mode and res_length >= max_new_tokens:
119+
reward *= 0.0
47120

48121
if not eval_mode:
49122
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
@@ -56,6 +129,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
56129
"parsed": final_answer,
57130
"format_valid": format_acc.item(),
58131
"ans_valid": ans_acc.item(),
132+
"response_length": res_length,
133+
"reward": reward.item(),
59134
}
60135

61136

@@ -71,31 +146,45 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
71146
s, e = response_idx[0], response_idx[1]
72147

73148
length_reward = 0.0
74-
if soft_over_length_punishment:
75-
max_length = kwargs.get("max_length", 1024 * 4)
76-
cache_length = kwargs.get("cache_length", 512)
77-
res_length = e.item() - s.item() + 1
78-
if max_length - cache_length < res_length < max_length:
79-
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
149+
res_length = e.item() - s.item() + 1
150+
if not eval_mode:
151+
max_new_tokens = kwargs["max_new_tokens"]
152+
else:
153+
max_new_tokens = -1 # for eval mode, we don't need to check the length
154+
if not eval_mode and soft_over_length_punishment:
155+
cache_length = kwargs["cache_length"]
156+
if max_new_tokens - cache_length < res_length < max_new_tokens:
157+
length_reward = ((max_new_tokens - cache_length) - res_length) / cache_length * acc_score
80158

81159
if gt_answer is None:
82-
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
160+
raise ValueError("no gt_answer is provided, please check your training dataset.")
83161

84162
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
163+
85164
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
86165
final_answer = extract_boxed_solution(decoded_final_answer)
87166
format_valid = final_answer is not None
167+
if "tags" in kwargs and kwargs["tags"]:
168+
tags = kwargs["tags"]
169+
format_valid = format_valid and all(
170+
[decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags]
171+
)
88172
# Check format accuracy
89173
if format_valid:
90174
format_acc += 1
91175
reward += format_score
92176

93177
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
94-
if format_valid and final_answer is not None and gt_answer.strip().lower() == final_answer.strip().lower():
95-
ans_acc += 1
96-
reward += acc_score
178+
if final_answer is not None:
179+
if eval_mode or format_valid:
180+
reward, ans_acc = verify_model_answer(final_answer, gt_answer, ans_acc, acc_score, reward)
181+
if not eval_mode:
182+
reward = reward + length_reward
183+
184+
# Check if the sequence is over length
185+
if not eval_mode and res_length >= max_new_tokens:
186+
reward *= 0.0
97187

98-
reward = reward + length_reward
99188
if not eval_mode:
100189
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
101190
else:
@@ -107,4 +196,6 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
107196
"parsed": final_answer,
108197
"format_valid": format_acc.item(),
109198
"ans_valid": ans_acc.item(),
199+
"response_length": res_length,
200+
"reward": reward.item(),
110201
}

applications/ColossalChat/rl_example.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,8 @@
198198
"beta": args.kl_coeff, # KL penalty coefficient
199199
"loss_variation": "sample_level",
200200
"reward_fn_type": args.reward_type,
201+
"max_length": args.max_new_tokens + args.max_prompt_tokens,
202+
"max_new_tokens": args.max_new_tokens,
201203
}
202204
elif args.algo == "DAPO":
203205
# DAPO variant settings
@@ -213,6 +215,7 @@
213215
"loss_variation": "token_level",
214216
"soft_over_length_punishment": True,
215217
"max_length": args.max_new_tokens + args.max_prompt_tokens,
218+
"max_new_tokens": args.max_new_tokens,
216219
"cache_length": min(1024, int(args.max_new_tokens / 4)),
217220
"filter_truncated_response": True,
218221
"reward_fn_type": args.reward_type,

0 commit comments

Comments
 (0)