Skip to content

Commit fec7f3c

Browse files
authored
Refactor train step (#69)
1 parent 2d8f0c1 commit fec7f3c

File tree

22 files changed

+568
-1984
lines changed

22 files changed

+568
-1984
lines changed

tests/buffer/queue_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from trinity.buffer.reader.queue_reader import QueueReader
55
from trinity.buffer.writer.queue_writer import QueueWriter
66
from trinity.common.config import BufferConfig, StorageConfig
7-
from trinity.common.constants import AlgorithmType, StorageType
7+
from trinity.common.constants import StorageType
88
from trinity.common.experience import Experience
99

1010

@@ -15,7 +15,7 @@ def test_queue_buffer(self):
1515
read_batch_size = 4
1616
meta = StorageConfig(
1717
name="test_buffer",
18-
algorithm_type=AlgorithmType.PPO,
18+
algorithm_type="ppo",
1919
storage_type=StorageType.QUEUE,
2020
)
2121
config = BufferConfig(

tests/buffer/sql_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from trinity.buffer.reader.sql_reader import SQLReader
77
from trinity.buffer.writer.sql_writer import SQLWriter
88
from trinity.common.config import BufferConfig, StorageConfig
9-
from trinity.common.constants import AlgorithmType, StorageType
9+
from trinity.common.constants import StorageType
1010
from trinity.common.experience import Experience
1111

1212
db_path = os.path.join(os.path.dirname(__file__), "test.db")
@@ -19,7 +19,7 @@ def test_create_sql_buffer(self) -> None:
1919
read_batch_size = 4
2020
meta = StorageConfig(
2121
name="test_buffer",
22-
algorithm_type=AlgorithmType.PPO,
22+
algorithm_type="ppo",
2323
path=f"sqlite:///{db_path}",
2424
storage_type=StorageType.SQL,
2525
)

tests/explorer/runner_pool_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from tests.tools import get_unittest_dataset_config
1111
from trinity.buffer.reader.queue_reader import QueueReader
1212
from trinity.common.config import InferenceModelConfig, StorageConfig, load_config
13-
from trinity.common.constants import AlgorithmType, StorageType
13+
from trinity.common.constants import StorageType
1414
from trinity.common.experience import Experience
1515
from trinity.common.models.model import InferenceModel
1616
from trinity.common.workflows import Task
@@ -105,7 +105,7 @@ def setUp(self):
105105
) = StorageConfig(
106106
name="test",
107107
storage_type=StorageType.QUEUE,
108-
algorithm_type=AlgorithmType.PPO,
108+
algorithm_type="ppo",
109109
)
110110
self.queue = QueueReader(
111111
self.config.buffer.trainer_input.experience_buffer, self.config.buffer

tests/trainer/trainer_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
get_unittest_dataset_config,
1616
)
1717
from trinity.cli.launcher import bench, both, train
18-
from trinity.common.constants import AlgorithmType, MonitorType, SyncMethod
18+
from trinity.common.constants import MonitorType, SyncMethod
1919

2020

2121
class BaseTrainerCase(RayUnittestBase):
@@ -119,7 +119,7 @@ class TestTrainerGSM8K(BaseTrainerCase):
119119
def test_trainer(self):
120120
"""Test GSM8K."""
121121
# test both mode
122-
self.config.algorithm.algorithm_type = AlgorithmType.GRPO
122+
self.config.algorithm.algorithm_type = "grpo"
123123
self.config.algorithm.repeat_times = 4
124124
# self.config.algorithm.repeat_times = 8 # TODO: used for real testing
125125
self.config.algorithm.advantage_fn = "grpo"
@@ -157,7 +157,7 @@ class TestTrainerGSM8KWithSFT(BaseTrainerCase):
157157
def test_trainer(self):
158158
"""Test GSM8K With SFT."""
159159
# test both mode
160-
self.config.algorithm.algorithm_type = AlgorithmType.GRPO
160+
self.config.algorithm.algorithm_type = "grpo"
161161
self.config.algorithm.repeat_times = 4
162162
self.config.algorithm.advantage_fn = "grpo"
163163
self.config.algorithm.advantage_fn_args = {}
@@ -174,7 +174,7 @@ def test_trainer(self):
174174
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
175175
rollout_metrics = parser.metric_list("rollout")
176176
self.assertTrue(len(rollout_metrics) > 0)
177-
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2)
177+
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4)
178178
actor_metrics = parser.metric_list("actor")
179179
self.assertTrue(len(actor_metrics) > 0)
180180
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 2) # SFT
@@ -193,7 +193,7 @@ def test_trainer(self):
193193
"""Test DPO."""
194194
# test both mode
195195
self.config.mode = "train"
196-
self.config.algorithm.algorithm_type = AlgorithmType.DPO
196+
self.config.algorithm.algorithm_type = "dpo"
197197
self.config.algorithm.policy_loss_fn = "dpo"
198198
self.config.algorithm.policy_loss_fn_args = {}
199199
# self.config.buffer.batch_size = 32

