Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ requires-python = ">=3.10"
dependencies = [
"verl==0.3.0.post1",
"ray[default]>=2.45.0",
"vllm>=0.8.5",
"vllm==0.8.5.post1",
"tensordict==0.6.2",
"wandb",
"omegaconf",
Expand Down
4 changes: 2 additions & 2 deletions tests/template/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ model:
max_prompt_tokens: 2048
max_response_tokens: 2048
cluster: # 2 for explorer, 2 for trainer
node_num: 1
gpu_per_node: 4
node_num: 2
gpu_per_node: 2
buffer:
total_epochs: 1
batch_size: 4
Expand Down
32 changes: 32 additions & 0 deletions tests/template/data/sft_for_gsm8k/sft.jsonl

Large diffs are not rendered by default.

47 changes: 47 additions & 0 deletions tests/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
StorageConfig,
load_config,
)
from trinity.common.constants import PromptType


def get_template_config() -> Config:
Expand Down Expand Up @@ -59,6 +60,47 @@ def get_unittest_dataset_config(
default_workflow_type="math_workflow",
default_reward_fn_type="countdown_reward",
)
elif dataset_name == "gsm8k":
return StorageConfig(
name=dataset_name,
path="openai/gsm8k",
split=split,
subset_name="main",
format=FormatConfig(
prompt_key="question",
response_key="answer",
),
rollout_args=GenerationConfig(
n=1,
temperature=1.0,
logprobs=0,
),
default_workflow_type="math_workflow",
default_reward_fn_type="math_reward",
)
elif dataset_name == "sft_for_gsm8k":
return StorageConfig(
name=dataset_name,
path=os.path.join(os.path.dirname(__file__), "template", "data", "sft_for_gsm8k"),
split="train",
format=FormatConfig(
prompt_type=PromptType.PLAINTEXT,
prompt_key="prompt",
response_key="response",
),
)
elif dataset_name == "dpo":
return StorageConfig(
name=dataset_name,
path="HumanLLMs/Human-Like-DPO-Dataset",
split="train",
format=FormatConfig(
prompt_type=PromptType.PLAINTEXT,
prompt_key="prompt",
chosen_key="chosen",
rejected_key="rejected",
),
)
else:
raise ValueError(f"Unknown dataset name: {dataset_name}")

Expand Down Expand Up @@ -104,6 +146,11 @@ def metric_steps(self, metric_name: str) -> List[int]:
raise ValueError(f"Metric '{metric_name}' does not exist.")
return list(self._metrics[metric_name].keys())

def metric_values(self, metric_name: str) -> List:
if not self.metric_exist(metric_name):
raise ValueError(f"Metric '{metric_name}' does not exist.")
return list(self._metrics[metric_name].values())

def metric_list(self, metric_prefix: str) -> List[str]:
return [name for name in self._metrics if name.startswith(metric_prefix)]

Expand Down
108 changes: 106 additions & 2 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
get_template_config,
get_unittest_dataset_config,
)
from trinity.cli.launcher import bench, both
from trinity.common.constants import MonitorType, SyncMethod
from trinity.cli.launcher import bench, both, train
from trinity.common.constants import AlgorithmType, MonitorType, SyncMethod


class BaseTrainerCase(RayUnittestBase):
Expand Down Expand Up @@ -109,3 +109,107 @@ def test_trainer(self):
def tearDown(self):
# remove dir only when the test passed
shutil.rmtree(self.config.checkpoint_job_dir)


