Skip to content

Commit d35ed00

Browse files
authored
fix some bugs for eval (#2)
1 parent d404dfa commit d35ed00

File tree

11 files changed

+80
-55
lines changed

11 files changed

+80
-55
lines changed

scripts/config/alfworld.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ data:
33
batch_size: 4
44
dataset_path: 'scripts/data_prepare/alfworld_data'
55
default_workflow_type: 'alfworld_workflow'
6-
dataset_config:
7-
split: 'train'
6+
train_split: 'train'
7+
eval_split: ''
88
format_config:
99
prompt_key: 'game_file'
1010
model:

scripts/config/countdown.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ data:
33
batch_size: 96
44
dataset_path: 'countdown_dataset/oneshot-split'
55
default_workflow_type: 'math_workflow'
6-
dataset_config:
7-
split: 'train'
6+
train_split: 'train'
7+
eval_split: ''
88
default_reward_fn_type: 'countdown_reward'
99
format_config:
1010
prompt_key: 'question'

scripts/config/gsm8k.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
data:
22
# basic info
33
dataset_path: '/PATH/TO/DATASET/'
4-
dataset_config:
5-
split: 'train'
4+
train_split: 'train'
5+
eval_split: ''
66
format_config:
77
prompt_key: 'question'
88
response_key: 'answer'
@@ -70,6 +70,7 @@ trainer:
7070
algorithm_type: ppo
7171
trainer_config_path: 'scripts/config/train_gsm8k.yaml'
7272
sft_warmup_iteration: 0 # Set to integer to enable sft warmup
73+
eval_interval: 50
7374
monitor:
7475
cache_root_dir: ""
7576
project: "Trinity-RFT-gsm8k"

scripts/config/webshop.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ data:
33
batch_size: 4
44
dataset_path: 'scripts/data_prepare/webshop_data'
55
default_workflow_type: 'webshop_workflow'
6-
dataset_config:
7-
split: 'train'
6+
train_split: 'train'
7+
eval_split: ''
88
format_config:
99
prompt_key: 'task_id'
1010
model:

tests/common/tmp/template_config.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ data:
33
dataset_path: ''
44
total_epoch: 1
55
batch_size: 1
6-
split: train
6+
train_split: 'train'
7+
eval_split: ''
78
default_workflow_type: ''
89
default_reward_fn_type: ''
910
dataset_config: {}

trinity/cli/launcher.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ def explore(config: Config) -> None:
1919
try:
2020
ray.get(explorer.prepare.remote())
2121
ray.get(explorer.sync_weight.remote())
22-
ref, _ = ray.wait([explorer.explore.remote()])
23-
ray.get(ref)
22+
ray.get(explorer.explore.remote())
2423
logger.info("Explore finished.")
2524
except Exception as e:
2625
logger.error(f"Explore failed: {e}")
@@ -34,8 +33,7 @@ def train(config: Config) -> None:
3433
trainer = Trainer.remote(config)
3534
try:
3635
ray.get(trainer.prepare.remote())
37-
ref, _ = ray.wait([trainer.train.remote(algo_type)])
38-
ray.get(ref)
36+
ray.get(trainer.train.remote(algo_type))
3937
logger.info("Train finished.")
4038
except Exception as e:
4139
logger.error(f"Train failed {e}.")
@@ -67,20 +65,21 @@ def both(config: Config) -> None:
6765

6866
if config.trainer.sft_warmup_iteration > 0:
6967
for step in range(config.trainer.sft_warmup_iteration):
70-
ray.get([trainer.train_step.remote(AlgorithmType.SFT)])
68+
ray.get(trainer.train_step.remote(AlgorithmType.SFT))
7169
logger.info(f"SFT warmup step {step} finished.")
7270
ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()])
7371

