Skip to content

Commit fe217aa

Browse files
authored
Refactor advantage computation, and delete RayPPOTrainer.fit (#61)
1 parent 5cd6cb6 commit fe217aa

File tree

15 files changed

+342
-387
lines changed

15 files changed

+342
-387
lines changed

tests/template/config.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ algorithm:
88
policy_loss_fn: ppo
99
policy_loss_fn_args:
1010
clip_range: 0.2
11+
advantage_fn_type: ppo_adv_fn
12+
advantage_fn_args:
13+
gamma: 1.0
14+
lam: 1.0
15+
1116
model:
1217
model_path: ''
1318
max_prompt_tokens: 2048

trinity/algorithm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn
1+
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
22
from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
33

44
__all__ = [
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn
2+
from trinity.algorithm.advantage_fn.grpo_advantage import GRPOAdvantageFn
3+
from trinity.algorithm.advantage_fn.opmd_advantage import OPMDAdvantageFn
4+
from trinity.algorithm.advantage_fn.ppo_advantage import PPOAdvantageFn
5+
from trinity.algorithm.advantage_fn.reinforce_plus_plus_advantage import (
6+
REINFORCEPLUSPLUSAdvantageFn,
7+
)
8+
from trinity.algorithm.advantage_fn.remax_advantage import REMAXAdvantageFn
9+
from trinity.algorithm.advantage_fn.rloo_advantage import RLOOAdvantageFn
10+
11+
__all__ = [
12+
"ADVANTAGE_FN",
13+
"AdvantageFn",
14+
"PPOAdvantageFn",
15+
"GRPOAdvantageFn",
16+
"REINFORCEPLUSPLUSAdvantageFn",
17+
"REMAXAdvantageFn",
18+
"RLOOAdvantageFn",
19+
"OPMDAdvantageFn",
20+
]

trinity/algorithm/advantage_fn/advantage_fn.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@ def __call__(self, exps: Any, **kwargs: Dict) -> Tuple[Any, Dict]:
1616
kwargs (`Dict`): The step-level parameters for calculating advantages.
1717
1818
Returns:
19-
`Any`: The experiences with advantages.
19+
`DataProto`: The experiences with advantages.
2020
`Dict`: The metrics for logging.
2121
"""
22+
23+
@classmethod
24+
@abstractmethod
25+
def default_args(cls) -> Dict:
26+
"""
27+
Returns:
28+
`Dict`: The default init arguments for the advantage function.
29+
"""
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""GRPO advantage computation
2+
3+
Adapted from compute_advantage_ppo in original ray_trainer.py
4+
"""
5+
6+
from typing import Dict, Tuple
7+
8+
from verl import DataProto
9+
10+
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
11+
from trinity.trainer.verl import core_algos
12+
13+
14+
@ADVANTAGE_FN.register_module("grpo_adv_fn")
15+
class GRPOAdvantageFn(AdvantageFn):
16+
"""GRPO advantage computation"""
17+
18+
def __init__(self) -> None:
19+
pass
20+
21+
def __call__(
22+
self,
23+
exps: DataProto,
24+
**kwargs,
25+
) -> Tuple[DataProto, Dict]:
26+
advantages, returns = core_algos.compute_grpo_outcome_advantage(
27+
token_level_rewards=exps.batch["token_level_rewards"],
28+
eos_mask=exps.batch["response_mask"],
29+
index=exps.non_tensor_batch["uid"],
30+
)
31+
exps.batch["advantages"] = advantages
32+
exps.batch["returns"] = returns
33+
34+
metrics = {
35+
# TODO: add meaningful metrics
36+
}
37+
38+
return exps, metrics
39+
40+
@classmethod
41+
def default_args(cls) -> Dict:
42+
return {}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""OPMD advantage computation
2+
3+
Adapted from compute_advantage_opmd in original ray_trainer.py
4+
"""
5+
6+
from typing import Dict, Tuple
7+
8+
from verl import DataProto
9+
10+
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
11+
from trinity.trainer.verl import core_algos
12+
13+
14+
@ADVANTAGE_FN.register_module("opmd_adv_fn")
15+
class OPMDAdvantageFn(AdvantageFn):
16+
"""OPMD advantage computation"""
17+
18+
def __init__(self) -> None:
19+
pass
20+
21+
def __call__(
22+
self,
23+
exps: DataProto,
24+
**kwargs,
25+
) -> Tuple[DataProto, Dict]:
26+
advantages, returns = core_algos.compute_opmd_outcome_advantage(
27+
token_level_rewards=exps.batch["token_level_rewards"],
28+
eos_mask=exps.batch["response_mask"],
29+
# TODO (yanxi): check consistency with exps.batch["attention_mask"][:, -response_length:] in original implementation
30+
index=exps.non_tensor_batch["uid"],
31+
opmd_baseline="mean",
32+
tau=1.0,
33+
)
34+
exps.batch["advantages"] = advantages
35+
exps.batch["returns"] = returns
36+
37+
metrics = {
38+
# TODO: add meaningful metrics
39+
}
40+
41+
return exps, metrics
42+
43+
@classmethod
44+
def default_args(cls) -> Dict:
45+
return {}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""PPO's GAE advantage computation
2+
3+
Adapted from compute_advantage_ppo in original ray_trainer.py
4+
"""
5+
6+
from typing import Dict, Tuple
7+
8+
from verl import DataProto
9+
10+
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
11+
from trinity.trainer.verl import core_algos
12+
13+
14+
@ADVANTAGE_FN.register_module("ppo_adv_fn")
15+
class PPOAdvantageFn(AdvantageFn):
16+
def __init__(
17+
self,
18+
gamma: float = 1.0,
19+
lam: float = 1.0,
20+
) -> None:
21+
self.gamma = gamma
22+
self.lam = lam
23+
24+
def __call__(
25+
self,
26+
exps: DataProto,
27+
**kwargs,
28+
) -> Tuple[DataProto, Dict]:
29+
advantages, returns = core_algos.compute_gae_advantage_return(
30+
token_level_rewards=exps.batch["token_level_rewards"],
31+
values=exps.batch["values"],
32+
eos_mask=exps.batch["response_mask"],
33+
gamma=self.gamma,
34+
lam=self.lam,
35+
)
36+
exps.batch["advantages"] = advantages
37+
exps.batch["returns"] = returns
38+
39+
metrics = {
40+
# TODO: add meaningful metrics
41+
}
42+
43+
return exps, metrics
44+
45+
@classmethod
46+
def default_args(cls) -> Dict:
47+
return {
48+
"gamma": 1.0,
49+
"lam": 1.0,
50+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""REINFORCE++ advantage computation
2+
3+
Adapted from compute_advantage_ppo in original ray_trainer.py
4+
"""
5+
6+
from typing import Dict, Tuple
7+
8+
from verl import DataProto
9+
10+
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
11+
from trinity.trainer.verl import core_algos
12+
13+
14+
@ADVANTAGE_FN.register_module("reinforceplusplus_adv_fn")
15+
class REINFORCEPLUSPLUSAdvantageFn(AdvantageFn):
16+
def __init__(self, gamma: float = 1.0) -> None:
17+
self.gamma = gamma
18+
19+
def __call__(
20+
self,
21+
exps: DataProto,
22+
**kwargs,
23+
) -> Tuple[DataProto, Dict]:
24+
advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage(
25+
token_level_rewards=exps.batch["token_level_rewards"],
26+
eos_mask=exps.batch["response_mask"],
27+
gamma=self.gamma,
28+
)
29+
exps.batch["advantages"] = advantages
30+
exps.batch["returns"] = returns
31+
32+
metrics = {
33+
# TODO: add meaningful metrics
34+
}
35+
36+
return exps, metrics
37+
38+
@classmethod
39+
def default_args(cls) -> Dict:
40+
return {
41+
"gamma": 1.0,
42+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""REMAX advantage computation
2+
3+
Adapted from compute_advantage_ppo in original ray_trainer.py
4+
"""
5+
6+
from typing import Dict, Tuple
7+
8+
from verl import DataProto
9+
10+
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
11+
from trinity.trainer.verl import core_algos
12+
13+
14+
@ADVANTAGE_FN.register_module("remax_adv_fn")
15+
class REMAXAdvantageFn(AdvantageFn):
16+
def __init__(self) -> None:
17+
pass
18+
19+
def __call__(
20+
self,
21+
exps: DataProto,
22+
**kwargs,
23+
) -> Tuple[DataProto, Dict]:
24+
advantages, returns = core_algos.compute_remax_outcome_advantage(
25+
token_level_rewards=exps.batch["token_level_rewards"],
26+
reward_baselines=exps.batch["reward_baselines"],
27+
eos_mask=exps.batch["response_mask"],
28+
)
29+
exps.batch["advantages"] = advantages
30+
exps.batch["returns"] = returns
31+
32+
metrics = {
33+
# TODO: add meaningful metrics
34+
}
35+
36+
return exps, metrics
37+
38+
@classmethod
39+
def default_args(cls) -> Dict:
40+
return {}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""RLOO advantage computation
2+
3+
Adapted from compute_advantage_ppo in original ray_trainer.py
4+
"""
5+
6+
from typing import Dict, Tuple
7+
8+
from verl import DataProto
9+
10+
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn
11+
from trinity.trainer.verl import core_algos
12+
13+
14+
@ADVANTAGE_FN.register_module("rloo_adv_fn")
15+
class RLOOAdvantageFn(AdvantageFn):
16+
def __init__(self) -> None:
17+
pass
18+
19+
def __call__(
20+
self,
21+
exps: DataProto,
22+
**kwargs,
23+
) -> Tuple[DataProto, Dict]:
24+
advantages, returns = core_algos.compute_rloo_outcome_advantage(
25+
token_level_rewards=exps.batch["token_level_rewards"],
26+
eos_mask=exps.batch["response_mask"],
27+
index=exps.non_tensor_batch["uid"],
28+
)
29+
exps.batch["advantages"] = advantages
30+
exps.batch["returns"] = returns
31+
32+
metrics = {
33+
# TODO: add meaningful metrics
34+
}
35+
36+
return exps, metrics
37+
38+
@classmethod
39+
def default_args(cls) -> Dict:
40+
return {}

0 commit comments

Comments
 (0)