class TestTrainerGSM8K(BaseTrainerCase):
def test_trainer(self):
"""Test GSM8K."""
# test both mode
self.config.algorithm.algorithm_type = AlgorithmType.GRPO
self.config.algorithm.repeat_times = 4
# self.config.algorithm.repeat_times = 8 # TODO: used for real testing
self.config.algorithm.advantage_fn_type = "grpo_adv_fn"
self.config.algorithm.advantage_fn_args = {}
# self.config.buffer.batch_size = 96 # TODO: used for real testing
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
self.config.check_and_update()
self.config.trainer.trainer_config.trainer.total_training_steps = 4
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 1e-5
both(self.config)
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
rollout_metrics = parser.metric_list("rollout")
self.assertTrue(len(rollout_metrics) > 0)
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4)
actor_metrics = parser.metric_list("actor")
self.assertTrue(len(actor_metrics) > 0)
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4)
response_metrics = parser.metric_list("response_length")
self.assertTrue(len(response_metrics) > 0)
self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
# TODO: used for real testing
# rewards = parser.metric_values("critic/rewards/mean")
# self.assertTrue(0.4 < rewards[0] < 0.55)
# self.assertTrue(0.4 < rewards[1] < 0.55)
# self.assertTrue(0.6 < rewards[2] < 0.7)
# self.assertTrue(0.6 < rewards[3] < 0.7)
ray.shutdown(_exiting_interpreter=True)
# check checkpoint

def tearDown(self):
# remove dir only when the test passed
shutil.rmtree(self.config.checkpoint_job_dir)


class TestTrainerGSM8KWithSFT(BaseTrainerCase):
def test_trainer(self):
"""Test GSM8K With SFT."""
# test both mode
self.config.algorithm.algorithm_type = AlgorithmType.GRPO
self.config.algorithm.repeat_times = 4
self.config.algorithm.advantage_fn_type = "grpo_adv_fn"
self.config.algorithm.advantage_fn_args = {}
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
self.config.buffer.trainer_input.sft_warmup_steps = 2
self.config.buffer.trainer_input.sft_warmup_dataset = get_unittest_dataset_config(
"sft_for_gsm8k"
)
self.config.check_and_update()
self.config.trainer.trainer_config.trainer.total_training_steps = 4
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 1e-5
both(self.config)
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
rollout_metrics = parser.metric_list("rollout")
self.assertTrue(len(rollout_metrics) > 0)
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2)
actor_metrics = parser.metric_list("actor")
self.assertTrue(len(actor_metrics) > 0)
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 2) # SFT
self.assertEqual(parser.metric_max_step(actor_metrics[-1]), 4) # RFT
response_metrics = parser.metric_list("response_length")
self.assertTrue(len(response_metrics) > 0)
self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
ray.shutdown(_exiting_interpreter=True)
# check checkpoint

def tearDown(self):
# remove dir only when the test passed
shutil.rmtree(self.config.checkpoint_job_dir)


class TestTrainerDPO(BaseTrainerCase):
def test_trainer(self):
"""Test DPO."""
# test both mode
self.config.mode = "train"
self.config.algorithm.algorithm_type = AlgorithmType.DPO
self.config.algorithm.policy_loss_fn = "dpo"
self.config.algorithm.policy_loss_fn_args = {}
# self.config.buffer.batch_size = 32
self.config.buffer.trainer_input.experience_buffer = get_unittest_dataset_config("dpo")
self.config.check_and_update()
self.config.trainer.trainer_config.trainer.total_training_steps = 4
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 5e-7
train(self.config)
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
actor_metrics = parser.metric_list("actor")
self.assertTrue(len(actor_metrics) > 0)
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4)
ray.shutdown(_exiting_interpreter=True)
# check checkpoint

def tearDown(self):
# remove dir only when the test passed
shutil.rmtree(self.config.checkpoint_job_dir)
19 changes: 12 additions & 7 deletions trinity/algorithm/policy_loss_fn/dpo_loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""DPO loss function."""

from typing import Any, Dict, Tuple
from typing import Dict, List, Tuple

import torch
import torch.nn.functional as F
Expand All @@ -19,13 +19,11 @@ def __init__(
self.beta = beta
self.label_smoothing = label_smoothing

def __call__(
def __call__( # type: ignore
self,
logprob: torch.Tensor,
old_logprob: torch.Tensor,
ref_logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
experiences: Any,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
chosen_logprob = logprob[::2]
Expand All @@ -35,8 +33,8 @@ def __call__(
chosen_logprob_sum = masked_sum(chosen_logprob, chosen_mask)
rejected_logprob_sum = masked_sum(rejected_logprob, rejected_mask)

chosen_ref_logprob = old_logprob[::2]
rejected_ref_logprob = old_logprob[1::2]
chosen_ref_logprob = ref_logprob[::2]
rejected_ref_logprob = ref_logprob[1::2]
chosen_ref_logprob_sum = masked_sum(chosen_ref_logprob, chosen_mask)
rejected_ref_logprob_sum = masked_sum(rejected_ref_logprob, rejected_mask)

Expand Down Expand Up @@ -65,3 +63,10 @@ def default_args(cls) -> Dict:
"beta": 0.1,
"label_smoothing": 0.0,
}

@property
def select_keys(self) -> List[str]:
return [
"ref_logprob",
"action_mask",
]
15 changes: 11 additions & 4 deletions trinity/algorithm/policy_loss_fn/opmd_policy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Modified from https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
"""