7472
algo_type = config.trainer.algorithm_type
75-
global_iter_num = 0
7673
while True:
7774
try:
78-
explore_continue = explorer.explore_step.remote()
79-
train_continue = trainer.train_step.remote(algo_type)
80-
if not ray.get(explore_continue):
75+
ref_explore = explorer.explore_step.remote()
76+
ref_train = trainer.train_step.remote(algo_type)
77+
explore_continue, _ = ray.get(ref_explore)
78+
train_continue, train_iter_num = ray.get(ref_train)
79+
if not explore_continue:
8180
logger.info("Explorer finished, stopping...")
8281
break
83-
if not ray.get(train_continue):
82+
if not train_continue:
8483
logger.info("Trainer finished, stopping...")
8584
break
8685
ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()])
@@ -89,10 +88,14 @@ def both(config: Config) -> None:
8988
logger.error(e)
9089
logger.error("Training stopped due to exception.")
9190
raise e
92-
global_iter_num += 1
93-
if global_iter_num % config.trainer.eval_interval == 0:
94-
ray.wait([explorer.eval.remote()])
95-
logger.info("Eval step finished.")
91+
if (train_iter_num - 1) % config.trainer.eval_interval == 0:
92+
try:
93+
ray.get(explorer.eval.remote(train_iter_num))
94+
logger.info("Evaluation finished.")
95+
except Exception as e:
96+
logger.error(e)
97+
logger.error("Evaluation failed.")
98+
raise e
9699

97100

98101
def main() -> None:

trinity/common/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,16 @@ def check_and_update(self) -> None:
306306
self.synchronizer.backend = self.explorer.backend
307307
if self.synchronizer.sync_method == "online" and self.mode != "both":
308308
raise ValueError("Online synchronization is only supported in both mode")
309+
310+
# check eval_interval
311+
if self.trainer.eval_interval % self.synchronizer.sync_iteration_interval != 0:
312+
self.trainer.eval_interval = (
313+
self.trainer.eval_interval // self.synchronizer.sync_iteration_interval
314+
) * self.synchronizer.sync_iteration_interval
315+
print(
316+
f"Warning: eval_interval is not a multiple of sync_iteration_interval; adjusted to the nearest integer={self.trainer.eval_interval}."
317+
)
318+
309319
# check monitor
310320
if not self.monitor.cache_root_dir:
311321
# create a cache dir in <checkpoint_path>/.cache

trinity/common/workflows/workflow.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ def run(self) -> List[Experience]:
5858
else:
5959
messages = [{"role": "user", "content": self.task_desc}]
6060
logger.debug("start chat")
61-
responses = self.model.chat(messages, n=self.repeat_times)
61+
n = 1 if self.is_eval else self.repeat_times
62+
responses = self.model.chat(messages, n=n)
6263
for response in responses:
6364
reward = self.reward_fn( # type: ignore [misc]
6465
response=response.response_text, # type: ignore [arg-type]
@@ -69,9 +70,9 @@ def run(self) -> List[Experience]:
6970
f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}"
7071
)
7172
if isinstance(reward, dict):
72-
if response.info is None:
73-
response.info = {}
74-
response.info.update(reward)
73+
if response.metrics is None:
74+
response.metrics = {}
75+
response.metrics.update(reward)
7576
reward = sum(reward.values())
7677
response.reward = reward
7778
return responses

trinity/explorer/explorer.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import time
55
from collections import defaultdict
6-
from typing import List, Optional
6+
from typing import List, Optional, Tuple
77

88
import ray
99
import torch
@@ -149,16 +149,20 @@ def get_weight(self, name: str) -> torch.Tensor:
149149

150150
def explore(self) -> None:
151151
"""Explore the entire dataset."""
152-
while self.explore_step():
152+
explore_status, _ = self.explore_step()
153+
while explore_status:
153154
self.sync_weight()
154155
self.logger.info("Explorer finished.")
155156

156-
def explore_step(self) -> bool:
157+
def explore_step(self) -> Tuple[bool, int]:
157158
"""Explore for one step.
158159
159160
Different from `explore()` which consumes all tasks in the task set,
160161
`explore_step()` only consume `sync_iteration_interval * batch_size`
161162
number of tasks.
163+
explore_status:
164+
explore_status: whether there are more tasks to explore.
165+
explore_iter_num: the number of explore iterations
162166
"""
163167
if self.task_iter is None:
164168
self.task_iter = iter(self.taskset)
@@ -175,7 +179,7 @@ def explore_step(self) -> bool:
175179
self.runner_pool.run_tasks(tasks)
176180
except StopIteration:
177181
self.logger.warning("No more tasks in the task set. Stop exploring.")
178-
return False
182+
return False, self.iteration
179183