trinity/algorithm/algorithm.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# -*- coding: utf-8 -*-
2+
"""Algorithm classes."""
3+
4+
from abc import ABC, ABCMeta
5+
from typing import Dict
6+
7+
from trinity.buffer.schema.sql_schema import DPODataModel, ExperienceModel, SFTDataModel
8+
from trinity.common.config import Config
9+
from trinity.common.constants import SyncMethod
10+
from trinity.common.experience import Experience, Experiences
11+
from trinity.utils.log import get_logger
12+
from trinity.utils.registry import Registry
13+
14+
logger = get_logger(__name__)
15+
16+
ALGORITHM_TYPE = Registry("algorithm")
17+
18+
19+
class ConstantMeta(ABCMeta):
20+
def __setattr__(cls, name, value):
21+
if name in cls.__dict__:
22+
raise AttributeError(f"{name} is already defined in {cls.__name__}")
23+
return super().__setattr__(name, value)
24+
25+
26+
class AlgorithmType(ABC, metaclass=ConstantMeta):
27+
use_critic: bool
28+
use_reference: bool
29+
use_advantage: bool
30+
use_rollout: bool
31+
can_balance_batch: bool
32+
schema: type
33+
34+
@classmethod
35+
def gather_experience(cls, exps: list[Experience], pad_token_id: int = 0) -> Experiences:
36+
return Experiences.gather_experiences(exps, pad_token_id)
37+
38+
@classmethod
39+
def get_default_config(cls) -> Dict:
40+
raise NotImplementedError
41+
42+
@classmethod
43+
def name(cls) -> str:
44+
return cls._name
45+
46+
@classmethod
47+
def check_config(cls, config: Config) -> None:
48+
pass
49+
50+
51+
@ALGORITHM_TYPE.register_module("sft")
52+
class SFTAlgorithm(AlgorithmType):
53+
"""SFT Algorithm."""
54+
55+
use_critic: bool = False
56+
use_reference: bool = False
57+
use_advantage: bool = False
58+
use_rollout: bool = False
59+
can_balance_batch: bool = True
60+
schema: type = SFTDataModel
61+
62+
@classmethod
63+
def get_default_config(cls) -> Dict:
64+
return {
65+
"policy_loss_fn": "sft",
66+
"kl_loss_fn": "none",
67+
"entropy_loss_fn": "none",
68+
}
69+
70+
71+
@ALGORITHM_TYPE.register_module("ppo")
72+
class PPOAlgorithm(AlgorithmType):
73+
"""PPO Algorithm."""
74+
75+
use_critic: bool = True
76+
use_reference: bool = True
77+
use_advantage: bool = True
78+
use_rollout: bool = True
79+
can_balance_batch: bool = True
80+
schema: type = ExperienceModel
81+
82+
@classmethod
83+
def get_default_config(cls) -> Dict:
84+
return {
85+
"repeat_times": 1,
86+
"policy_loss_fn": "ppo",
87+
"advantage_fn": "ppo",
88+
"kl_penalty_fn": "none",
89+
"kl_loss_fn": "k2",
90+
"entropy_loss_fn": "basic",
91+
}
92+
93+
94+
@ALGORITHM_TYPE.register_module("grpo")
95+
class GRPOAlgorithm(AlgorithmType):
96+
"""GRPO algorithm."""
97+
98+
use_critic: bool = False
99+
use_reference: bool = True
100+
use_advantage: bool = True
101+
use_rollout: bool = True
102+
can_balance_batch: bool = True
103+
schema: type = ExperienceModel
104+
105+
@classmethod
106+
def get_default_config(cls) -> Dict:
107+
return {
108+
"repeat_times": 2,
109+
"policy_loss_fn": "ppo",
110+
"advantage_fn": "grpo",
111+
"kl_penalty_fn": "none",
112+
"kl_loss_fn": "k2",
113+
"entropy_loss_fn": "basic",
114+
}
115+
116+
117+
@ALGORITHM_TYPE.register_module("opmd")
118+
class OPMDAlgorithm(AlgorithmType):
119+
"""OPMD algorithm."""
120+
121+
use_critic: bool = False
122+
use_reference: bool = True
123+
use_advantage: bool = True
124+
use_rollout: bool = True
125+
can_balance_batch: bool = True
126+
schema: type = ExperienceModel
127+
128+
@classmethod
129+
def get_default_config(cls) -> Dict:
130+
return {
131+
"repeat_times": 2,
132+
"policy_loss_fn": "opmd",
133+
"advantage_fn": "opmd",
134+
"kl_penalty_fn": "none",
135+
"kl_loss_fn": "k2",
136+
"entropy_loss_fn": "basic",
137+
}
138+
139+
140+
@ALGORITHM_TYPE.register_module("dpo")
141+
class DPOAlgorithm(AlgorithmType):
142+
"""DPO algorithm."""
143+
144+
use_critic: bool = False
145+
use_reference: bool = True
146+
use_advantage: bool = False
147+
use_rollout: bool = False
148+
can_balance_batch: bool = False
149+
schema: type = DPODataModel
150+
151+
@classmethod
152+
def gather_experience(cls, exps: list[Experience], pad_token_id: int = 0) -> Experiences:
153+
return Experiences.gather_dpo_experiences(exps, pad_token_id)
154+
155+
@classmethod
156+
def get_default_config(cls) -> Dict:
157+
return {
158+
"repeat_times": 2, # fake repeat times
159+
"policy_loss_fn": "dpo",
160+
"kl_loss_fn": "k2",
161+
"entropy_loss_fn": "basic",
162+
}
163+
164+
@classmethod
165+
def check_config(cls, config: Config) -> None:
166+
if config.model == "train":
167+
if (
168+
config.buffer.trainer_input.experience_buffer is None
169+
or not config.buffer.trainer_input.experience_buffer.path
170+
):
171+
raise ValueError(
172+
"`buffer.trainer_input.experience_buffer.path` is required when `algorithm.algorithm_type == dpo`"
173+
)
174+
elif config.mode in ["both", "explore"]:
175+
raise ValueError(f"DPO does not support `{config.mode}` mode")
176+
177+
if config.synchronizer.sync_method != SyncMethod.CHECKPOINT:
178+
config.synchronizer.sync_method = SyncMethod.CHECKPOINT
179+
logger.warning(
180+
"DPO only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`."
181+
)
182+
if config.algorithm.repeat_times != 2:
183+
config.algorithm.repeat_times = 2
184+
logger.warning(
185+
"DPO only supports 2 repeat times, set `algorithm.repeat_times` to 2."
186+
) # no need to warn
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# -*- coding: utf-8 -*-
2+
"""AlgorithmManager for switching between SFT and RFT."""
3+
4+
from trinity.algorithm.algorithm import ALGORITHM_TYPE
5+
from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import ENTROPY_LOSS_FN
6+
from trinity.algorithm.kl_fn.kl_fn import KL_FN
7+
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN
8+
from trinity.common.config import AlgorithmConfig, Config
9+
10+
11+
class AlgorithmManager:
12+
def __init__(self, config: Config):
13+
self.config = config
14+
sft_type = ALGORITHM_TYPE.get("sft")
15+
sft_default_config = sft_type.get_default_config()
16+
self.sft_algorithm_config = AlgorithmConfig(
17+
algorithm_type="sft",
18+
**sft_default_config,
19+
)
20+
policy_fn_cls = POLICY_LOSS_FN.get(self.sft_algorithm_config.policy_loss_fn)
21+
self.sft_algorithm_config.policy_loss_fn_args = policy_fn_cls.default_args()
22+
kl_loss_fn_cls = KL_FN.get(self.sft_algorithm_config.kl_loss_fn)
23+
self.sft_algorithm_config.kl_loss_fn_args = kl_loss_fn_cls.default_args()
24+
entropy_loss_fn_cls = ENTROPY_LOSS_FN.get(self.sft_algorithm_config.entropy_loss_fn)
25+
self.sft_algorithm_config.entropy_loss_fn_args = entropy_loss_fn_cls.default_args()
26+
27+
def get_current_algorithm_config(self, global_steps: int):
28+
if global_steps <= self.config.buffer.trainer_input.sft_warmup_steps:
29+
return self.sft_algorithm_config
30+
else:
31+
return self.config.algorithm
32+
33+
def need_save(self, global_steps: int):
34+
return global_steps == self.config.buffer.trainer_input.sft_warmup_steps

trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,25 @@ def __call__(
6161
@classmethod
6262
def default_args(cls) -> Dict:
6363
return {"entropy_coef": 0.0}
64+
65+
66+
@ENTROPY_LOSS_FN.register_module("none")
67+
class DummyEntropyLossFn(EntropyLossFn):
68+
"""
69+
Dummy entropy loss function.
70+
"""
71+
72+
def __init__(self):
73+
pass
74+
75+
def __call__(
76+
self,
77+
entropy: torch.Tensor,
78+
action_mask: torch.Tensor,
79+
**kwargs,
80+
) -> Tuple[torch.Tensor, Dict]:
81+
return torch.tensor(0.0), {}
82+
83+
@classmethod
84+
def default_args(cls) -> Dict:
85+
return {}

trinity/algorithm/kl_fn/kl_fn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def default_args(cls):
102102

103103

104104
@KL_FN.register_module("none")
105-
class DummyFn(KLFn):
105+
class DummyKLFn(KLFn):
106106
"""
107107
Dummy KL function.
108108
"""

trinity/buffer/buffer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ def get_buffer_reader(storage_config: StorageConfig, buffer_config: BufferConfig
4141
elif storage_config.storage_type == StorageType.FILE:
4242
from trinity.buffer.reader.file_reader import FILE_READERS
4343

44-
file_read_type = storage_config.algorithm_type
45-
if file_read_type is not None:
46-
file_read_type = file_read_type.value
44+
algorithm_type = storage_config.algorithm_type
45+
if algorithm_type is not None:
46+
file_read_type = algorithm_type
4747
else:
4848
file_read_type = "rollout"
4949
return FILE_READERS.get(file_read_type)(storage_config, buffer_config)

trinity/buffer/reader/file_reader.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
import transformers
77
from datasets import load_dataset
88

9+
from trinity.algorithm.algorithm import DPOAlgorithm, SFTAlgorithm
910
from trinity.buffer.buffer_reader import BufferReader
1011
from trinity.common.config import BufferConfig, StorageConfig
11-
from trinity.common.constants import AlgorithmType, PromptType, ReadStrategy, TaskType
12+
from trinity.common.constants import PromptType, ReadStrategy, TaskType
1213
from trinity.common.experience import Experience
1314
from trinity.common.rewards import REWARD_FUNCTIONS
1415
from trinity.common.workflows import WORKFLOWS, Task
@@ -17,7 +18,7 @@
1718
FILE_READERS = Registry("file_readers")
1819

1920

20-
@FILE_READERS.register_module(AlgorithmType.SFT.value)
21+
@FILE_READERS.register_module(SFTAlgorithm.name())
2122
class SFTDataReader(BufferReader):
2223
"""Reader for SFT file data."""
2324

@@ -96,7 +97,7 @@ def read(self, strategy: Optional[ReadStrategy] = None) -> List:
9697
return exp_list
9798

9899

99-
@FILE_READERS.register_module(AlgorithmType.DPO.value)
100+
@FILE_READERS.register_module(DPOAlgorithm.name())
100101
class DPODataReader(BufferReader):
101102
def __init__(self, meta: StorageConfig, config: BufferConfig):
102103
self.split = meta.split

0 commit comments

Comments
 (0)