Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/buffer/queue_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from trinity.buffer.reader.queue_reader import QueueReader
from trinity.buffer.writer.queue_writer import QueueWriter
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import AlgorithmType, StorageType
from trinity.common.constants import StorageType
from trinity.common.experience import Experience


Expand All @@ -15,7 +15,7 @@ def test_queue_buffer(self):
read_batch_size = 4
meta = StorageConfig(
name="test_buffer",
algorithm_type=AlgorithmType.PPO,
algorithm_type="ppo",
storage_type=StorageType.QUEUE,
)
config = BufferConfig(
Expand Down
4 changes: 2 additions & 2 deletions tests/buffer/sql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from trinity.buffer.reader.sql_reader import SQLReader
from trinity.buffer.writer.sql_writer import SQLWriter
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import AlgorithmType, StorageType
from trinity.common.constants import StorageType
from trinity.common.experience import Experience

db_path = os.path.join(os.path.dirname(__file__), "test.db")
Expand All @@ -19,7 +19,7 @@ def test_create_sql_buffer(self) -> None:
read_batch_size = 4
meta = StorageConfig(
name="test_buffer",
algorithm_type=AlgorithmType.PPO,
algorithm_type="ppo",
path=f"sqlite:///{db_path}",
storage_type=StorageType.SQL,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/explorer/runner_pool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tests.tools import get_unittest_dataset_config
from trinity.buffer.reader.queue_reader import QueueReader
from trinity.common.config import InferenceModelConfig, StorageConfig, load_config
from trinity.common.constants import AlgorithmType, StorageType
from trinity.common.constants import StorageType
from trinity.common.experience import Experience
from trinity.common.models.model import InferenceModel
from trinity.common.workflows import Task
Expand Down Expand Up @@ -105,7 +105,7 @@ def setUp(self):
) = StorageConfig(
name="test",
storage_type=StorageType.QUEUE,
algorithm_type=AlgorithmType.PPO,
algorithm_type="ppo",
)
self.queue = QueueReader(
self.config.buffer.trainer_input.experience_buffer, self.config.buffer
Expand Down
10 changes: 5 additions & 5 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
get_unittest_dataset_config,
)
from trinity.cli.launcher import bench, both, train
from trinity.common.constants import AlgorithmType, MonitorType, SyncMethod
from trinity.common.constants import MonitorType, SyncMethod


class BaseTrainerCase(RayUnittestBase):
Expand Down Expand Up @@ -119,7 +119,7 @@ class TestTrainerGSM8K(BaseTrainerCase):
def test_trainer(self):
"""Test GSM8K."""
# test both mode
self.config.algorithm.algorithm_type = AlgorithmType.GRPO
self.config.algorithm.algorithm_type = "grpo"
self.config.algorithm.repeat_times = 4
# self.config.algorithm.repeat_times = 8 # TODO: used for real testing
self.config.algorithm.advantage_fn = "grpo"
Expand Down Expand Up @@ -157,7 +157,7 @@ class TestTrainerGSM8KWithSFT(BaseTrainerCase):
def test_trainer(self):
"""Test GSM8K With SFT."""
# test both mode
self.config.algorithm.algorithm_type = AlgorithmType.GRPO
self.config.algorithm.algorithm_type = "grpo"
self.config.algorithm.repeat_times = 4
self.config.algorithm.advantage_fn = "grpo"
self.config.algorithm.advantage_fn_args = {}
Expand All @@ -174,7 +174,7 @@ def test_trainer(self):
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
rollout_metrics = parser.metric_list("rollout")
self.assertTrue(len(rollout_metrics) > 0)
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2)
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4)
actor_metrics = parser.metric_list("actor")
self.assertTrue(len(actor_metrics) > 0)
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 2) # SFT
Expand All @@ -193,7 +193,7 @@ def test_trainer(self):
"""Test DPO."""
# test both mode
self.config.mode = "train"
self.config.algorithm.algorithm_type = AlgorithmType.DPO
self.config.algorithm.algorithm_type = "dpo"
self.config.algorithm.policy_loss_fn = "dpo"
self.config.algorithm.policy_loss_fn_args = {}
# self.config.buffer.batch_size = 32
Expand Down
186 changes: 186 additions & 0 deletions trinity/algorithm/algorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# -*- coding: utf-8 -*-
"""Algorithm classes."""

from abc import ABC, ABCMeta
from typing import Dict

from trinity.buffer.schema.sql_schema import DPODataModel, ExperienceModel, SFTDataModel
from trinity.common.config import Config
from trinity.common.constants import SyncMethod
from trinity.common.experience import Experience, Experiences
from trinity.utils.log import get_logger
from trinity.utils.registry import Registry

