diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index a697539141..1e50991c6b 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -271,6 +271,9 @@ def test_trainer(self): self.config.algorithm.algorithm_type = "dpo" self.config.algorithm.policy_loss_fn = "dpo" self.config.algorithm.policy_loss_fn_args = {} + self.config.buffer.total_epochs = 2 + self.config.buffer.total_steps = 4 # step has higher priority than epoch + self.config.synchronizer.sync_interval = 4 # self.config.buffer.batch_size = 32 self.config.buffer.trainer_input.experience_buffer = get_unittest_dataset_config("dpo") self.config.check_and_update() @@ -287,6 +290,33 @@ def tearDown(self): shutil.rmtree(self.config.checkpoint_job_dir) +class TestTrainerSFT(BaseTrainerCase): + def test_trainer(self): + """Test SFT.""" + # test both mode + self.config.mode = "train" + self.config.algorithm.algorithm_type = "sft" + self.config.algorithm.policy_loss_fn = "sft" + self.config.algorithm.policy_loss_fn_args = {} + self.config.algorithm.kl_loss_fn = "none" + self.config.algorithm.entropy_loss_fn = "none" + self.config.synchronizer.sync_interval = 4 + self.config.buffer.total_epochs = 2 + self.config.buffer.trainer_input.experience_buffer = get_unittest_dataset_config( + "sft_for_gsm8k" + ) + self.config.check_and_update() + train(self.config) + parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) + actor_metrics = parser.metric_list("actor") + self.assertTrue(len(actor_metrics) > 0) + self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4) + + def tearDown(self): + # remove dir only when the test passed + shutil.rmtree(self.config.checkpoint_job_dir) + + def run_trainer(config: Config) -> None: ray.init(namespace=config.ray_namespace) train(config) diff --git a/trinity/common/config.py b/trinity/common/config.py index 13b37ac26d..ab7d8a08cd 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -471,9 +471,16 @@ def _check_buffer(self) -> None: # noqa: C901 "`buffer.explorer_input.taskset.rollout_args.n` is set to `algorithm.repeat_times`" f" (={self.algorithm.repeat_times})." ) - self.buffer.explorer_input.taskset.task_type = TaskType.EXPLORE - self.buffer.explorer_input.taskset.total_epochs = self.buffer.total_epochs - self.buffer.explorer_input.taskset.total_steps = self.buffer.total_steps + if self.mode == "train": + assert ( + self.buffer.trainer_input.experience_buffer is not None + ), "`buffer.trainer_input.experience_buffer` is required when `mode` is `train`." + self.buffer.trainer_input.experience_buffer.total_epochs = self.buffer.total_epochs + self.buffer.trainer_input.experience_buffer.total_steps = self.buffer.total_steps + else: + self.buffer.explorer_input.taskset.task_type = TaskType.EXPLORE + self.buffer.explorer_input.taskset.total_epochs = self.buffer.total_epochs + self.buffer.explorer_input.taskset.total_steps = self.buffer.total_steps if self.buffer.explorer_input.taskset.default_workflow_type is None: self.buffer.explorer_input.taskset.default_workflow_type = ( self.buffer.explorer_input.default_workflow_type