Skip to content

Commit 382307a

Browse files
TongLi3701Tong Li
authored andcommitted
fix default eval setting (#6321)
Co-authored-by: Tong Li <[email protected]>
1 parent 2a39d3a commit 382307a

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def __init__(
151151
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
152152
self.response_format_tags = response_format_tags
153153
else:
154-
raise ValueError("eval_dataset_config is not defined")
154+
print("No eval dataset provided, skip eval")
155155
self.device = get_current_device()
156156

157157
# init backend

applications/ColossalChat/rl_example.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,16 @@
1010
parser = argparse.ArgumentParser()
1111
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
1212
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
13-
parser.add_argument("-p", "--project", type=str, default="GRPO-V3", help="Project name.")
13+
parser.add_argument(
14+
"-ed",
15+
"--eval-dataset",
16+
type=str,
17+
default=None,
18+
help="Evaluation dataset for each task, please use json format to specify the dataset for each task. \
19+
For example: {'task1':'data_eval_task1.jsonl', 'task2':'data_eval_task2.jsonl'}, the jsonl file should be in the same format as the training dataset. \
20+
The key is the task name, and the value is the path to the jsonl file",
21+
)
22+
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
1423
parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.")
1524

1625
# Distributed training parameters
@@ -301,10 +310,14 @@
301310
project_name=args.project,
302311
save_interval=args.save_interval,
303312
save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")),
304-
eval_dataset_config={
305-
k: {"path": v, "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt}
306-
for k, v in json.loads(args.eval_dataset).items()
307-
},
313+
eval_dataset_config=(
314+
{
315+
k: {"path": v, "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt}
316+
for k, v in json.loads(args.eval_dataset).items()
317+
}
318+
if args.eval_dataset
319+
else None
320+
),
308321
eval_interval=args.eval_interval,
309322
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
310323
eval_generation_config=eval_generation_config,

0 commit comments

Comments
 (0)