logger = get_logger(__name__)

ALGORITHM_TYPE = Registry("algorithm")


class ConstantMeta(ABCMeta):
def __setattr__(cls, name, value):
if name in cls.__dict__:
raise AttributeError(f"{name} is already defined in {cls.__name__}")
return super().__setattr__(name, value)


class AlgorithmType(ABC, metaclass=ConstantMeta):
use_critic: bool
use_reference: bool
use_advantage: bool
use_rollout: bool
can_balance_batch: bool
schema: type

@classmethod
def gather_experience(cls, exps: list[Experience], pad_token_id: int = 0) -> Experiences:
return Experiences.gather_experiences(exps, pad_token_id)

@classmethod
def get_default_config(cls) -> Dict:
raise NotImplementedError

@classmethod
def name(cls) -> str:
return cls._name

@classmethod
def check_config(cls, config: Config) -> None:
pass


@ALGORITHM_TYPE.register_module("sft")
class SFTAlgorithm(AlgorithmType):
"""SFT Algorithm."""

use_critic: bool = False
use_reference: bool = False
use_advantage: bool = False
use_rollout: bool = False
can_balance_batch: bool = True
schema: type = SFTDataModel

@classmethod
def get_default_config(cls) -> Dict:
return {
"policy_loss_fn": "sft",
"kl_loss_fn": "none",
"entropy_loss_fn": "none",
}


@ALGORITHM_TYPE.register_module("ppo")
class PPOAlgorithm(AlgorithmType):
"""PPO Algorithm."""

use_critic: bool = True
use_reference: bool = True
use_advantage: bool = True
use_rollout: bool = True
can_balance_batch: bool = True
schema: type = ExperienceModel

@classmethod
def get_default_config(cls) -> Dict:
return {
"repeat_times": 1,
"policy_loss_fn": "ppo",
"advantage_fn": "ppo",
"kl_penalty_fn": "none",
"kl_loss_fn": "k2",
"entropy_loss_fn": "basic",
}


@ALGORITHM_TYPE.register_module("grpo")
class GRPOAlgorithm(AlgorithmType):
"""GRPO algorithm."""

use_critic: bool = False
use_reference: bool = True
use_advantage: bool = True
use_rollout: bool = True
can_balance_batch: bool = True
schema: type = ExperienceModel

@classmethod
def get_default_config(cls) -> Dict:
return {
"repeat_times": 2,
"policy_loss_fn": "ppo",
"advantage_fn": "grpo",
"kl_penalty_fn": "none",
"kl_loss_fn": "k2",
"entropy_loss_fn": "basic",
}


@ALGORITHM_TYPE.register_module("opmd")
class OPMDAlgorithm(AlgorithmType):
"""OPMD algorithm."""

use_critic: bool = False
use_reference: bool = True
use_advantage: bool = True
use_rollout: bool = True
can_balance_batch: bool = True
schema: type = ExperienceModel

@classmethod
def get_default_config(cls) -> Dict:
return {
"repeat_times": 2,
"policy_loss_fn": "opmd",
"advantage_fn": "opmd",
"kl_penalty_fn": "none",
"kl_loss_fn": "k2",
"entropy_loss_fn": "basic",
}


@ALGORITHM_TYPE.register_module("dpo")
class DPOAlgorithm(AlgorithmType):
"""DPO algorithm."""

use_critic: bool = False
use_reference: bool = True
use_advantage: bool = False
use_rollout: bool = False
can_balance_batch: bool = False
schema: type = DPODataModel

@classmethod
def gather_experience(cls, exps: list[Experience], pad_token_id: int = 0) -> Experiences:
return Experiences.gather_dpo_experiences(exps, pad_token_id)

@classmethod
def get_default_config(cls) -> Dict:
return {
"repeat_times": 2, # fake repeat times
"policy_loss_fn": "dpo",
"kl_loss_fn": "k2",
"entropy_loss_fn": "basic",
}

@classmethod
def check_config(cls, config: Config) -> None:
if config.model == "train":
if (
config.buffer.trainer_input.experience_buffer is None
or not config.buffer.trainer_input.experience_buffer.path
):
raise ValueError(
"`buffer.trainer_input.experience_buffer.path` is required when `algorithm.algorithm_type == dpo`"
)
elif config.mode in ["both", "explore"]:
raise ValueError(f"DPO does not support `{config.mode}` mode")

