Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions scripts/config/alfworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ data:
batch_size: 4
dataset_path: 'scripts/data_prepare/alfworld_data'
default_workflow_type: 'alfworld_workflow'
dataset_config:
split: 'train'
train_split: 'train'
eval_split: ''
format_config:
prompt_key: 'game_file'
model:
Expand Down
4 changes: 2 additions & 2 deletions scripts/config/countdown.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ data:
batch_size: 96
dataset_path: 'countdown_dataset/oneshot-split'
default_workflow_type: 'math_workflow'
dataset_config:
split: 'train'
train_split: 'train'
eval_split: ''
default_reward_fn_type: 'countdown_reward'
format_config:
prompt_key: 'question'
Expand Down
5 changes: 3 additions & 2 deletions scripts/config/gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
data:
# basic info
dataset_path: '/PATH/TO/DATASET/'
dataset_config:
split: 'train'
train_split: 'train'
eval_split: ''
format_config:
prompt_key: 'question'
response_key: 'answer'
Expand Down Expand Up @@ -70,6 +70,7 @@ trainer:
algorithm_type: ppo
trainer_config_path: 'scripts/config/train_gsm8k.yaml'
sft_warmup_iteration: 0 # Set to integer to enable sft warmup
eval_interval: 10
monitor:
cache_root_dir: ""
project: "Trinity-RFT-gsm8k"
Expand Down
4 changes: 2 additions & 2 deletions scripts/config/webshop.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ data:
batch_size: 4
dataset_path: 'scripts/data_prepare/webshop_data'
default_workflow_type: 'webshop_workflow'
dataset_config:
split: 'train'
train_split: 'train'
eval_split: ''
format_config:
prompt_key: 'task_id'
model:
Expand Down
3 changes: 2 additions & 1 deletion tests/common/tmp/template_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ data:
dataset_path: ''
total_epoch: 1
batch_size: 1
split: train
train_split: 'train'
eval_split: ''
default_workflow_type: ''
default_reward_fn_type: ''
dataset_config: {}
Expand Down
15 changes: 10 additions & 5 deletions trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def both(config: Config) -> None:
ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()])

algo_type = config.trainer.algorithm_type
global_iter_num = 0
while True:
try:
explore_continue = explorer.explore_step.remote()
Expand All @@ -89,10 +88,16 @@ def both(config: Config) -> None:
logger.error(e)
logger.error("Training stopped due to exception.")
raise e
global_iter_num += 1
if global_iter_num % config.trainer.eval_interval == 0:
ray.wait([explorer.eval.remote()])
logger.info("Eval step finished.")
train_step_num = ray.get(trainer.get_current_step.remote())
if (train_step_num - 1) % config.trainer.eval_interval == 0:
ref, _ = ray.wait([explorer.eval.remote(step=train_step_num)])
try:
ray.get(ref)
logger.info("Evaluation finished.")
except Exception as e:
logger.error(e)
logger.error("Evaluation failed.")
raise e


def main() -> None:
Expand Down
10 changes: 10 additions & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,16 @@ def check_and_update(self) -> None:
self.synchronizer.backend = self.explorer.backend
if self.synchronizer.sync_method == "online" and self.mode != "both":
raise ValueError("Online synchronization is only supported in both mode")

# check eval_interval
if self.trainer.eval_interval % self.synchronizer.sync_iteration_interval != 0:
self.trainer.eval_interval = (
self.trainer.eval_interval // self.synchronizer.sync_iteration_interval
) * self.synchronizer.sync_iteration_interval
print(
f"Warning: eval_interval is not a multiple of sync_iteration_interval; adjusted to the nearest integer={self.trainer.eval_interval}."
)

# check monitor
if not self.monitor.cache_root_dir:
# create a cache dir in <checkpoint_path>/.cache
Expand Down
9 changes: 5 additions & 4 deletions trinity/common/workflows/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def run(self) -> List[Experience]:
else:
messages = [{"role": "user", "content": self.task_desc}]
logger.debug("start chat")
responses = self.model.chat(messages, n=self.repeat_times)
n = 1 if self.is_eval else self.repeat_times
responses = self.model.chat(messages, n=n)
for response in responses:
reward = self.reward_fn( # type: ignore [misc]
response=response.response_text, # type: ignore [arg-type]
Expand All @@ -69,9 +70,9 @@ def run(self) -> List[Experience]:
f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}"
)
if isinstance(reward, dict):
if response.info is None:
response.info = {}
response.info.update(reward)
if response.metrics is None:
response.metrics = {}
response.metrics.update(reward)
reward = sum(reward.values())
response.reward = reward
return responses
Expand Down
8 changes: 3 additions & 5 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,9 @@ def explore_step(self) -> bool:
self.logger.info("Explore step finished.")
return True

def eval(self) -> bool:
def eval(self, step) -> bool:
"""Evaluation on all evaluation data samples."""
self.logger.info("\n\nEvaluation started.\n\n")
self.logger.info("Evaluation started.")
st = time.time()
all_metrics = defaultdict(list)

Expand All @@ -231,11 +231,9 @@ def eval(self) -> bool:
for metric_name, metric_value in status.metric.items():
all_metrics[metric_name].append(metric_value)

self.logger.info("Evaluation finished.")

log_metrics = self.monitor.calculate_metrics(all_metrics, prefix="eval") # type: ignore
log_metrics["eval/total_time"] = time.time() - st
self.monitor.log(log_metrics, step=self.iteration) # type: ignore
self.monitor.log(log_metrics, step=step) # type: ignore
return True

def sync_weight(self) -> None:
Expand Down
3 changes: 3 additions & 0 deletions trinity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def sync_weight(self) -> None:
if self.config.synchronizer.sync_method == "online":
self.engine.sync_weight()

def get_current_step(self) -> int:
return self.engine.get_current_step()


class TrainEngineWrapper(ABC):
"""A wrapper class to wrap various training engines."""
Expand Down
3 changes: 3 additions & 0 deletions trinity/trainer/verl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,3 +559,6 @@ def sft_to_rft(self) -> None:

def shutdown(self) -> None:
pass

def get_current_step(self) -> int:
return self.global_steps