diff --git a/pyproject.toml b/pyproject.toml
index dcf86f8349..022c9a8ffe 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -23,7 +23,7 @@ requires-python = ">=3.10"
dependencies = [
"verl==0.3.0.post1",
"ray[default]>=2.45.0",
- "vllm>=0.8.5",
+ "vllm==0.8.5.post1",
"tensordict==0.6.2",
"wandb",
"omegaconf",
diff --git a/tests/template/config.yaml b/tests/template/config.yaml
index c83d938c66..3a767df243 100644
--- a/tests/template/config.yaml
+++ b/tests/template/config.yaml
@@ -18,8 +18,8 @@ model:
max_prompt_tokens: 2048
max_response_tokens: 2048
cluster: # 2 for explorer, 2 for trainer
- node_num: 1
- gpu_per_node: 4
+ node_num: 2
+ gpu_per_node: 2
buffer:
total_epochs: 1
batch_size: 4
diff --git a/tests/template/data/sft_for_gsm8k/sft.jsonl b/tests/template/data/sft_for_gsm8k/sft.jsonl
new file mode 100644
index 0000000000..a8d6972103
--- /dev/null
+++ b/tests/template/data/sft_for_gsm8k/sft.jsonl
@@ -0,0 +1,32 @@
+{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "}
+{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "}
+{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "}
+{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "}
+{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "}
+{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "}
+{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "}
+{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "}
+{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "}
+{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "}
+{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "}
+{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "}
+{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "}
+{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "}
+{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "}
+{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "}
+{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "}
+{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "}
+{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "}
+{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "}
+{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "}
+{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "}
+{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "}
+{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "}
+{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "}
+{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "}
+{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "}
+{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "}
+{"prompt": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?", "response": " \\nKen got $1750. Tony got twice as much as Ken, so Tony got 2 x $1750 = $3500. The total amount shared between Ken and Tony is $1750 + $3500 = $5250.\\n\\n $5250 "}
+{"prompt": "Carrie works for $8 an hour and 35 hours a week at her job. It’s been a month since she started working there. She has saved up all of her money because she wants to buy a bike for $400. How much money will she have left over after she buys her bike?", "response": " \\nCarrie\\'s hourly wage is $8 and she works 35 hours a week. \\nSo her weekly earnings are 35 hours * $8/hour = $280. \\nIt has been a month since she started working, so she has worked 4 weeks. \\nSo her total earnings for the month are 4 weeks * $280/week = $1120. \\nShe wants to buy a bike for $400. \\nSo she will have $1120 - $400 = $720 left over. \\n\\n $720 "}
+{"prompt": "If Mr. Jones has 6 shirts for every pair of pants, and he has 40 pants, what's the total number of pieces of clothes he owns if all other factors remain the same?", "response": " Mr. Jones has 6 shirts for every pair of pants. Since he has 40 pairs of pants, he has 6 * 40 = 240 shirts. So in total, he owns 240 shirts + 40 pants = 280 pieces of clothes. \\n 280 "}
+{"prompt": "Erica is telling her grandson about the animals she saw during her recent safari in Kenya's Maasai Mara. On Saturday, she saw 3 lions and 2 elephants. She spotted 2 buffaloes and 5 leopards on Sunday, and 5 rhinos and 3 warthogs on Monday. What is the total number of animals that Erica saw?", "response": " First, let\\'s add the animals Erica saw on Saturday. She saw 3 lions and 2 elephants, which gives us a total of 3 + 2 = 5 animals. On Sunday, she saw 2 buffaloes and 5 leopards, which gives us a total of 2 + 5 = 7 animals. On Monday, she saw 5 rhinos and 3 warthogs, which gives us a total of 5 + 3 = 8 animals. To find the total number of animals Erica saw, we add up the totals from each day: 5 (Saturday) + 7 (Sunday) + 8 (Monday) = 20 animals. Therefore, Erica saw a total of 20 animals. \\n 20 "}
diff --git a/tests/tools.py b/tests/tools.py
index 2e34438d66..3111839a37 100644
--- a/tests/tools.py
+++ b/tests/tools.py
@@ -13,6 +13,7 @@
StorageConfig,
load_config,
)
+from trinity.common.constants import PromptType
def get_template_config() -> Config:
@@ -59,6 +60,47 @@ def get_unittest_dataset_config(
default_workflow_type="math_workflow",
default_reward_fn_type="countdown_reward",
)
+ elif dataset_name == "gsm8k":
+ return StorageConfig(
+ name=dataset_name,
+ path="openai/gsm8k",
+ split=split,
+ subset_name="main",
+ format=FormatConfig(
+ prompt_key="question",
+ response_key="answer",
+ ),
+ rollout_args=GenerationConfig(
+ n=1,
+ temperature=1.0,
+ logprobs=0,
+ ),
+ default_workflow_type="math_workflow",
+ default_reward_fn_type="math_reward",
+ )
+ elif dataset_name == "sft_for_gsm8k":
+ return StorageConfig(
+ name=dataset_name,
+ path=os.path.join(os.path.dirname(__file__), "template", "data", "sft_for_gsm8k"),
+ split="train",
+ format=FormatConfig(
+ prompt_type=PromptType.PLAINTEXT,
+ prompt_key="prompt",
+ response_key="response",
+ ),
+ )
+ elif dataset_name == "dpo":
+ return StorageConfig(
+ name=dataset_name,
+ path="HumanLLMs/Human-Like-DPO-Dataset",
+ split="train",
+ format=FormatConfig(
+ prompt_type=PromptType.PLAINTEXT,
+ prompt_key="prompt",
+ chosen_key="chosen",
+ rejected_key="rejected",
+ ),
+ )
else:
raise ValueError(f"Unknown dataset name: {dataset_name}")
@@ -104,6 +146,11 @@ def metric_steps(self, metric_name: str) -> List[int]:
raise ValueError(f"Metric '{metric_name}' does not exist.")
return list(self._metrics[metric_name].keys())
+ def metric_values(self, metric_name: str) -> List:
+ if not self.metric_exist(metric_name):
+ raise ValueError(f"Metric '{metric_name}' does not exist.")
+ return list(self._metrics[metric_name].values())
+
def metric_list(self, metric_prefix: str) -> List[str]:
return [name for name in self._metrics if name.startswith(metric_prefix)]
diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py
index ac73e46c8d..55f63ae856 100644
--- a/tests/trainer/trainer_test.py
+++ b/tests/trainer/trainer_test.py
@@ -14,8 +14,8 @@
get_template_config,
get_unittest_dataset_config,
)
-from trinity.cli.launcher import bench, both
-from trinity.common.constants import MonitorType, SyncMethod
+from trinity.cli.launcher import bench, both, train
+from trinity.common.constants import AlgorithmType, MonitorType, SyncMethod
class BaseTrainerCase(RayUnittestBase):
@@ -109,3 +109,107 @@ def test_trainer(self):
def tearDown(self):
# remove dir only when the test passed
shutil.rmtree(self.config.checkpoint_job_dir)
+
+
+class TestTrainerGSM8K(BaseTrainerCase):
+ def test_trainer(self):
+ """Test GSM8K."""
+ # test both mode
+ self.config.algorithm.algorithm_type = AlgorithmType.GRPO
+ self.config.algorithm.repeat_times = 4
+ # self.config.algorithm.repeat_times = 8 # TODO: used for real testing
+ self.config.algorithm.advantage_fn_type = "grpo_adv_fn"
+ self.config.algorithm.advantage_fn_args = {}
+ # self.config.buffer.batch_size = 96 # TODO: used for real testing
+ self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
+ self.config.check_and_update()
+ self.config.trainer.trainer_config.trainer.total_training_steps = 4
+ self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
+ self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 1e-5
+ both(self.config)
+ parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
+ rollout_metrics = parser.metric_list("rollout")
+ self.assertTrue(len(rollout_metrics) > 0)
+ self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4)
+ actor_metrics = parser.metric_list("actor")
+ self.assertTrue(len(actor_metrics) > 0)
+ self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4)
+ response_metrics = parser.metric_list("response_length")
+ self.assertTrue(len(response_metrics) > 0)
+ self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
+ # TODO: used for real testing
+ # rewards = parser.metric_values("critic/rewards/mean")
+ # self.assertTrue(0.4 < rewards[0] < 0.55)
+ # self.assertTrue(0.4 < rewards[1] < 0.55)
+ # self.assertTrue(0.6 < rewards[2] < 0.7)
+ # self.assertTrue(0.6 < rewards[3] < 0.7)
+ ray.shutdown(_exiting_interpreter=True)
+ # check checkpoint
+
+ def tearDown(self):
+ # remove dir only when the test passed
+ shutil.rmtree(self.config.checkpoint_job_dir)
+
+
+class TestTrainerGSM8KWithSFT(BaseTrainerCase):
+ def test_trainer(self):
+ """Test GSM8K With SFT."""
+ # test both mode
+ self.config.algorithm.algorithm_type = AlgorithmType.GRPO
+ self.config.algorithm.repeat_times = 4
+ self.config.algorithm.advantage_fn_type = "grpo_adv_fn"
+ self.config.algorithm.advantage_fn_args = {}
+ self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
+ self.config.buffer.trainer_input.sft_warmup_steps = 2
+ self.config.buffer.trainer_input.sft_warmup_dataset = get_unittest_dataset_config(
+ "sft_for_gsm8k"
+ )
+ self.config.check_and_update()
+ self.config.trainer.trainer_config.trainer.total_training_steps = 4
+ self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
+ self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 1e-5
+ both(self.config)
+ parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
+ rollout_metrics = parser.metric_list("rollout")
+ self.assertTrue(len(rollout_metrics) > 0)
+ self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2)
+ actor_metrics = parser.metric_list("actor")
+ self.assertTrue(len(actor_metrics) > 0)
+ self.assertEqual(parser.metric_max_step(actor_metrics[0]), 2) # SFT
+ self.assertEqual(parser.metric_max_step(actor_metrics[-1]), 4) # RFT
+ response_metrics = parser.metric_list("response_length")
+ self.assertTrue(len(response_metrics) > 0)
+ self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
+ ray.shutdown(_exiting_interpreter=True)
+ # check checkpoint
+
+ def tearDown(self):
+ # remove dir only when the test passed
+ shutil.rmtree(self.config.checkpoint_job_dir)
+
+
+class TestTrainerDPO(BaseTrainerCase):
+ def test_trainer(self):
+ """Test DPO."""
+ # test both mode
+ self.config.mode = "train"
+ self.config.algorithm.algorithm_type = AlgorithmType.DPO
+ self.config.algorithm.policy_loss_fn = "dpo"
+ self.config.algorithm.policy_loss_fn_args = {}
+ # self.config.buffer.batch_size = 32
+ self.config.buffer.trainer_input.experience_buffer = get_unittest_dataset_config("dpo")
+ self.config.check_and_update()
+ self.config.trainer.trainer_config.trainer.total_training_steps = 4
+ self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
+ self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 5e-7
+ train(self.config)
+ parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
+ actor_metrics = parser.metric_list("actor")
+ self.assertTrue(len(actor_metrics) > 0)
+ self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4)
+ ray.shutdown(_exiting_interpreter=True)
+ # check checkpoint
+
+ def tearDown(self):
+ # remove dir only when the test passed
+ shutil.rmtree(self.config.checkpoint_job_dir)
diff --git a/trinity/algorithm/policy_loss_fn/dpo_loss.py b/trinity/algorithm/policy_loss_fn/dpo_loss.py
index 3a9ea92f5c..7dfbb7141d 100644
--- a/trinity/algorithm/policy_loss_fn/dpo_loss.py
+++ b/trinity/algorithm/policy_loss_fn/dpo_loss.py
@@ -1,6 +1,6 @@
"""DPO loss function."""
-from typing import Any, Dict, Tuple
+from typing import Dict, List, Tuple
import torch
import torch.nn.functional as F
@@ -19,13 +19,11 @@ def __init__(
self.beta = beta
self.label_smoothing = label_smoothing
- def __call__(
+ def __call__( # type: ignore
self,
logprob: torch.Tensor,
- old_logprob: torch.Tensor,
+ ref_logprob: torch.Tensor,
action_mask: torch.Tensor,
- advantages: torch.Tensor,
- experiences: Any,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
chosen_logprob = logprob[::2]
@@ -35,8 +33,8 @@ def __call__(
chosen_logprob_sum = masked_sum(chosen_logprob, chosen_mask)
rejected_logprob_sum = masked_sum(rejected_logprob, rejected_mask)
- chosen_ref_logprob = old_logprob[::2]
- rejected_ref_logprob = old_logprob[1::2]
+ chosen_ref_logprob = ref_logprob[::2]
+ rejected_ref_logprob = ref_logprob[1::2]
chosen_ref_logprob_sum = masked_sum(chosen_ref_logprob, chosen_mask)
rejected_ref_logprob_sum = masked_sum(rejected_ref_logprob, rejected_mask)
@@ -65,3 +63,10 @@ def default_args(cls) -> Dict:
"beta": 0.1,
"label_smoothing": 0.0,
}
+
+ @property
+ def select_keys(self) -> List[str]:
+ return [
+ "ref_logprob",
+ "action_mask",
+ ]
diff --git a/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py
index dd521f9ee0..e9457c55d1 100644
--- a/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py
+++ b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py
@@ -3,7 +3,7 @@
Modified from https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
"""
-from typing import Any, Dict, Tuple
+from typing import Dict, List, Tuple
import torch
@@ -16,13 +16,12 @@ class OPMDPolicyLossFn(PolicyLossFn):
def __init__(self, tau: float = 1.0) -> None:
self.tau = tau
- def __call__(
+ def __call__( # type: ignore
self,
logprob: torch.Tensor,
- old_logprob: torch.Tensor,
+ old_logprob: torch.Tensor, # NOT USED!
action_mask: torch.Tensor,
advantages: torch.Tensor,
- experiences: Any,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
pg_losses = -advantages * logprob
@@ -33,3 +32,11 @@ def __call__(
@classmethod
def default_args(cls) -> Dict:
return {"tau": 1.0}
+
+ @property
+ def select_keys(self) -> List[str]:
+ return [
+ "old_logprob",
+ "action_mask",
+ "advantages",
+ ]
diff --git a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py
index eb02c49b46..6c1a29b3e9 100644
--- a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py
+++ b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
-from typing import Any, Dict, Tuple
+from typing import Dict, List, Tuple
import torch
@@ -17,10 +17,6 @@ class PolicyLossFn(ABC):
def __call__(
self,
logprob: torch.Tensor,
- old_logprob: torch.Tensor,
- action_mask: torch.Tensor,
- advantages: torch.Tensor,
- experiences: Any,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
"""
@@ -29,7 +25,6 @@ def __call__(
old_logprob (`torch.Tensor`): The log probability generated by the reference model.
action_mask (`torch.Tensor`): The action mask.
advantages (`torch.Tensor`): The advantages.
- experiences (`DataProto`): The input experiences.
kwargs (`Dict`): The step-level parameters for calculating the policy loss.
Returns:
@@ -44,3 +39,11 @@ def default_args(cls) -> Dict:
Returns:
`Dict`: The default init arguments for the policy loss function.
"""
+
+ @property
+ @abstractmethod
+ def select_keys(self) -> List[str]:
+ """
+ Returns:
+ `List[str]`: The keys to select from input data.
+ """
diff --git a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py
index 9831f048d6..5c735d4d6a 100644
--- a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py
+++ b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py
@@ -3,7 +3,7 @@
Modified from https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
"""
-from typing import Any, Dict, Optional, Tuple
+from typing import Dict, List, Optional, Tuple
import torch
@@ -30,13 +30,12 @@ def __init__(
assert self.clip_range_low is not None, "clip_range_low must be specified."
assert self.clip_range_high is not None, "clip_range_high must be specified."
- def __call__(
+ def __call__( # type: ignore
self,
logprob: torch.Tensor,
old_logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
- experiences: Any,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
negative_approx_kl = logprob - old_logprob
@@ -62,3 +61,11 @@ def default_args(cls) -> Dict:
return {
"clip_range": 0.2,
}
+
+ @property
+ def select_keys(self) -> List[str]:
+ return [
+ "old_logprob",
+ "action_mask",
+ "advantages",
+ ]
diff --git a/trinity/algorithm/policy_loss_fn/sft_loss.py b/trinity/algorithm/policy_loss_fn/sft_loss.py
index c04f775fa3..dd1c75a4a2 100644
--- a/trinity/algorithm/policy_loss_fn/sft_loss.py
+++ b/trinity/algorithm/policy_loss_fn/sft_loss.py
@@ -1,6 +1,6 @@
"""SFT loss function."""
-from typing import Any, Dict, Tuple
+from typing import Dict, List, Tuple
import torch
@@ -13,13 +13,10 @@ class SFTLossFn(PolicyLossFn):
def __init__(self, use_token_level_loss: bool = True) -> None:
self.use_token_level_loss = use_token_level_loss
- def __call__(
+ def __call__( # type: ignore
self,
logprob: torch.Tensor,
- old_logprob: torch.Tensor,
action_mask: torch.Tensor,
- advantages: torch.Tensor,
- experiences: Any,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
if self.use_token_level_loss:
@@ -33,3 +30,7 @@ def default_args(cls):
return {
"use_token_level_loss": True,
}
+
+ @property
+ def select_keys(self) -> List[str]:
+ return ["action_mask"]
diff --git a/trinity/common/config.py b/trinity/common/config.py
index 794202bab0..5d294abdfd 100644
--- a/trinity/common/config.py
+++ b/trinity/common/config.py
@@ -182,6 +182,9 @@ class AlgorithmConfig:
# If not set, use AdvantageFn.default_args()
advantage_fn_args: Optional[dict] = None
+ # used for SFT
+ use_token_level_loss: bool = True
+
@dataclass
class ClusterConfig:
@@ -452,7 +455,7 @@ def _check_buffer(self) -> None: # noqa: C901
and self.buffer.trainer_input.sft_warmup_dataset is None
):
raise ValueError(
- "buffer.trainer_input.sft_warmup_dataset is required when buffer.trainer_input.sft_warmup_steps > 0"
+ "`buffer.trainer_input.sft_warmup_dataset` is required when `buffer.trainer_input.sft_warmup_steps` > 0"
)
if self.buffer.trainer_input.sft_warmup_dataset is not None:
self.buffer.trainer_input.sft_warmup_dataset.algorithm_type = AlgorithmType.SFT
diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py
index 876ca2835f..7208f83fb4 100644
--- a/trinity/trainer/trainer.py
+++ b/trinity/trainer/trainer.py
@@ -73,7 +73,17 @@ def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool
Returns:
bool: Whether to continue training.
"""
- self.engine.set_algorithm(self.config.algorithm)
+ if algo_type.is_sft():
+ algorithm_config = AlgorithmConfig(
+ algorithm_type=AlgorithmType.SFT,
+ policy_loss_fn="sft",
+ policy_loss_fn_args={
+ "use_token_level_loss": self.config.algorithm.use_token_level_loss
+ },
+ )
+ self.engine.set_algorithm(algorithm_config)
+ else:
+ self.engine.set_algorithm(self.config.algorithm)
if algo_type.is_rft() and self.config.buffer.trainer_input.read_experience_strategy:
strategy = self.config.buffer.trainer_input.read_experience_strategy
else:
diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py
index b598bb6dad..97cd186c36 100644
--- a/trinity/trainer/verl/dp_actor.py
+++ b/trinity/trainer/verl/dp_actor.py
@@ -279,16 +279,25 @@ def update_policy(self, data: DataProto): # noqa: C901
"temperature"
] # temperature must be in the data.meta_info to avoid slient error
select_keys = [
- "responses",
"input_ids",
- "attention_mask",
"position_ids",
- "old_log_probs",
- "advantages",
+ "attention_mask",
+ "responses",
"response_mask",
]
+ select_keys_verl2trinity = {
+ "old_log_probs": "old_logprob",
+ "ref_log_prob": "ref_logprob",
+ "response_mask": "action_mask",
+ "advantages": "advantages",
+ }
+ select_keys_trinity2verl = {value: key for key, value in select_keys_verl2trinity.items()}
+ for trinity_key in self.policy_loss_fn.select_keys:
+ verl_key = select_keys_trinity2verl[trinity_key]
+ select_keys.append(verl_key)
if self.config.use_kl_loss:
select_keys.append("ref_log_prob")
+ select_keys = list(set(select_keys))
batch = data.select(batch_keys=select_keys).batch
has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
@@ -351,11 +360,8 @@ def update_policy(self, data: DataProto): # noqa: C901
responses = data["responses"]
response_length = responses.size(1)
attention_mask = data["attention_mask"]
- # response_mask = attention_mask[:, -response_length:]
response_mask = data["response_mask"]
assert response_mask.shape == attention_mask[:, -response_length:].shape
- old_log_prob = data["old_log_probs"]
- advantages = data["advantages"]
entropy_coeff = self.config.entropy_coeff
# all return: (bsz, response_length)
@@ -363,12 +369,14 @@ def update_policy(self, data: DataProto): # noqa: C901
micro_batch=data, temperature=temperature
)
+ kwargs = {
+ select_keys_verl2trinity[verl_key]: value
+ for verl_key, value in data.items()
+ if verl_key in select_keys_verl2trinity
+ }
pg_loss, metric = self.policy_loss_fn( # type: ignore
logprob=log_prob,
- old_logprob=old_log_prob,
- action_mask=response_mask,
- advantages=advantages,
- experiences=data,
+ **kwargs,
)
# compute entropy loss from entropy
diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py
index b6397adde7..ca02b6c288 100644
--- a/trinity/trainer/verl_trainer.py
+++ b/trinity/trainer/verl_trainer.py
@@ -4,6 +4,7 @@
Modified from verl/trainer/ppo/ray_trainer.py
"""
import os
+import sys
from pprint import pprint
from typing import Tuple
@@ -169,29 +170,14 @@ def prepare(self):
return
# we start from step 1
- self.global_steps += 1
def _create_dataloader(self):
self.train_dataloader = _InternalDataLoader(self.config)
# TODO: compute total training steps
- # if self.algorithm_type.is_dpo():
- # train_batch_size = self.config.buffer.read_batch_size
- # total_epochs = self.config.trainer.total_epochs
- # from math import ceil
-
- # self.total_training_steps = ceil(
- # self.train_dataloader.size() // train_batch_size * total_epochs
- # )
- # if not self.config.actor_rollout_ref.actor.optim.total_training_steps > 0:
- # self.config.actor_rollout_ref.actor.optim.total_training_steps = (
- # self.total_training_steps
- # )
- # if not self.config.critic.optim.total_training_steps > 0:
- # self.config.critic.optim.total_training_steps = self.total_training_steps
- # else:
- self.total_training_steps = float("inf")
+ self.total_training_steps = self.config.trainer.total_training_steps or sys.maxsize
def train_dpo_step(self, experiences: Experiences) -> Tuple[bool, int]:
+ self.global_steps += 1
metrics = {}
timing_raw = {}
@@ -251,12 +237,23 @@ def train_dpo_step(self, experiences: Experiences) -> Tuple[bool, int]:
with _timer("save_checkpoint", timing_raw):
self._save_checkpoint()
- self.global_steps += 1
- return True, self.global_steps - 1
+ if self.global_steps >= self.total_training_steps:
+ if (
+ self.config.trainer.save_freq > 0
+ and self.global_steps % self.config.trainer.save_freq != 0
+ ):
+ with _timer("save_checkpoint", timing_raw):
+ self._save_checkpoint()
+ # stop training
+ return False, self.global_steps
+ else:
+ # continue
+ return True, self.global_steps
def train_sft_step(self, experiences: Experiences) -> Tuple[bool, int]:
if self.sft_warmup_step_num >= self.config.trainer.sft_warmup_steps:
- return False, self.global_steps - 1
+ return False, self.global_steps
+ self.global_steps += 1
metrics = {}
timing_raw = {}
@@ -308,18 +305,19 @@ def train_sft_step(self, experiences: Experiences) -> Tuple[bool, int]:
# TODO: log as sft metrics
self.logger.log(data=metrics, step=self.global_steps)
self.sft_warmup_step_num += 1
- self.global_steps += 1
+ train_status = True
if self.sft_warmup_step_num == self.config.trainer.sft_warmup_steps:
self.logger.log(
data={"sft_warmup_steps": self.sft_warmup_step_num},
- step=self.global_steps - 1,
+ step=self.global_steps,
)
with _timer("save_checkpoint", timing_raw):
self._save_checkpoint()
- return False, self.global_steps - 1
- return True, self.global_steps - 1
+ train_status = False
+ return train_status, self.global_steps
def train_rft_step(self, experiences: Experiences) -> Tuple[bool, int]:
+ self.global_steps += 1
metrics = {}
timing_raw = {}
@@ -426,20 +424,18 @@ def train_rft_step(self, experiences: Experiences) -> Tuple[bool, int]:
# TODO: make a canonical logger that supports various backend
self.logger.log(data=metrics, step=self.global_steps)
- self.global_steps += 1
-
if self.global_steps >= self.total_training_steps:
if (
self.config.trainer.save_freq > 0
- and (self.global_steps - 1) % self.config.trainer.save_freq != 0
+ and self.global_steps % self.config.trainer.save_freq != 0
):
with _timer("save_checkpoint", timing_raw):
self._save_checkpoint()
# stop training
- return False, self.global_steps - 1
+ return False, self.global_steps
else:
# continue
- return True, self.global_steps - 1
+ return True, self.global_steps
def _log_single_experience(
self, experiences: Experiences, idx: int, skip_special_tokens: bool