diff --git a/pyproject.toml b/pyproject.toml index dcf86f8349..022c9a8ffe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ requires-python = ">=3.10" dependencies = [ "verl==0.3.0.post1", "ray[default]>=2.45.0", - "vllm>=0.8.5", + "vllm==0.8.5.post1", "tensordict==0.6.2", "wandb", "omegaconf", diff --git a/tests/template/config.yaml b/tests/template/config.yaml index c83d938c66..3a767df243 100644 --- a/tests/template/config.yaml +++ b/tests/template/config.yaml @@ -18,8 +18,8 @@ model: max_prompt_tokens: 2048 max_response_tokens: 2048 cluster: # 2 for explorer, 2 for trainer - node_num: 1 - gpu_per_node: 4 + node_num: 2 + gpu_per_node: 2 buffer: total_epochs: 1 batch_size: 4 diff --git a/tests/template/data/sft_for_gsm8k/sft.jsonl b/tests/template/data/sft_for_gsm8k/sft.jsonl new file mode 100644 index 0000000000..a8d6972103 --- /dev/null +++ b/tests/template/data/sft_for_gsm8k/sft.jsonl @@ -0,0 +1,32 @@ +{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "} +{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "} +{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "} +{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "} +{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "} +{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "} +{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "} +{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "} +{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "} +{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "} +{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "} +{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "} +{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "} +{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "} +{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "} +{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "} +{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "} +{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "} +{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "} +{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "} +{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "} +{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "} +{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "} +{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "} +{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "} +{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "} +{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "} +{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "} +{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "} +{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "} +{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "} +{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "} diff --git a/tests/tools.py b/tests/tools.py index 2e34438d66..3111839a37 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -13,6 +13,7 @@ StorageConfig, load_config, ) +from trinity.common.constants import PromptType def get_template_config() -> Config: @@ -59,6 +60,47 @@ def get_unittest_dataset_config( default_workflow_type="math_workflow", default_reward_fn_type="countdown_reward", ) + elif dataset_name == "gsm8k": + return StorageConfig( + name=dataset_name, + path="openai/gsm8k", + split=split, + subset_name="main", + format=FormatConfig( + prompt_key="question", + response_key="answer", + ), + rollout_args=GenerationConfig( + n=1, + temperature=1.0, + logprobs=0, + ), + default_workflow_type="math_workflow", + default_reward_fn_type="math_reward", + ) + elif dataset_name == "sft_for_gsm8k": + return StorageConfig( + name=dataset_name, + path=os.path.join(os.path.dirname(__file__), "template", "data", "sft_for_gsm8k"), + split="train", + format=FormatConfig( + prompt_type=PromptType.PLAINTEXT, + prompt_key="prompt", + response_key="response", + ), + ) + elif dataset_name == "dpo": + return StorageConfig( + name=dataset_name, + path="HumanLLMs/Human-Like-DPO-Dataset", + split="train", + format=FormatConfig( + prompt_type=PromptType.PLAINTEXT, + prompt_key="prompt", + chosen_key="chosen", + rejected_key="rejected", + ), + ) else: raise ValueError(f"Unknown dataset name: {dataset_name}") @@ -104,6 +146,11 @@ def metric_steps(self, metric_name: str) -> List[int]: raise ValueError(f"Metric '{metric_name}' does not exist.") return list(self._metrics[metric_name].keys()) + def metric_values(self, metric_name: str) -> List: + if not self.metric_exist(metric_name): + raise ValueError(f"Metric '{metric_name}' does not exist.") + return list(self._metrics[metric_name].values()) + def metric_list(self, metric_prefix: str) -> List[str]: return [name for name in self._metrics if name.startswith(metric_prefix)] diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index ac73e46c8d..55f63ae856 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -14,8 +14,8 @@ get_template_config, get_unittest_dataset_config, ) -from trinity.cli.launcher import bench, both -from trinity.common.constants import MonitorType, SyncMethod +from trinity.cli.launcher import bench, both, train +from trinity.common.constants import AlgorithmType, MonitorType, SyncMethod class BaseTrainerCase(RayUnittestBase): @@ -109,3 +109,107 @@ def test_trainer(self): def tearDown(self): # remove dir only when the test passed shutil.rmtree(self.config.checkpoint_job_dir) + + +class TestTrainerGSM8K(BaseTrainerCase): + def test_trainer(self): + """Test GSM8K.""" + # test both mode + self.config.algorithm.algorithm_type = AlgorithmType.GRPO + self.config.algorithm.repeat_times = 4 + # self.config.algorithm.repeat_times = 8 # TODO: used for real testing + self.config.algorithm.advantage_fn_type = "grpo_adv_fn" + self.config.algorithm.advantage_fn_args = {} + # self.config.buffer.batch_size = 96 # TODO: used for real testing + self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") + self.config.check_and_update() + self.config.trainer.trainer_config.trainer.total_training_steps = 4 + self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2 + self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 1e-5 + both(self.config) + parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) + rollout_metrics = parser.metric_list("rollout") + self.assertTrue(len(rollout_metrics) > 0) + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4) + actor_metrics = parser.metric_list("actor") + self.assertTrue(len(actor_metrics) > 0) + self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4) + response_metrics = parser.metric_list("response_length") + self.assertTrue(len(response_metrics) > 0) + self.assertEqual(parser.metric_max_step(response_metrics[0]), 4) + # TODO: used for real testing + # rewards = parser.metric_values("critic/rewards/mean") + # self.assertTrue(0.4 < rewards[0] < 0.55) + # self.assertTrue(0.4 < rewards[1] < 0.55) + # self.assertTrue(0.6 < rewards[2] < 0.7) + # self.assertTrue(0.6 < rewards[3] < 0.7) + ray.shutdown(_exiting_interpreter=True) + # check checkpoint + + def tearDown(self): + # remove dir only when the test passed + shutil.rmtree(self.config.checkpoint_job_dir) + + +class TestTrainerGSM8KWithSFT(BaseTrainerCase): + def test_trainer(self): + """Test GSM8K With SFT.""" + # test both mode + self.config.algorithm.algorithm_type = AlgorithmType.GRPO + self.config.algorithm.repeat_times = 4 + self.config.algorithm.advantage_fn_type = "grpo_adv_fn" + self.config.algorithm.advantage_fn_args = {} + self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") + self.config.buffer.trainer_input.sft_warmup_steps = 2 + self.config.buffer.trainer_input.sft_warmup_dataset = get_unittest_dataset_config( + "sft_for_gsm8k" + ) + self.config.check_and_update() + self.config.trainer.trainer_config.trainer.total_training_steps = 4 + self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2 + self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 1e-5 + both(self.config) + parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) + rollout_metrics = parser.metric_list("rollout") + self.assertTrue(len(rollout_metrics) > 0) + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2) + actor_metrics = parser.metric_list("actor") + self.assertTrue(len(actor_metrics) > 0) + self.assertEqual(parser.metric_max_step(actor_metrics[0]), 2) # SFT + self.assertEqual(parser.metric_max_step(actor_metrics[-1]), 4) # RFT + response_metrics = parser.metric_list("response_length") + self.assertTrue(len(response_metrics) > 0) + self.assertEqual(parser.metric_max_step(response_metrics[0]), 4) + ray.shutdown(_exiting_interpreter=True) + # check checkpoint + + def tearDown(self): + # remove dir only when the test passed + shutil.rmtree(self.config.checkpoint_job_dir) + + +class TestTrainerDPO(BaseTrainerCase): + def test_trainer(self): + """Test DPO.""" + # test both mode + self.config.mode = "train" + self.config.algorithm.algorithm_type = AlgorithmType.DPO + self.config.algorithm.policy_loss_fn = "dpo" + self.config.algorithm.policy_loss_fn_args = {} + # self.config.buffer.batch_size = 32 + self.config.buffer.trainer_input.experience_buffer = get_unittest_dataset_config("dpo") + self.config.check_and_update() + self.config.trainer.trainer_config.trainer.total_training_steps = 4 + self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2 + self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 5e-7 + train(self.config) + parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) + actor_metrics = parser.metric_list("actor") + self.assertTrue(len(actor_metrics) > 0) + self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4) + ray.shutdown(_exiting_interpreter=True) + # check checkpoint + + def tearDown(self): + # remove dir only when the test passed + shutil.rmtree(self.config.checkpoint_job_dir) diff --git a/trinity/algorithm/policy_loss_fn/dpo_loss.py b/trinity/algorithm/policy_loss_fn/dpo_loss.py index 3a9ea92f5c..7dfbb7141d 100644 --- a/trinity/algorithm/policy_loss_fn/dpo_loss.py +++ b/trinity/algorithm/policy_loss_fn/dpo_loss.py @@ -1,6 +1,6 @@ """DPO loss function.""" -from typing import Any, Dict, Tuple +from typing import Dict, List, Tuple import torch import torch.nn.functional as F @@ -19,13 +19,11 @@ def __init__( self.beta = beta self.label_smoothing = label_smoothing - def __call__( + def __call__( # type: ignore self, logprob: torch.Tensor, - old_logprob: torch.Tensor, + ref_logprob: torch.Tensor, action_mask: torch.Tensor, - advantages: torch.Tensor, - experiences: Any, **kwargs, ) -> Tuple[torch.Tensor, Dict]: chosen_logprob = logprob[::2] @@ -35,8 +33,8 @@ def __call__( chosen_logprob_sum = masked_sum(chosen_logprob, chosen_mask) rejected_logprob_sum = masked_sum(rejected_logprob, rejected_mask) - chosen_ref_logprob = old_logprob[::2] - rejected_ref_logprob = old_logprob[1::2] + chosen_ref_logprob = ref_logprob[::2] + rejected_ref_logprob = ref_logprob[1::2] chosen_ref_logprob_sum = masked_sum(chosen_ref_logprob, chosen_mask) rejected_ref_logprob_sum = masked_sum(rejected_ref_logprob, rejected_mask) @@ -65,3 +63,10 @@ def default_args(cls) -> Dict: "beta": 0.1, "label_smoothing": 0.0, } + + @property + def select_keys(self) -> List[str]: + return [ + "ref_logprob", + "action_mask", + ] diff --git a/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py index dd521f9ee0..e9457c55d1 100644 --- a/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py @@ -3,7 +3,7 @@ Modified from https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py """ -from typing import Any, Dict, Tuple +from typing import Dict, List, Tuple import torch @@ -16,13 +16,12 @@ class OPMDPolicyLossFn(PolicyLossFn): def __init__(self, tau: float = 1.0) -> None: self.tau = tau - def __call__( + def __call__( # type: ignore self, logprob: torch.Tensor, - old_logprob: torch.Tensor, + old_logprob: torch.Tensor, # NOT USED! action_mask: torch.Tensor, advantages: torch.Tensor, - experiences: Any, **kwargs, ) -> Tuple[torch.Tensor, Dict]: pg_losses = -advantages * logprob @@ -33,3 +32,11 @@ def __call__( @classmethod def default_args(cls) -> Dict: return {"tau": 1.0} + + @property + def select_keys(self) -> List[str]: + return [ + "old_logprob", + "action_mask", + "advantages", + ] diff --git a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py index eb02c49b46..6c1a29b3e9 100644 --- a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py +++ b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, Tuple +from typing import Dict, List, Tuple import torch @@ -17,10 +17,6 @@ class PolicyLossFn(ABC): def __call__( self, logprob: torch.Tensor, - old_logprob: torch.Tensor, - action_mask: torch.Tensor, - advantages: torch.Tensor, - experiences: Any, **kwargs, ) -> Tuple[torch.Tensor, Dict]: """ @@ -29,7 +25,6 @@ def __call__( old_logprob (`torch.Tensor`): The log probability generated by the reference model. action_mask (`torch.Tensor`): The action mask. advantages (`torch.Tensor`): The advantages. - experiences (`DataProto`): The input experiences. kwargs (`Dict`): The step-level parameters for calculating the policy loss. Returns: @@ -44,3 +39,11 @@ def default_args(cls) -> Dict: Returns: `Dict`: The default init arguments for the policy loss function. """ + + @property + @abstractmethod + def select_keys(self) -> List[str]: + """ + Returns: + `List[str]`: The keys to select from input data. + """ diff --git a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py index 9831f048d6..5c735d4d6a 100644 --- a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py @@ -3,7 +3,7 @@ Modified from https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py """ -from typing import Any, Dict, Optional, Tuple +from typing import Dict, List, Optional, Tuple import torch @@ -30,13 +30,12 @@ def __init__( assert self.clip_range_low is not None, "clip_range_low must be specified." assert self.clip_range_high is not None, "clip_range_high must be specified." - def __call__( + def __call__( # type: ignore self, logprob: torch.Tensor, old_logprob: torch.Tensor, action_mask: torch.Tensor, advantages: torch.Tensor, - experiences: Any, **kwargs, ) -> Tuple[torch.Tensor, Dict]: negative_approx_kl = logprob - old_logprob @@ -62,3 +61,11 @@ def default_args(cls) -> Dict: return { "clip_range": 0.2, } + + @property + def select_keys(self) -> List[str]: + return [ + "old_logprob", + "action_mask", + "advantages", + ] diff --git a/trinity/algorithm/policy_loss_fn/sft_loss.py b/trinity/algorithm/policy_loss_fn/sft_loss.py index c04f775fa3..dd1c75a4a2 100644 --- a/trinity/algorithm/policy_loss_fn/sft_loss.py +++ b/trinity/algorithm/policy_loss_fn/sft_loss.py @@ -1,6 +1,6 @@ """SFT loss function.""" -from typing import Any, Dict, Tuple +from typing import Dict, List, Tuple import torch @@ -13,13 +13,10 @@ class SFTLossFn(PolicyLossFn): def __init__(self, use_token_level_loss: bool = True) -> None: self.use_token_level_loss = use_token_level_loss - def __call__( + def __call__( # type: ignore self, logprob: torch.Tensor, - old_logprob: torch.Tensor, action_mask: torch.Tensor, - advantages: torch.Tensor, - experiences: Any, **kwargs, ) -> Tuple[torch.Tensor, Dict]: if self.use_token_level_loss: @@ -33,3 +30,7 @@ def default_args(cls): return { "use_token_level_loss": True, } + + @property + def select_keys(self) -> List[str]: + return ["action_mask"] diff --git a/trinity/common/config.py b/trinity/common/config.py index 794202bab0..5d294abdfd 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -182,6 +182,9 @@ class AlgorithmConfig: # If not set, use AdvantageFn.default_args() advantage_fn_args: Optional[dict] = None + # used for SFT + use_token_level_loss: bool = True + @dataclass class ClusterConfig: @@ -452,7 +455,7 @@ def _check_buffer(self) -> None: # noqa: C901 and self.buffer.trainer_input.sft_warmup_dataset is None ): raise ValueError( - "buffer.trainer_input.sft_warmup_dataset is required when buffer.trainer_input.sft_warmup_steps > 0" + "`buffer.trainer_input.sft_warmup_dataset` is required when `buffer.trainer_input.sft_warmup_steps` > 0" ) if self.buffer.trainer_input.sft_warmup_dataset is not None: self.buffer.trainer_input.sft_warmup_dataset.algorithm_type = AlgorithmType.SFT diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 876ca2835f..7208f83fb4 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -73,7 +73,17 @@ def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool Returns: bool: Whether to continue training. """ - self.engine.set_algorithm(self.config.algorithm) + if algo_type.is_sft(): + algorithm_config = AlgorithmConfig( + algorithm_type=AlgorithmType.SFT, + policy_loss_fn="sft", + policy_loss_fn_args={ + "use_token_level_loss": self.config.algorithm.use_token_level_loss + }, + ) + self.engine.set_algorithm(algorithm_config) + else: + self.engine.set_algorithm(self.config.algorithm) if algo_type.is_rft() and self.config.buffer.trainer_input.read_experience_strategy: strategy = self.config.buffer.trainer_input.read_experience_strategy else: diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index b598bb6dad..97cd186c36 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -279,16 +279,25 @@ def update_policy(self, data: DataProto): # noqa: C901 "temperature" ] # temperature must be in the data.meta_info to avoid slient error select_keys = [ - "responses", "input_ids", - "attention_mask", "position_ids", - "old_log_probs", - "advantages", + "attention_mask", + "responses", "response_mask", ] + select_keys_verl2trinity = { + "old_log_probs": "old_logprob", + "ref_log_prob": "ref_logprob", + "response_mask": "action_mask", + "advantages": "advantages", + } + select_keys_trinity2verl = {value: key for key, value in select_keys_verl2trinity.items()} + for trinity_key in self.policy_loss_fn.select_keys: + verl_key = select_keys_trinity2verl[trinity_key] + select_keys.append(verl_key) if self.config.use_kl_loss: select_keys.append("ref_log_prob") + select_keys = list(set(select_keys)) batch = data.select(batch_keys=select_keys).batch has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() @@ -351,11 +360,8 @@ def update_policy(self, data: DataProto): # noqa: C901 responses = data["responses"] response_length = responses.size(1) attention_mask = data["attention_mask"] - # response_mask = attention_mask[:, -response_length:] response_mask = data["response_mask"] assert response_mask.shape == attention_mask[:, -response_length:].shape - old_log_prob = data["old_log_probs"] - advantages = data["advantages"] entropy_coeff = self.config.entropy_coeff # all return: (bsz, response_length) @@ -363,12 +369,14 @@ def update_policy(self, data: DataProto): # noqa: C901 micro_batch=data, temperature=temperature ) + kwargs = { + select_keys_verl2trinity[verl_key]: value + for verl_key, value in data.items() + if verl_key in select_keys_verl2trinity + } pg_loss, metric = self.policy_loss_fn( # type: ignore logprob=log_prob, - old_logprob=old_log_prob, - action_mask=response_mask, - advantages=advantages, - experiences=data, + **kwargs, ) # compute entropy loss from entropy diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index b6397adde7..ca02b6c288 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -4,6 +4,7 @@ Modified from verl/trainer/ppo/ray_trainer.py """ import os +import sys from pprint import pprint from typing import Tuple @@ -169,29 +170,14 @@ def prepare(self): return # we start from step 1 - self.global_steps += 1 def _create_dataloader(self): self.train_dataloader = _InternalDataLoader(self.config) # TODO: compute total training steps - # if self.algorithm_type.is_dpo(): - # train_batch_size = self.config.buffer.read_batch_size - # total_epochs = self.config.trainer.total_epochs - # from math import ceil - - # self.total_training_steps = ceil( - # self.train_dataloader.size() // train_batch_size * total_epochs - # ) - # if not self.config.actor_rollout_ref.actor.optim.total_training_steps > 0: - # self.config.actor_rollout_ref.actor.optim.total_training_steps = ( - # self.total_training_steps - # ) - # if not self.config.critic.optim.total_training_steps > 0: - # self.config.critic.optim.total_training_steps = self.total_training_steps - # else: - self.total_training_steps = float("inf") + self.total_training_steps = self.config.trainer.total_training_steps or sys.maxsize def train_dpo_step(self, experiences: Experiences) -> Tuple[bool, int]: + self.global_steps += 1 metrics = {} timing_raw = {} @@ -251,12 +237,23 @@ def train_dpo_step(self, experiences: Experiences) -> Tuple[bool, int]: with _timer("save_checkpoint", timing_raw): self._save_checkpoint() - self.global_steps += 1 - return True, self.global_steps - 1 + if self.global_steps >= self.total_training_steps: + if ( + self.config.trainer.save_freq > 0 + and self.global_steps % self.config.trainer.save_freq != 0 + ): + with _timer("save_checkpoint", timing_raw): + self._save_checkpoint() + # stop training + return False, self.global_steps + else: + # continue + return True, self.global_steps def train_sft_step(self, experiences: Experiences) -> Tuple[bool, int]: if self.sft_warmup_step_num >= self.config.trainer.sft_warmup_steps: - return False, self.global_steps - 1 + return False, self.global_steps + self.global_steps += 1 metrics = {} timing_raw = {} @@ -308,18 +305,19 @@ def train_sft_step(self, experiences: Experiences) -> Tuple[bool, int]: # TODO: log as sft metrics self.logger.log(data=metrics, step=self.global_steps) self.sft_warmup_step_num += 1 - self.global_steps += 1 + train_status = True if self.sft_warmup_step_num == self.config.trainer.sft_warmup_steps: self.logger.log( data={"sft_warmup_steps": self.sft_warmup_step_num}, - step=self.global_steps - 1, + step=self.global_steps, ) with _timer("save_checkpoint", timing_raw): self._save_checkpoint() - return False, self.global_steps - 1 - return True, self.global_steps - 1 + train_status = False + return train_status, self.global_steps def train_rft_step(self, experiences: Experiences) -> Tuple[bool, int]: + self.global_steps += 1 metrics = {} timing_raw = {} @@ -426,20 +424,18 @@ def train_rft_step(self, experiences: Experiences) -> Tuple[bool, int]: # TODO: make a canonical logger that supports various backend self.logger.log(data=metrics, step=self.global_steps) - self.global_steps += 1 - if self.global_steps >= self.total_training_steps: if ( self.config.trainer.save_freq > 0 - and (self.global_steps - 1) % self.config.trainer.save_freq != 0 + and self.global_steps % self.config.trainer.save_freq != 0 ): with _timer("save_checkpoint", timing_raw): self._save_checkpoint() # stop training - return False, self.global_steps - 1 + return False, self.global_steps else: # continue - return True, self.global_steps - 1 + return True, self.global_steps def _log_single_experience( self, experiences: Experiences, idx: int, skip_special_tokens: bool