Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/template/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ model:
max_prompt_tokens: 2048
max_response_tokens: 2048
cluster: # 2 for explorer, 2 for trainer
node_num: 2
gpu_per_node: 2
node_num: 1
gpu_per_node: 4
buffer:
total_epochs: 1
batch_size: 4
Expand Down
31 changes: 31 additions & 0 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,9 @@ def test_trainer(self):
self.config.algorithm.algorithm_type = "dpo"
self.config.algorithm.policy_loss_fn = "dpo"
self.config.algorithm.policy_loss_fn_args = {}
self.config.buffer.total_epochs = 2
self.config.buffer.total_steps = 4 # step has higher priority than epoch
self.config.synchronizer.sync_interval = 4
# self.config.buffer.batch_size = 32
self.config.buffer.trainer_input.experience_buffer = get_unittest_dataset_config("dpo")
self.config.check_and_update()
Expand All @@ -287,6 +290,34 @@ def tearDown(self):
shutil.rmtree(self.config.checkpoint_job_dir)


class TestTrainerSFT(BaseTrainerCase):
def test_trainer(self):
"""Test SFT."""
# test both mode
self.config.mode = "train"
self.config.algorithm.algorithm_type = "sft"
self.config.algorithm.policy_loss_fn = "sft"
self.config.algorithm.policy_loss_fn_args = {}
self.config.algorithm.kl_loss_fn = "none"
self.config.algorithm.entropy_loss_fn = "none"
self.config.synchronizer.sync_interval = 4
self.config.buffer.total_epochs = 2
self.config.buffer.trainer_input.experience_buffer = get_unittest_dataset_config(
"sft_for_gsm8k"
)
self.config.check_and_update()
print(self.config)
train(self.config)
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
actor_metrics = parser.metric_list("actor")
self.assertTrue(len(actor_metrics) > 0)
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4)

def tearDown(self):
# remove dir only when the test passed
shutil.rmtree(self.config.checkpoint_job_dir)


def run_trainer(config: Config) -> None:
ray.init(namespace=config.ray_namespace)
train(config)
Expand Down
10 changes: 7 additions & 3 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,9 +471,13 @@ def _check_buffer(self) -> None: # noqa: C901
"`buffer.explorer_input.taskset.rollout_args.n` is set to `algorithm.repeat_times`"
f" (={self.algorithm.repeat_times})."
)
self.buffer.explorer_input.taskset.task_type = TaskType.EXPLORE
self.buffer.explorer_input.taskset.total_epochs = self.buffer.total_epochs
self.buffer.explorer_input.taskset.total_steps = self.buffer.total_steps
if self.mode == "train":
self.buffer.trainer_input.experience_buffer.total_epochs = self.buffer.total_epochs
self.buffer.trainer_input.experience_buffer.total_steps = self.buffer.total_steps
else:
self.buffer.explorer_input.taskset.task_type = TaskType.EXPLORE
self.buffer.explorer_input.taskset.total_epochs = self.buffer.total_epochs
self.buffer.explorer_input.taskset.total_steps = self.buffer.total_steps
if self.buffer.explorer_input.taskset.default_workflow_type is None:
self.buffer.explorer_input.taskset.default_workflow_type = (
self.buffer.explorer_input.default_workflow_type
Expand Down