Skip to content

Commit de2ad3b

Browse files
TongLi3701Tong Li
andauthored
fix default eval setting (#6321)
Co-authored-by: Tong Li <[email protected]>
1 parent 32afa7b commit de2ad3b

File tree

3 files changed

+11
-6
lines changed

3 files changed

+11
-6
lines changed

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import os
23
import uuid
34
from typing import Any, Dict, Optional
45

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def __init__(
149149
else:
150150
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
151151
else:
152-
raise ValueError("eval_dataset_config is not defined")
152+
print("No eval dataset provided, skip eval")
153153
self.device = get_current_device()
154154

155155
# init backend

applications/ColossalChat/rl_example.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"-ed",
1515
"--eval-dataset",
1616
type=str,
17-
default='{"eval task name":"data_eval.jsonl"}',
17+
default=None,
1818
help="Evaluation dataset for each task, please use json format to specify the dataset for each task. \
1919
For example: {'task1':'data_eval_task1.jsonl', 'task2':'data_eval_task2.jsonl'}, the jsonl file should be in the same format as the training dataset. \
2020
The key is the task name, and the value is the path to the jsonl file",
@@ -265,10 +265,14 @@
265265
project_name=args.project,
266266
save_interval=args.save_interval,
267267
save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")),
268-
eval_dataset_config={
269-
k: {"path": v, "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt}
270-
for k, v in json.loads(args.eval_dataset).items()
271-
},
268+
eval_dataset_config=(
269+
{
270+
k: {"path": v, "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt}
271+
for k, v in json.loads(args.eval_dataset).items()
272+
}
273+
if args.eval_dataset
274+
else None
275+
),
272276
eval_interval=args.eval_interval,
273277
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
274278
eval_generation_config=eval_generation_config,

0 commit comments

Comments
 (0)