Skip to content

Commit 2783593

Browse files
authored
Add unittest for mix (agentscope-ai#200)
1 parent 21253c4 commit 2783593

File tree

3 files changed

+61
-3
lines changed

3 files changed

+61
-3
lines changed

examples/mix_chord/mix_chord.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ algorithm:
2121
ppo_micro_batch_size_per_gpu: 4
2222
ngpus_trainer: 4
2323
train_batch_size_expert: 64
24-
train_batch_size_usual: 256 # (40 batchsize * (1 - 0.2 expert_data_ratio)) * 8 repeat times
24+
train_batch_size_usual: 256 # 32 batchsize * 8 repeat times
2525
model:
2626
model_path: /PATH/TO/MODEL/
2727
max_response_tokens: 10240
@@ -31,7 +31,8 @@ cluster:
3131
gpu_per_node: 8
3232
buffer:
3333
total_epochs: 4
34-
batch_size: 40
34+
batch_size: 32
35+
train_batch_size: 320
3536
max_retry_times: 3
3637
max_retry_interval: 1
3738
explorer_input:

examples/mix_math/mix_math.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ cluster:
2525
gpu_per_node: 8
2626
buffer:
2727
total_epochs: 10
28-
batch_size: 40
28+
batch_size: 32
29+
train_batch_size: 320
2930
max_retry_times: 3
3031
max_retry_interval: 1
3132
explorer_input:

tests/trainer/trainer_test.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,3 +456,59 @@ def test_fully_async_mode(self, name, use_priority_queue):
456456
def tearDown(self):
457457
checkpoint_path = get_checkpoint_path()
458458
shutil.rmtree(os.path.join(checkpoint_path, "unittest"))
459+
460+
461+
class TestTrainerMIX(BaseTrainerCase):
462+
def test_trainer(self):
463+
"""Test MIX algorithm."""
464+
# gsm8k has 16 tasks, sft_for_gsm8k has 8 tasks
465+
# total 4 steps, each step: read 4 tasks from gsm8k, 16 tasks from sft_for_gsm8k
466+
self.config.algorithm.algorithm_type = "mix"
467+
self.config.algorithm.repeat_times = 4
468+
self.config.algorithm.sample_strategy = "mix"
469+
self.config.algorithm.sample_strategy_args = {"expert_data_ratio": 0.5} # rft=4*4 : sft=16
470+
self.config.algorithm.policy_loss_fn = "mix"
471+
self.config.buffer.batch_size = 4
472+
self.config.buffer.train_batch_size = 32
473+
self.config.buffer.total_epochs = 1
474+
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
475+
self.config.synchronizer.sync_interval = 1
476+
self.config.trainer.save_interval = 1
477+
self.config.buffer.trainer_input.sft_warmup_dataset = get_unittest_dataset_config(
478+
"sft_for_gsm8k"
479+
)
480+
self.config.buffer.trainer_input.sft_warmup_dataset.total_epochs = 8 # test this works
481+
self.config.check_and_update()
482+
self.config.buffer.trainer_input.experience_buffer.max_read_timeout = 20
483+
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
484+
both(self.config)
485+
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
486+
487+
# test rollout metrics
488+
rollout_metrics = parser.metric_list("rollout")
489+
self.assertTrue(len(rollout_metrics) > 0)
490+
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4)
491+
self.assertEqual(
492+
parser.metric_values("rollout/experience_count")[1], 16
493+
) # 16 rft experiences
494+
# test actor metrics
495+
actor_metrics = parser.metric_list("actor")
496+
self.assertTrue(len(actor_metrics) > 0)
497+
expert_metrics = parser.metric_list("actor/expert/")
498+
self.assertEqual(parser.metric_max_step(expert_metrics[0]), 4) # SFT
499+
usual_metrics = parser.metric_list("actor/usual/")
500+
self.assertEqual(parser.metric_max_step(usual_metrics[0]), 4) # RFT
501+
response_metrics = parser.metric_list("response_length")
502+
self.assertTrue(len(response_metrics) > 0)
503+
self.assertEqual(parser.metric_min_step(response_metrics[0]), 1)
504+
self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
505+
# test save checkpoint at last step
506+
checkpoint_dir, step_num = get_checkpoint_dir_with_step_num(
507+
checkpoint_root_path=self.config.checkpoint_job_dir,
508+
trainer_type="verl",
509+
)
510+
self.assertEqual(step_num, 4)
511+
self.assertTrue(len(os.listdir(os.path.join(checkpoint_dir, "actor"))) > 0)
512+
513+
def tearDown(self):
514+
shutil.rmtree(self.config.checkpoint_job_dir)

0 commit comments

Comments
 (0)