if config.synchronizer.sync_method != SyncMethod.CHECKPOINT:
config.synchronizer.sync_method = SyncMethod.CHECKPOINT
logger.warning(
"DPO only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`."
)
if config.algorithm.repeat_times != 2:
config.algorithm.repeat_times = 2
logger.warning(
"DPO only supports 2 repeat times, set `algorithm.repeat_times` to 2."
) # no need to warn
34 changes: 34 additions & 0 deletions trinity/algorithm/algorithm_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
"""AlgorithmManager for switching between SFT and RFT."""

from trinity.algorithm.algorithm import ALGORITHM_TYPE
from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import ENTROPY_LOSS_FN
from trinity.algorithm.kl_fn.kl_fn import KL_FN
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN
from trinity.common.config import AlgorithmConfig, Config


class AlgorithmManager:
def __init__(self, config: Config):
self.config = config
sft_type = ALGORITHM_TYPE.get("sft")
sft_default_config = sft_type.get_default_config()
self.sft_algorithm_config = AlgorithmConfig(
algorithm_type="sft",
**sft_default_config,
)
policy_fn_cls = POLICY_LOSS_FN.get(self.sft_algorithm_config.policy_loss_fn)
self.sft_algorithm_config.policy_loss_fn_args = policy_fn_cls.default_args()
kl_loss_fn_cls = KL_FN.get(self.sft_algorithm_config.kl_loss_fn)
self.sft_algorithm_config.kl_loss_fn_args = kl_loss_fn_cls.default_args()
entropy_loss_fn_cls = ENTROPY_LOSS_FN.get(self.sft_algorithm_config.entropy_loss_fn)
self.sft_algorithm_config.entropy_loss_fn_args = entropy_loss_fn_cls.default_args()

def get_current_algorithm_config(self, global_steps: int):
if global_steps <= self.config.buffer.trainer_input.sft_warmup_steps:
return self.sft_algorithm_config
else:
return self.config.algorithm

def need_save(self, global_steps: int):
return global_steps == self.config.buffer.trainer_input.sft_warmup_steps
22 changes: 22 additions & 0 deletions trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,25 @@ def __call__(
@classmethod
def default_args(cls) -> Dict:
return {"entropy_coef": 0.0}


@ENTROPY_LOSS_FN.register_module("none")
class DummyEntropyLossFn(EntropyLossFn):
"""
Dummy entropy loss function.
"""

def __init__(self):
pass

def __call__(
self,
entropy: torch.Tensor,
action_mask: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
return torch.tensor(0.0), {}

@classmethod
def default_args(cls) -> Dict:
return {}
2 changes: 1 addition & 1 deletion trinity/algorithm/kl_fn/kl_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def default_args(cls):


@KL_FN.register_module("none")
class DummyFn(KLFn):
class DummyKLFn(KLFn):
"""
Dummy KL function.
"""
Expand Down
6 changes: 3 additions & 3 deletions trinity/buffer/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ def get_buffer_reader(storage_config: StorageConfig, buffer_config: BufferConfig
elif storage_config.storage_type == StorageType.FILE:
from trinity.buffer.reader.file_reader import FILE_READERS

file_read_type = storage_config.algorithm_type
if file_read_type is not None:
file_read_type = file_read_type.value
algorithm_type = storage_config.algorithm_type
if algorithm_type is not None:
file_read_type = algorithm_type
else:
file_read_type = "rollout"
return FILE_READERS.get(file_read_type)(storage_config, buffer_config)
Expand Down
7 changes: 4 additions & 3 deletions trinity/buffer/reader/file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import transformers
from datasets import load_dataset

from trinity.algorithm.algorithm import DPOAlgorithm, SFTAlgorithm
from trinity.buffer.buffer_reader import BufferReader
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import AlgorithmType, PromptType, ReadStrategy, TaskType
from trinity.common.constants import PromptType, ReadStrategy, TaskType
from trinity.common.experience import Experience
from trinity.common.rewards import REWARD_FUNCTIONS
from trinity.common.workflows import WORKFLOWS, Task
Expand All @@ -17,7 +18,7 @@
FILE_READERS = Registry("file_readers")


@FILE_READERS.register_module(AlgorithmType.SFT.value)
@FILE_READERS.register_module(SFTAlgorithm.name())
class SFTDataReader(BufferReader):
"""Reader for SFT file data."""

Expand Down Expand Up @@ -96,7 +97,7 @@ def read(self, strategy: Optional[ReadStrategy] = None) -> List:
return exp_list


@FILE_READERS.register_module(AlgorithmType.DPO.value)
@FILE_READERS.register_module(DPOAlgorithm.name())
class DPODataReader(BufferReader):
def __init__(self, meta: StorageConfig, config: BufferConfig):
self.split = meta.split
Expand Down
Loading