from typing import Any, Dict, Tuple
from typing import Dict, List, Tuple

import torch

Expand All @@ -16,13 +16,12 @@ class OPMDPolicyLossFn(PolicyLossFn):
def __init__(self, tau: float = 1.0) -> None:
self.tau = tau

def __call__(
def __call__( # type: ignore
self,
logprob: torch.Tensor,
old_logprob: torch.Tensor,
old_logprob: torch.Tensor, # NOT USED!
action_mask: torch.Tensor,
advantages: torch.Tensor,
experiences: Any,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
pg_losses = -advantages * logprob
Expand All @@ -33,3 +32,11 @@ def __call__(
@classmethod
def default_args(cls) -> Dict:
return {"tau": 1.0}

@property
def select_keys(self) -> List[str]:
return [
"old_logprob",
"action_mask",
"advantages",
]
15 changes: 9 additions & 6 deletions trinity/algorithm/policy_loss_fn/policy_loss_fn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Tuple
from typing import Dict, List, Tuple

import torch

Expand All @@ -17,10 +17,6 @@ class PolicyLossFn(ABC):
def __call__(
self,
logprob: torch.Tensor,
old_logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
experiences: Any,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
"""
Expand All @@ -29,7 +25,6 @@ def __call__(
old_logprob (`torch.Tensor`): The log probability generated by the reference model.
action_mask (`torch.Tensor`): The action mask.
advantages (`torch.Tensor`): The advantages.
experiences (`DataProto`): The input experiences.
kwargs (`Dict`): The step-level parameters for calculating the policy loss.

Returns:
Expand All @@ -44,3 +39,11 @@ def default_args(cls) -> Dict:
Returns:
`Dict`: The default init arguments for the policy loss function.
"""

@property
@abstractmethod
def select_keys(self) -> List[str]:
"""
Returns:
`List[str]`: The keys to select from input data.
"""
13 changes: 10 additions & 3 deletions trinity/algorithm/policy_loss_fn/ppo_policy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Modified from https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py
"""

from typing import Any, Dict, Optional, Tuple
from typing import Dict, List, Optional, Tuple

import torch

Expand All @@ -30,13 +30,12 @@ def __init__(
assert self.clip_range_low is not None, "clip_range_low must be specified."
assert self.clip_range_high is not None, "clip_range_high must be specified."

def __call__(
def __call__( # type: ignore
self,
logprob: torch.Tensor,
old_logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
experiences: Any,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
negative_approx_kl = logprob - old_logprob
Expand All @@ -62,3 +61,11 @@ def default_args(cls) -> Dict:
return {
"clip_range": 0.2,
}

@property
def select_keys(self) -> List[str]:
return [
"old_logprob",
"action_mask",
"advantages",
]
11 changes: 6 additions & 5 deletions trinity/algorithm/policy_loss_fn/sft_loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""SFT loss function."""

from typing import Any, Dict, Tuple
from typing import Dict, List, Tuple

import torch

Expand All @@ -13,13 +13,10 @@ class SFTLossFn(PolicyLossFn):
def __init__(self, use_token_level_loss: bool = True) -> None:
self.use_token_level_loss = use_token_level_loss

def __call__(
def __call__( # type: ignore
self,
logprob: torch.Tensor,
old_logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
experiences: Any,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
if self.use_token_level_loss:
Expand All @@ -33,3 +30,7 @@ def default_args(cls):
return {
"use_token_level_loss": True,
}

@property
def select_keys(self) -> List[str]:
return ["action_mask"]
Loading