Skip to content

Commit 14f237c

Browse files
authored
[feat] Support boxed math reward (#6284)
* fix pp+tp, fix dataloader * fixed plugin micro-batch size * support boxed reward * add boxed reward * fix pp state dict incomplete issue * Revert "fix pp state dict incomplete issue" This reverts commit 6c1b3b6.
1 parent 2ca1e3c commit 14f237c

File tree

5 files changed

+118
-12
lines changed

5 files changed

+118
-12
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def setup(self) -> None:
7171
and "num_microbatches" not in self.plugin_config
7272
and "microbatch_size" not in self.plugin_config
7373
):
74-
plugin_config["microbatch_size"] = self.minibatch_size
74+
plugin_config["microbatch_size"] = max(1, self.minibatch_size // plugin_config.get("pp_size", 1))
7575
plugin_config.update(self.plugin_config)
7676
self.plugin = HybridParallelPlugin(**plugin_config)
7777
self.booster = Booster(plugin=self.plugin)

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import wandb
88
from coati.distributed.consumer import BaseConsumer
99
from coati.distributed.loss import PolicyLoss
10-
from coati.distributed.reward.reward_fn import math_reward_fn
10+
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
1111
from coati.distributed.reward.verifiable_reward import VerifiableReward
1212
from coati.distributed.utils import calc_action_log_probs
1313
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
@@ -54,7 +54,9 @@ def __init__(
5454
and "num_microbatches" not in plugin_config
5555
and "microbatch_size" not in plugin_config
5656
):
57-
plugin_config["microbatch_size"] = max(1, grpo_config.get("train_microbatch_size") // 2)
57+
plugin_config["microbatch_size"] = max(
58+
1, grpo_config.get("train_microbatch_size") // plugin_config.get("pp_size", 1)
59+
)
5860
super().__init__(
5961
num_producers,
6062
num_episodes,
@@ -131,7 +133,12 @@ def __init__(
131133
k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"]
132134
}
133135
self.reward_model = VerifiableReward(
134-
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags, **reward_model_kwargs
136+
reward_fns=[
137+
math_reward_fn if grpo_config.get("reward_fn_type") == "think_answer_tags" else boxed_math_reward_fn
138+
],
139+
tokenizer=self.tokenizer,
140+
tags=response_format_tags,
141+
**reward_model_kwargs,
135142
)
136143
self.global_step = 0
137144
self.use_wandb = use_wandb

applications/ColossalChat/coati/distributed/reward/reward_fn.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22

3-
from .reward_utils import extract_solution, validate_response_structure
3+
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
44

55

66
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
@@ -70,3 +70,43 @@ def gsm8k_reward_fn(input_ids, **kwargs):
7070
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
7171
reward = reward + 9.0
7272
return reward
73+
74+
75+
def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
76+
tokenizer = kwargs["tokenizer"]
77+
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
78+
format_score = 0.0
79+
acc_score = 10.0
80+
reward = torch.tensor(0.0)
81+
format_acc = torch.tensor(0.0)
82+
ans_acc = torch.tensor(0.0)
83+
s, e = response_idx[0], response_idx[1]
84+
85+
length_reward = 0.0
86+
if soft_over_length_punishment:
87+
max_length = kwargs.get("max_length", 1024 * 4)
88+
cache_length = kwargs.get("cache_length", 512)
89+
res_length = e.item() - s.item() + 1
90+
if max_length - cache_length < res_length < max_length:
91+
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
92+
93+
if gt_answer is None:
94+
return reward
95+
96+
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
97+
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
98+
final_answer = extract_boxed_solution(decoded_final_answer)
99+
format_valid = final_answer is not None
100+
# Check format accuracy
101+
if format_valid:
102+
format_acc += 1
103+
reward += format_score
104+
105+
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
106+
if format_valid and final_answer is not None and gt_answer.strip().lower() == final_answer.strip().lower():
107+
ans_acc += 1
108+
reward += acc_score
109+
110+
reward = reward + length_reward
111+
112+
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)

applications/ColossalChat/coati/distributed/reward/reward_utils.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,51 @@ def extract_solution(solution_str: str) -> Tuple[Optional[str], str]:
7474

7575
final_answer = matches[-1].group(1).strip()
7676
return final_answer, solution_str
77+
78+
79+
def extract_boxed_solution(text: str) -> Optional[str]:
80+
"""
81+
Modified from: https://gist.github.com/lewtun/9c2ce1937b741404090a3dc4c7c022b3
82+
Retrieves the content from the last occurrence of `\boxed{}` in a LaTeX-like string.
83+
84+
Args:
85+
text (str): A string potentially containing LaTeX-style boxed expressions.
86+
87+
Returns:
88+
Optional[str]: The text inside the final `\boxed{}` if successfully extracted;
89+
returns `None` if no properly closed box is found.
90+
91+
Examples:
92+
>>> extract_boxed_solution("The answer is \\boxed{42}.")
93+
'42'
94+
>>> extract_boxed_solution("Here is an unmatched \\boxed{42")
95+
None
96+
"""
97+
try:
98+
# Find the last occurrence of "\boxed{"
99+
start_idx = text.rindex("\\boxed{")
100+
# Move past "\boxed{" to find the start of the content
101+
content_start = start_idx + len("\\boxed{")
102+
open_braces = 1
103+
pos = content_start
104+
105+
# Traverse the string to find the matching closing brace
106+
while open_braces > 0 and pos < len(text):
107+
if text[pos] == "{":
108+
open_braces += 1
109+
elif text[pos] == "}":
110+
open_braces -= 1
111+
pos += 1
112+
113+
# If all braces are matched, extract and return the content
114+
if open_braces == 0:
115+
return text[content_start : pos - 1].strip()
116+
else:
117+
return None
118+
119+
except ValueError:
120+
# "\boxed{" not found
121+
return None
122+
except Exception:
123+
# Any other unexpected error
124+
return None

applications/ColossalChat/rl_example.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,14 @@
8686
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
8787
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
8888
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
89+
parser.add_argument(
90+
"-rt",
91+
"--reward-type",
92+
type=str,
93+
default="think_answer_tags",
94+
choices=["think_answer_tags", "boxed"],
95+
help="Reward type for GRPO.",
96+
)
8997

9098
# Logging/Checkpointing parameters
9199
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
@@ -136,8 +144,8 @@
136144
max_length=args.max_new_tokens + args.max_prompt_tokens,
137145
do_sample=True,
138146
max_new_tokens=None,
139-
early_stopping=False,
140-
stop_strings=["</answer>"],
147+
early_stopping=False if args.reward_type == "think_answer_tags" else True,
148+
stop_strings=["</answer>"] if args.reward_type == "think_answer_tags" else None,
141149
)
142150
)
143151
elif args.backend == "vllm":
@@ -153,9 +161,9 @@
153161
generate_config.update(
154162
dict(
155163
max_tokens=args.max_new_tokens, # max new tokens
156-
ignore_eos=True,
164+
ignore_eos=True if args.reward_type == "think_answer_tags" else False,
157165
include_stop_str_in_output=True,
158-
stop=["</answer>"],
166+
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
159167
)
160168
)
161169
else:
@@ -168,6 +176,7 @@
168176
"train_microbatch_size": args.train_microbatch_size,
169177
"beta": args.kl_coeff, # KL penalty coefficient
170178
"loss_variation": "sample_level",
179+
"reward_fn_type": args.reward_type,
171180
}
172181
elif args.algo == "DAPO":
173182
# DAPO variant settings
@@ -185,6 +194,7 @@
185194
"max_length": args.max_new_tokens + args.max_prompt_tokens,
186195
"cache_length": min(1024, int(args.max_new_tokens / 4)),
187196
"filter_truncated_response": True,
197+
"reward_fn_type": args.reward_type,
188198
}
189199
else:
190200
raise ValueError(f"Unsupported algorithm: {args.algo}")
@@ -212,14 +222,15 @@
212222
plugin_config={
213223
"zero_stage": 2,
214224
}, # for zero
215-
# currently not support tp/pp
216225
# plugin_config={
217226
# "tp_size": 2,
218227
# "pp_size": 2,
219-
# "microbatch_size": max(1, args.train_microbatch_size // 2),
228+
# "microbatch_size": max(
229+
# 1, args.train_microbatch_size // 2
230+
# ), # microbatch size should be set to train_microbatch_size // pp_size
220231
# "zero_stage": 0,
221232
# "max_norm": 1.0,
222-
# }, # for pp
233+
# }, # for pp, tp
223234
inference_backend=args.backend,
224235
master_addr="localhost",
225236
master_port=args.master_port,

0 commit comments

Comments
 (0)