Skip to content

Commit 732d801

Browse files
authored
Add KL/Entorpy Fn (#64)
1 parent 9d582e8 commit 732d801

File tree

23 files changed

+361
-75
lines changed

23 files changed

+361
-75
lines changed

docs/sphinx_doc/source/tutorial/example_dpo.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ name: <experiment_name>
4848
mode: train
4949
algorithm:
5050
algorithm_type: dpo
51+
kl_loss_fn: k1
52+
kl_loss_fn_args:
53+
kl_coef: 0.1 # value of beta in DPO
5154
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
5255
model:
5356
model_path: /PATH/TO/MODEL/
@@ -70,8 +73,6 @@ buffer:
7073
trainer:
7174
trainer_config_path: 'examples/dpo_humanlike/train_dpo.yaml'
7275
save_interval: 30
73-
actor_use_kl_loss: True
74-
actor_kl_loss_coef: 0.1 # value of beta in DPO
7576
```
7677
7778
### Run the Experiment

examples/dpo_humanlike/dpo.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ name: "trinity_dpo"
33
mode: train
44
algorithm:
55
algorithm_type: dpo
6+
kl_loss_fn: k1
7+
kl_loss_fn_args:
8+
kl_coef: 0.1
69
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
710
model:
811
model_path: /PATH/TO/MODEL
@@ -34,5 +37,3 @@ trainer:
3437
trainer_type: 'verl'
3538
trainer_config_path: 'examples/dpo_humanlike/train_dpo.yaml'
3639
save_interval: 30
37-
actor_use_kl_loss: True
38-
actor_kl_loss_coef: 0.1

tests/template/config.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@ algorithm:
88
policy_loss_fn: ppo
99
policy_loss_fn_args:
1010
clip_range: 0.2
11-
advantage_fn_type: ppo_adv_fn
11+
advantage_fn: ppo
1212
advantage_fn_args:
1313
gamma: 1.0
1414
lam: 1.0
15+
kl_penalty_fn: k3
16+
kl_loss_fn: k2
1517

1618
model:
1719
model_path: ''

tests/trainer/trainer_test.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ def test_trainer(self):
6767
actor_metrics = parser.metric_list("actor")
6868
self.assertTrue(len(actor_metrics) > 0)
6969
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 8)
70+
actor_kl_metrics = parser.metric_list("actor/kl")
71+
self.assertTrue(len(actor_kl_metrics) > 0)
72+
critic_kl_metrics = parser.metric_list("critic/kl")
73+
self.assertTrue(len(critic_kl_metrics) > 0)
7074
response_metrics = parser.metric_list("response_length")
7175
self.assertTrue(len(response_metrics) > 0)
7276
self.assertEqual(parser.metric_max_step(response_metrics[0]), 8)
@@ -86,7 +90,7 @@ def test_trainer(self):
8690
)
8791
self.assertTrue(os.path.exists(checkpoint_step_4))
8892
self.assertTrue(os.path.exists(checkpoint_step_8))
89-
93+
# TODO: Reinit will fail when using v1 engine, find a way to fix it
9094
ray.init(ignore_reinit_error=True)
9195
# test bench mode
9296
self.config.mode = "bench"
@@ -118,7 +122,7 @@ def test_trainer(self):
118122
self.config.algorithm.algorithm_type = AlgorithmType.GRPO
119123
self.config.algorithm.repeat_times = 4
120124
# self.config.algorithm.repeat_times = 8 # TODO: used for real testing
121-
self.config.algorithm.advantage_fn_type = "grpo_adv_fn"
125+
self.config.algorithm.advantage_fn = "grpo"
122126
self.config.algorithm.advantage_fn_args = {}
123127
# self.config.buffer.batch_size = 96 # TODO: used for real testing
124128
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
@@ -143,8 +147,6 @@ def test_trainer(self):
143147
# self.assertTrue(0.4 < rewards[1] < 0.55)
144148
# self.assertTrue(0.6 < rewards[2] < 0.7)
145149
# self.assertTrue(0.6 < rewards[3] < 0.7)
146-
ray.shutdown(_exiting_interpreter=True)
147-
# check checkpoint
148150

149151
def tearDown(self):
150152
# remove dir only when the test passed
@@ -157,7 +159,7 @@ def test_trainer(self):
157159
# test both mode
158160
self.config.algorithm.algorithm_type = AlgorithmType.GRPO
159161
self.config.algorithm.repeat_times = 4
160-
self.config.algorithm.advantage_fn_type = "grpo_adv_fn"
162+
self.config.algorithm.advantage_fn = "grpo"
161163
self.config.algorithm.advantage_fn_args = {}
162164
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
163165
self.config.buffer.trainer_input.sft_warmup_steps = 2
@@ -180,8 +182,6 @@ def test_trainer(self):
180182
response_metrics = parser.metric_list("response_length")
181183
self.assertTrue(len(response_metrics) > 0)
182184
self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
183-
ray.shutdown(_exiting_interpreter=True)
184-
# check checkpoint
185185

186186
def tearDown(self):
187187
# remove dir only when the test passed
@@ -207,8 +207,6 @@ def test_trainer(self):
207207
actor_metrics = parser.metric_list("actor")
208208
self.assertTrue(len(actor_metrics) > 0)
209209
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4)
210-
ray.shutdown(_exiting_interpreter=True)
211-
# check checkpoint
212210

213211
def tearDown(self):
214212
# remove dir only when the test passed

trinity/algorithm/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
2+
from trinity.algorithm.entropy_loss_fn import ENTROPY_LOSS_FN, EntropyLossFn
3+
from trinity.algorithm.kl_fn import KL_FN, KLFn
24
from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
35

46
__all__ = [
57
"AdvantageFn",
68
"ADVANTAGE_FN",
79
"PolicyLossFn",
810
"POLICY_LOSS_FN",
11+
"KLFn",
12+
"KL_FN",
13+
"EntropyLossFn",
14+
"ENTROPY_LOSS_FN",
915
]

trinity/algorithm/advantage_fn/grpo_advantage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from trinity.trainer.verl import core_algos
1212

1313

14-
@ADVANTAGE_FN.register_module("grpo_adv_fn")
14+
@ADVANTAGE_FN.register_module("grpo")
1515
class GRPOAdvantageFn(AdvantageFn):
1616
"""GRPO advantage computation"""
1717

trinity/algorithm/advantage_fn/opmd_advantage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from trinity.trainer.verl import core_algos
1212

1313

14-
@ADVANTAGE_FN.register_module("opmd_adv_fn")
14+
@ADVANTAGE_FN.register_module("opmd")
1515
class OPMDAdvantageFn(AdvantageFn):
1616
"""OPMD advantage computation"""
1717

trinity/algorithm/advantage_fn/ppo_advantage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from trinity.trainer.verl import core_algos
1212

1313

14-
@ADVANTAGE_FN.register_module("ppo_adv_fn")
14+
@ADVANTAGE_FN.register_module("ppo")
1515
class PPOAdvantageFn(AdvantageFn):
1616
def __init__(
1717
self,

trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from trinity.trainer.verl import core_algos
1212

1313

14-
@ADVANTAGE_FN.register_module("reinforceplusplus_adv_fn")
14+
@ADVANTAGE_FN.register_module("reinforceplusplus")
1515
class REINFORCEPLUSPLUSAdvantageFn(AdvantageFn):
1616
def __init__(self, gamma: float = 1.0) -> None:
1717
self.gamma = gamma

trinity/algorithm/advantage_fn/remax_advantage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from trinity.trainer.verl import core_algos
1212

1313

14-
@ADVANTAGE_FN.register_module("remax_adv_fn")
14+
@ADVANTAGE_FN.register_module("remax")
1515
class REMAXAdvantageFn(AdvantageFn):
1616
def __init__(self) -> None:
1717
pass

0 commit comments

Comments
 (0)