Skip to content

Commit 78a06f5

Browse files
committed
fix missing tags parameter
1 parent 88e3b09 commit 78a06f5

File tree

4 files changed

+27
-10
lines changed

4 files changed

+27
-10
lines changed

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -120,16 +120,7 @@ def __init__(
120120
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
121121
)
122122
# Initialize verifiable reward.
123-
response_format_tags = (
124-
{
125-
"think_start": {"text": "<think>", "num_occur": 1},
126-
"think_end": {"text": "</think>", "num_occur": 1},
127-
"answer_start": {"text": "<answer>", "num_occur": 1},
128-
"answer_end": {"text": "</answer>", "num_occur": 1},
129-
}
130-
if grpo_config.get("reward_fn_type") == "think_answer_tags"
131-
else None
132-
)
123+
response_format_tags = grpo_config.get("response_format_tags", None)
133124
reward_model_kwargs = {
134125
k: v
135126
for k, v in grpo_config.items()

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def launch_distributed(
100100
eval_dataset_config=eval_dataset_config,
101101
eval_interval=eval_interval,
102102
evaluation_function_type=grpo_config["reward_fn_type"],
103+
response_format_tags=grpo_config["response_format_tags"],
103104
eval_save_dir=eval_save_dir,
104105
eval_generation_config=eval_generation_config,
105106
project_name=project_name,

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(
4646
eval_dataset_config=None,
4747
eval_interval=-1, # disable evaluation
4848
evaluation_function_type="think_answer_tags",
49+
response_format_tags=None,
4950
eval_save_dir: str = "./eval",
5051
project_name: str = None,
5152
run_name: str = None,
@@ -148,6 +149,7 @@ def __init__(
148149
self.evaluation_function = boxed_math_reward_fn
149150
else:
150151
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
152+
self.response_format_tags = response_format_tags
151153
else:
152154
raise ValueError("eval_dataset_config is not defined")
153155
self.device = get_current_device()
@@ -217,6 +219,7 @@ def loop(self) -> None:
217219
eval_outputs["response_idx"][m][n],
218220
tokenizer=self.tokenizer,
219221
eval_mode=True,
222+
tags=self.response_format_tags,
220223
)
221224
for m in range(eval_outputs["input_ids"].size(0))
222225
for n in range(eval_outputs["input_ids"].size(1))
@@ -324,6 +327,7 @@ def __init__(
324327
eval_dataset_config=None,
325328
eval_interval=-1, # disable evaluation
326329
evaluation_function_type="think_answer_tags",
330+
response_format_tags=None,
327331
eval_save_dir: str = "./eval",
328332
eval_generation_config={},
329333
project_name: str = None,
@@ -349,6 +353,7 @@ def __init__(
349353
eval_dataset_config=eval_dataset_config,
350354
eval_interval=eval_interval,
351355
evaluation_function_type=evaluation_function_type,
356+
response_format_tags=response_format_tags,
352357
eval_save_dir=eval_save_dir,
353358
project_name=project_name,
354359
run_name=run_name,

applications/ColossalChat/rl_example.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,16 @@
231231
"reward_fn_type": args.reward_type,
232232
"max_length": args.max_new_tokens + args.max_prompt_tokens,
233233
"max_new_tokens": args.max_new_tokens,
234+
"response_format_tags": (
235+
{
236+
"think_start": {"text": "<think>", "num_occur": 1},
237+
"think_end": {"text": "</think>", "num_occur": 1},
238+
"answer_start": {"text": "<answer>", "num_occur": 1},
239+
"answer_end": {"text": "</answer>", "num_occur": 1},
240+
}
241+
if args.reward_type == "think_answer_tags"
242+
else None
243+
),
234244
}
235245
elif args.algo == "DAPO":
236246
# DAPO variant settings
@@ -250,6 +260,16 @@
250260
"cache_length": min(1024, int(args.max_new_tokens / 4)),
251261
"filter_truncated_response": True,
252262
"reward_fn_type": args.reward_type,
263+
"response_format_tags": (
264+
{
265+
"think_start": {"text": "<think>", "num_occur": 1},
266+
"think_end": {"text": "</think>", "num_occur": 1},
267+
"answer_start": {"text": "<answer>", "num_occur": 1},
268+
"answer_end": {"text": "</answer>", "num_occur": 1},
269+
}
270+
if args.reward_type == "think_answer_tags"
271+
else None
272+
),
253273
}
254274
else:
255275
raise ValueError(f"Unsupported algorithm: {args.algo}")

0 commit comments

Comments
 (0)