Skip to content

Commit 59c3670

Browse files
committed
[WIP] refactor algorithm_type
1 parent cd4e85e commit 59c3670

File tree

7 files changed

+265
-113
lines changed

7 files changed

+265
-113
lines changed

trinity/algorithm/algorithm.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Dict
3+
4+
from trinity.common.experience import Experience, Experiences
5+
from trinity.utils.registry import Registry
6+
7+
ALGORITHM = Registry("algorithm")
8+
9+
10+
class Algorithm(ABC):
11+
use_critic: bool
12+
use_reference: bool
13+
use_advantage: bool
14+
can_balance_batch: bool
15+
16+
@classmethod
17+
def gather_experience(cls, exps: list[Experience], pad_token_id: int = 0) -> Experiences:
18+
return Experiences.gather_experiences(exps, pad_token_id)
19+
20+
@abstractmethod
21+
@classmethod
22+
def get_default_config(cls) -> Dict:
23+
pass
24+
25+
26+
@ALGORITHM.register_module("sft")
27+
class SFTAlgorithm(Algorithm):
28+
"""SFT Algorithm."""
29+
30+
use_critic: bool = False
31+
use_reference: bool = False
32+
use_advantage: bool = False
33+
can_balance_batch: bool = True
34+
35+
@classmethod
36+
def get_default_config(cls) -> Dict:
37+
return {
38+
"policy_loss_fn": "sft",
39+
"kl_loss_fn": "none",
40+
"entropy_loss_fn": "basic",
41+
}
42+
43+
44+
@ALGORITHM.register_module("ppo")
45+
class PPOAlgorithm(Algorithm):
46+
"""PPO Algorithm."""
47+
48+
use_critic: bool = True
49+
use_reference: bool = True
50+
use_advantage: bool = True
51+
can_balance_batch: bool = True
52+
53+
@classmethod
54+
def get_default_config(cls) -> Dict:
55+
return {
56+
"repeat_times": 1,
57+
"policy_loss_fn": "ppo",
58+
"advantage_fn": "ppo",
59+
"kl_penalty_fn": "k3",
60+
"kl_loss_fn": "k2",
61+
"entropy_loss_fn": "basic",
62+
}
63+
64+
65+
@ALGORITHM.register_module("grpo")
66+
class GRPOAlgorithm(Algorithm):
67+
"""GRPO algorithm."""
68+
69+
use_critic: bool = False
70+
use_reference: bool = True
71+
use_advantage: bool = True
72+
can_balance_batch: bool = True
73+
74+
@classmethod
75+
def get_default_config(cls) -> Dict:
76+
return {
77+
"repeat_times": 2,
78+
"policy_loss_fn": "ppo",
79+
"advantage_fn": "grpo",
80+
"kl_penalty_fn": "k3",
81+
"kl_loss_fn": "k2",
82+
"entropy_loss_fn": "basic",
83+
}
84+
85+
86+
@ALGORITHM.register_module("opmd")
87+
class OPMDAlgorithm(Algorithm):
88+
"""OPMD algorithm."""
89+
90+
use_critic: bool = False
91+
use_reference: bool = True
92+
use_advantage: bool = True
93+
can_balance_batch: bool = True
94+
95+
@classmethod
96+
def get_default_config(cls) -> Dict:
97+
return {
98+
"repeat_times": 2,
99+
"policy_loss_fn": "opmd",
100+
"advantage_fn": "opmd",
101+
"kl_penalty_fn": "k3",
102+
"kl_loss_fn": "k2",
103+
"entropy_loss_fn": "basic",
104+
}
105+
106+
107+
@ALGORITHM.register_module("dpo")
108+
class DPOAlgorithm(Algorithm):
109+
"""DPO algorithm."""
110+
111+
use_critic: bool = False
112+
use_reference: bool = True
113+
use_advantage: bool = False
114+
can_balance_batch: bool = False
115+
116+
@classmethod
117+
def gather_experience(cls, exps: list[Experience], pad_token_id: int = 0) -> Experiences:
118+
return Experiences.gather_dpo_experiences(exps, pad_token_id)
119+
120+
@classmethod
121+
def get_default_config(cls) -> Dict:
122+
return {
123+
"repeat_times": 2, # fake repeat times
124+
"policy_loss_fn": "dpo",
125+
"kl_loss_fn": "k2",
126+
"entropy_loss_fn": "basic",
127+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from trinity.algorithm.algorithm import ALGORITHM
2+
from trinity.common.config import AlgorithmConfig, Config
3+
4+
5+
class AlgorithmManager:
6+
def __init__(self, config: Config):
7+
self.config = config
8+
sft_type = ALGORITHM.get("sft")
9+
sft_default_config = sft_type.get_default_config()
10+
self.sft_algorithm_config = AlgorithmConfig(
11+
algorithm_type=sft_type,
12+
**sft_default_config,
13+
)
14+
15+
def get_current_algorithm_config(self, global_steps: int):
16+
if global_steps <= self.config.buffer.trainer_input.sft_warmup_steps:
17+
return self.sft_algorithm_config
18+
else:
19+
return self.config.algorithm.algorithm_type
20+
21+
def need_save(self, global_steps: int):
22+
return global_steps == self.config.buffer.trainer_input.sft_warmup_steps

trinity/common/config.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
"""Configs for RFT."""
33
import os
44
from dataclasses import dataclass, field
5-
from typing import Any, Dict, List, Optional
5+
from typing import Any, Dict, List, Optional, Union
66

77
from omegaconf import OmegaConf
88

9+
from trinity.algorithm.algorithm import ALGORITHM, Algorithm
10+
from trinity.algorithm.algorithm_manager import AlgorithmManager
911
from trinity.common.constants import (
1012
AlgorithmType,
1113
MonitorType,
@@ -170,34 +172,43 @@ class InferenceModelConfig:
170172
class AlgorithmConfig:
171173
"""Config for algorithm."""
172174

173-
algorithm_type: AlgorithmType = AlgorithmType.PPO
175+
algorithm_type: Union[str, Algorithm] = "ppo"
174176
# for GRPO-like algorithms, repeat each task for `repeat_times` times
175177
repeat_times: int = 1
176178

177-
policy_loss_fn: str = "ppo"
179+
policy_loss_fn: str = None # "ppo"
178180
# If not set, use PolicyLossFn.default_args()
179181
policy_loss_fn_args: Optional[dict] = None
180182

181-
advantage_fn: str = "ppo"
183+
advantage_fn: str = None # "ppo"
182184
# If not set, use AdvantageFn.default_args()
183185
advantage_fn_args: Optional[dict] = None
184186

185-
kl_penalty_fn: str = "none" # set to "none" to disable kl penalty in reward
187+
kl_penalty_fn: str = None # "none" # set to "none" to disable kl penalty in reward
186188
# If not set, use kl_penalty_fn.default_args()
187189
kl_penalty_fn_args: Optional[dict] = None
188190

189-
kl_loss_fn: str = "k2" # set to "none" to disable kl loss
191+
kl_loss_fn: str = None # "k2" # set to "none" to disable kl loss
190192
# If not set, use kl_loss_fn.default_args()
191193
kl_loss_fn_args: Optional[dict] = None
192194

193-
entropy_loss_fn: str = "basic"
195+
entropy_loss_fn: str = None # "basic"
194196
# If not set, use entropy_loss_fn.default_args()
195197
entropy_loss_fn_args: Optional[dict] = None
196198

197199
# used for SFT warmup
198200
# TODO: move this to SFT warmup
199201
use_token_level_loss: bool = True
200202

203+
# do not set
204+
algorithm_manager: Optional[AlgorithmManager] = None
205+
206+
def get_current_algorithm_config(self, global_steps: int):
207+
return self.algorithm_manager.get_current_algorithm_config(global_steps)
208+
209+
def need_save(self, global_steps: int):
210+
return self.algorithm_manager.need_save(global_steps)
211+
201212

202213
@dataclass
203214
class ClusterConfig:
@@ -492,6 +503,12 @@ def _check_algorithm(self) -> None:
492503
POLICY_LOSS_FN,
493504
)
494505

506+
self.algorithm.algorithm_manager = AlgorithmManager(self)
507+
self.algorithm.algorithm_type = ALGORITHM.get(self.algorithm.algorithm_type)
508+
for key, value in self.algorithm.algorithm_type.get_default_config().items():
509+
if getattr(self.algorithm, key, None) is None:
510+
setattr(self.algorithm, key, value)
511+
495512
policy_fn_cls = POLICY_LOSS_FN.get(self.algorithm.policy_loss_fn)
496513
if policy_fn_cls is None:
497514
raise ValueError(f"Invalid policy_loss_fn: {self.algorithm.policy_loss_fn}")

trinity/common/constants.py

Lines changed: 58 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -62,64 +62,64 @@ class StorageType(CaseInsensitiveEnum):
6262
FILE = "file"
6363

6464

65-
class AlgorithmType(CaseInsensitiveEnum):
66-
"""Algorithm Type."""
67-
68-
SFT = "sft"
69-
PPO = "ppo"
70-
GRPO = "grpo"
71-
OPMD = "opmd"
72-
DPO = "dpo"
73-
74-
def is_rft(self) -> bool:
75-
"""Check if the algorithm is RFT."""
76-
return self in [
77-
AlgorithmType.PPO,
78-
AlgorithmType.GRPO,
79-
AlgorithmType.OPMD,
80-
]
81-
82-
def is_sft(self) -> bool:
83-
"""Check if the algorithm is SFT."""
84-
return self == AlgorithmType.SFT
85-
86-
def is_dpo(self) -> bool:
87-
"""Check if the algorithm is DPO."""
88-
return self == AlgorithmType.DPO
89-
90-
@property
91-
def use_critic(self) -> bool:
92-
"""Check if the algorithm uses critic."""
93-
return self == AlgorithmType.PPO
94-
95-
@property
96-
def use_reference(self) -> bool:
97-
"""Check if the algorithm uses reference."""
98-
return self in {
99-
AlgorithmType.PPO,
100-
AlgorithmType.GRPO,
101-
AlgorithmType.OPMD,
102-
AlgorithmType.DPO,
103-
}
104-
105-
@property
106-
def use_advantage(self) -> bool:
107-
"""Check if the algorithm uses advantage."""
108-
return self in {
109-
AlgorithmType.PPO,
110-
AlgorithmType.GRPO,
111-
AlgorithmType.OPMD,
112-
}
113-
114-
@property
115-
def can_balance_batch(self) -> bool:
116-
"""Check if the algorithm can balance batch."""
117-
return self in {
118-
AlgorithmType.SFT,
119-
AlgorithmType.PPO,
120-
AlgorithmType.GRPO,
121-
AlgorithmType.OPMD,
122-
}
65+
# class AlgorithmType(CaseInsensitiveEnum):
66+
# """Algorithm Type."""
67+
68+
# SFT = "sft"
69+
# PPO = "ppo"
70+
# GRPO = "grpo"
71+
# OPMD = "opmd"
72+
# DPO = "dpo"
73+
74+
# def is_rft(self) -> bool:
75+
# """Check if the algorithm is RFT."""
76+
# return self in [
77+
# AlgorithmType.PPO,
78+
# AlgorithmType.GRPO,
79+
# AlgorithmType.OPMD,
80+
# ]
81+
82+
# def is_sft(self) -> bool:
83+
# """Check if the algorithm is SFT."""
84+
# return self == AlgorithmType.SFT
85+
86+
# def is_dpo(self) -> bool:
87+
# """Check if the algorithm is DPO."""
88+
# return self == AlgorithmType.DPO
89+
90+
# @property
91+
# def use_critic(self) -> bool:
92+
# """Check if the algorithm uses critic."""
93+
# return self == AlgorithmType.PPO
94+
95+
# @property
96+
# def use_reference(self) -> bool:
97+
# """Check if the algorithm uses reference."""
98+
# return self in {
99+
# AlgorithmType.PPO,
100+
# AlgorithmType.GRPO,
101+
# AlgorithmType.OPMD,
102+
# AlgorithmType.DPO,
103+
# }
104+
105+
# @property
106+
# def use_advantage(self) -> bool:
107+
# """Check if the algorithm uses advantage."""
108+
# return self in {
109+
# AlgorithmType.PPO,
110+
# AlgorithmType.GRPO,
111+
# AlgorithmType.OPMD,
112+
# }
113+
114+
# @property
115+
# def can_balance_batch(self) -> bool:
116+
# """Check if the algorithm can balance batch."""
117+
# return self in {
118+
# AlgorithmType.SFT,
119+
# AlgorithmType.PPO,
120+
# AlgorithmType.GRPO,
121+
# AlgorithmType.OPMD,
122+
# }
123123

124124

125125
class MonitorType(CaseInsensitiveEnum):

0 commit comments

Comments
 (0)