180184
# wait for all tasks of this step to finish
181185
while self.runner_pool.has_next():
@@ -190,7 +194,7 @@ def explore_step(self) -> bool:
190194
self.runner_pool.run_tasks(next(self.task_iter)) # type: ignore
191195
except StopIteration:
192196
self.logger.warning("No more tasks in the task set. Stop exploring.")
193-
return False
197+
return False, self.iteration
194198
else:
195199
for metric_name, metric_value in status.metric.items():
196200
all_metrics[metric_name].append(metric_value)
@@ -208,11 +212,11 @@ def explore_step(self) -> bool:
208212
)
209213

210214
self.logger.info("Explore step finished.")
211-
return True
215+
return True, self.iteration
212216

213-
def eval(self) -> bool:
217+
def eval(self, step) -> bool:
214218
"""Evaluation on all evaluation data samples."""
215-
self.logger.info("\n\nEvaluation started.\n\n")
219+
self.logger.info("Evaluation started.")
216220
st = time.time()
217221
all_metrics = defaultdict(list)
218222

@@ -231,11 +235,9 @@ def eval(self) -> bool:
231235
for metric_name, metric_value in status.metric.items():
232236
all_metrics[metric_name].append(metric_value)
233237

234-
self.logger.info("Evaluation finished.")
235-
236238
log_metrics = self.monitor.calculate_metrics(all_metrics, prefix="eval") # type: ignore
237239
log_metrics["eval/total_time"] = time.time() - st
238-
self.monitor.log(log_metrics, step=self.iteration) # type: ignore
240+
self.monitor.log(log_metrics, step=step) # type: ignore
239241
return True
240242

241243
def sync_weight(self) -> None:

trinity/trainer/trainer.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
88
"""
99
from abc import ABC, abstractmethod
10+
from typing import Tuple
1011

1112
import ray
1213

@@ -45,18 +46,23 @@ def prepare(self) -> None:
4546
def train(self, algo_type: AlgorithmType = AlgorithmType.PPO):
4647
"""Train the model."""
4748
while True:
48-
if not self.train_iteration(algo_type):
49+
train_status, _ = self.train_iteration(algo_type)
50+
if not train_status:
4951
break
5052

51-
def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> bool:
52-
"""Train one step. Each step contains `sync_iteration_interval` iteration."""
53+
def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool, int]:
54+
"""Train one step. Each step contains `sync_iteration_interval` iteration.
55+
Returns:
56+
train_status: Whether to continue training.
57+
train_iter_num: The number of training iterations"""
5358
for _ in range(self.config.synchronizer.sync_iteration_interval):
54-
if not self.train_iteration(algo_type):
55-
return False
59+
train_status, train_iter_num = self.train_iteration(algo_type)
60+
if not train_status:
61+
return False, train_iter_num
5662
self.logger.info("Trainer finished.")
57-
return True
63+
return True, train_iter_num
5864

59-
def train_iteration(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> bool:
65+
def train_iteration(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool, int]:
6066
"""Train one iteration.
6167
6268
Args:
@@ -108,15 +114,15 @@ def prepare(self) -> None:
108114
"""Do some preparation before training started."""
109115

110116
@abstractmethod
111-
def train_rft_iteration(self, experiences) -> bool:
117+
def train_rft_iteration(self, experiences) -> Tuple[bool, int]:
112118
"""Train on the RFT data."""
113119

114120
@abstractmethod
115-
def train_sft_iteration(self, experiences) -> bool:
121+
def train_sft_iteration(self, experiences) -> Tuple[bool, int]:
116122
"""Train on the SFT data."""
117123

118124
@abstractmethod
119-
def train_dpo_iteration(self, experiences) -> bool:
125+
def train_dpo_iteration(self, experiences) -> Tuple[bool, int]:
120126
"""Train on the DPO data."""
121127

122128
@abstractmethod

0 commit comments

Comments
 (0)