Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tests/template/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ algorithm:
policy_loss_fn: ppo
policy_loss_fn_args:
clip_range: 0.2
advantage_fn_type: ppo_adv_fn
advantage_fn_args:
gamma: 1.0
lam: 1.0

model:
model_path: ''
max_prompt_tokens: 2048
Expand Down
2 changes: 1 addition & 1 deletion trinity/algorithm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn

__all__ = [
Expand Down
20 changes: 20 additions & 0 deletions trinity/algorithm/advantage_fn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn
from trinity.algorithm.advantage_fn.grpo_advantage import GRPOAdvantageFn
from trinity.algorithm.advantage_fn.opmd_advantage import OPMDAdvantageFn
from trinity.algorithm.advantage_fn.ppo_advantage import PPOAdvantageFn
from trinity.algorithm.advantage_fn.reinforce_plus_plus_advantage import (
REINFORCEPLUSPLUSAdvantageFn,
)
from trinity.algorithm.advantage_fn.remax_advantage import REMAXAdvantageFn
from trinity.algorithm.advantage_fn.rloo_advantage import RLOOAdvantageFn

__all__ = [
"ADVANTAGE_FN",
"AdvantageFn",
"PPOAdvantageFn",
"GRPOAdvantageFn",
"REINFORCEPLUSPLUSAdvantageFn",
"REMAXAdvantageFn",
"RLOOAdvantageFn",
"OPMDAdvantageFn",
]
10 changes: 9 additions & 1 deletion trinity/algorithm/advantage_fn/advantage_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ def __call__(self, exps: Any, **kwargs: Dict) -> Tuple[Any, Dict]:
kwargs (`Dict`): The step-level parameters for calculating advantages.

Returns:
`Any`: The experiences with advantages.
`DataProto`: The experiences with advantages.
`Dict`: The metrics for logging.
"""

@classmethod
@abstractmethod
def default_args(cls) -> Dict:
"""
Returns:
`Dict`: The default init arguments for the advantage function.
"""
42 changes: 42 additions & 0 deletions trinity/algorithm/advantage_fn/grpo_advantage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""GRPO advantage computation

Adapted from compute_advantage_ppo in original ray_trainer.py
"""

from typing import Dict, Tuple

from verl import DataProto

from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
from trinity.trainer.verl import core_algos


@ADVANTAGE_FN.register_module("grpo_adv_fn")
class GRPOAdvantageFn(AdvantageFn):
"""GRPO advantage computation"""

def __init__(self) -> None:
pass

def __call__(
self,
exps: DataProto,
**kwargs,
) -> Tuple[DataProto, Dict]:
advantages, returns = core_algos.compute_grpo_outcome_advantage(
token_level_rewards=exps.batch["token_level_rewards"],
eos_mask=exps.batch["response_mask"],
index=exps.non_tensor_batch["uid"],
)
exps.batch["advantages"] = advantages
exps.batch["returns"] = returns

metrics = {
# TODO: add meaningful metrics
}

return exps, metrics

@classmethod
def default_args(cls) -> Dict:
return {}
45 changes: 45 additions & 0 deletions trinity/algorithm/advantage_fn/opmd_advantage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""OPMD advantage computation

Adapted from compute_advantage_opmd in original ray_trainer.py
"""

from typing import Dict, Tuple

from verl import DataProto

from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
from trinity.trainer.verl import core_algos


@ADVANTAGE_FN.register_module("opmd_adv_fn")
class OPMDAdvantageFn(AdvantageFn):
"""OPMD advantage computation"""

def __init__(self) -> None:
pass

def __call__(
self,
exps: DataProto,
**kwargs,
) -> Tuple[DataProto, Dict]:
advantages, returns = core_algos.compute_opmd_outcome_advantage(
token_level_rewards=exps.batch["token_level_rewards"],
eos_mask=exps.batch["response_mask"],
# TODO (yanxi): check consistency with exps.batch["attention_mask"][:, -response_length:] in original implementation
index=exps.non_tensor_batch["uid"],
opmd_baseline="mean",
tau=1.0,
)
exps.batch["advantages"] = advantages
exps.batch["returns"] = returns

metrics = {
# TODO: add meaningful metrics
}

return exps, metrics

@classmethod
def default_args(cls) -> Dict:
return {}
50 changes: 50 additions & 0 deletions trinity/algorithm/advantage_fn/ppo_advantage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""PPO's GAE advantage computation

Adapted from compute_advantage_ppo in original ray_trainer.py
"""

from typing import Dict, Tuple

from verl import DataProto

from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
from trinity.trainer.verl import core_algos


@ADVANTAGE_FN.register_module("ppo_adv_fn")
class PPOAdvantageFn(AdvantageFn):
def __init__(
self,
gamma: float = 1.0,
lam: float = 1.0,
) -> None:
self.gamma = gamma
self.lam = lam

def __call__(
self,
exps: DataProto,
**kwargs,
) -> Tuple[DataProto, Dict]:
advantages, returns = core_algos.compute_gae_advantage_return(
token_level_rewards=exps.batch["token_level_rewards"],
values=exps.batch["values"],
eos_mask=exps.batch["response_mask"],
gamma=self.gamma,
lam=self.lam,
)
exps.batch["advantages"] = advantages
exps.batch["returns"] = returns

metrics = {
# TODO: add meaningful metrics
}

return exps, metrics

@classmethod
def default_args(cls) -> Dict:
return {
"gamma": 1.0,
"lam": 1.0,
}
42 changes: 42 additions & 0 deletions trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""REINFORCE++ advantage computation

Adapted from compute_advantage_ppo in original ray_trainer.py
"""

from typing import Dict, Tuple

from verl import DataProto

from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
from trinity.trainer.verl import core_algos


@ADVANTAGE_FN.register_module("reinforceplusplus_adv_fn")
class REINFORCEPLUSPLUSAdvantageFn(AdvantageFn):
def __init__(self, gamma: float = 1.0) -> None:
self.gamma = gamma

def __call__(
self,
exps: DataProto,
**kwargs,
) -> Tuple[DataProto, Dict]:
advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage(
token_level_rewards=exps.batch["token_level_rewards"],
eos_mask=exps.batch["response_mask"],
gamma=self.gamma,
)
exps.batch["advantages"] = advantages
exps.batch["returns"] = returns

metrics = {
# TODO: add meaningful metrics
}

return exps, metrics

@classmethod
def default_args(cls) -> Dict:
return {
"gamma": 1.0,
}
40 changes: 40 additions & 0 deletions trinity/algorithm/advantage_fn/remax_advantage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""REMAX advantage computation

Adapted from compute_advantage_ppo in original ray_trainer.py
"""

from typing import Dict, Tuple

from verl import DataProto

from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
from trinity.trainer.verl import core_algos


@ADVANTAGE_FN.register_module("remax_adv_fn")
class REMAXAdvantageFn(AdvantageFn):
def __init__(self) -> None:
pass

def __call__(
self,
exps: DataProto,
**kwargs,
) -> Tuple[DataProto, Dict]:
advantages, returns = core_algos.compute_remax_outcome_advantage(
token_level_rewards=exps.batch["token_level_rewards"],
reward_baselines=exps.batch["reward_baselines"],
eos_mask=exps.batch["response_mask"],
)
exps.batch["advantages"] = advantages
exps.batch["returns"] = returns

metrics = {
# TODO: add meaningful metrics
}

return exps, metrics

@classmethod
def default_args(cls) -> Dict:
return {}
40 changes: 40 additions & 0 deletions trinity/algorithm/advantage_fn/rloo_advantage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""RLOO advantage computation

Adapted from compute_advantage_ppo in original ray_trainer.py
"""

from typing import Dict, Tuple

from verl import DataProto

from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
from trinity.trainer.verl import core_algos


@ADVANTAGE_FN.register_module("rloo_adv_fn")
class RLOOAdvantageFn(AdvantageFn):
def __init__(self) -> None:
pass

def __call__(
self,
exps: DataProto,
**kwargs,
) -> Tuple[DataProto, Dict]:
advantages, returns = core_algos.compute_rloo_outcome_advantage(
token_level_rewards=exps.batch["token_level_rewards"],
eos_mask=exps.batch["response_mask"],
index=exps.non_tensor_batch["uid"],
)
exps.batch["advantages"] = advantages
exps.batch["returns"] = returns

metrics = {
# TODO: add meaningful metrics
}

return exps, metrics

@classmethod
def default_args(cls) -> Dict:
return {}
14 changes: 11 additions & 3 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,15 @@ class AlgorithmConfig:
algorithm_type: AlgorithmType = AlgorithmType.PPO
# for GRPO-like algorithms, repeat each task for `repeat_times` times
repeat_times: int = 1
gamma: Optional[float] = None
lam: Optional[float] = None

policy_loss_fn: str = "ppo"
# If not set, use PolicyLossFn.default_args()
policy_loss_fn_args: Optional[dict] = None

advantage_fn_type: str = "ppo_adv_fn"
# If not set, use AdvantageFn.default_args()
advantage_fn_args: Optional[dict] = None


@dataclass
class ClusterConfig:
Expand Down Expand Up @@ -470,14 +472,20 @@ def _check_buffer(self) -> None: # noqa: C901
self.buffer.tokenizer_path = self.model.model_path

def _check_algorithm(self) -> None:
from trinity.algorithm import POLICY_LOSS_FN
from trinity.algorithm import ADVANTAGE_FN, POLICY_LOSS_FN

policy_fn_cls = POLICY_LOSS_FN.get(self.algorithm.policy_loss_fn)
if policy_fn_cls is None:
raise ValueError(f"Invalid policy_loss_fn: {self.algorithm.policy_loss_fn}")
if self.algorithm.policy_loss_fn_args is None:
self.algorithm.policy_loss_fn_args = policy_fn_cls.default_args()

advantage_fn_cls = ADVANTAGE_FN.get(self.algorithm.advantage_fn_type)
if advantage_fn_cls is None:
raise ValueError(f"Invalid advantage_fn_type: {self.algorithm.advantage_fn_type}")
if self.algorithm.advantage_fn_args is None:
self.algorithm.advantage_fn_args = advantage_fn_cls.default_args()

def check_and_update(self) -> None: # noqa: C901
"""Check and update the config."""
self._check_deprecated()
Expand Down
Loading