Skip to content

Commit d7c43fe

Browse files
authored
Init Algorithm Module (#58)
1 parent 318da40 commit d7c43fe

File tree

7 files changed

+68
-0
lines changed

7 files changed

+68
-0
lines changed

trinity/algorithm/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn
2+
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
3+
4+
__all__ = [
5+
"AdvantageFn",
6+
"ADVANTAGE_FN",
7+
"PolicyLossFn",
8+
"POLICY_LOSS_FN",
9+
]

trinity/algorithm/advantage_fn/__init__.py

Whitespace-only changes.
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any, Dict, Tuple
3+
4+
from trinity.utils.registry import Registry
5+
6+
ADVANTAGE_FN = Registry("advantage_fn")
7+
8+
9+
class AdvantageFn(ABC):
10+
@abstractmethod
11+
def __call__(self, exps: Any, **kwargs: Dict) -> Tuple[Any, Dict]:
12+
"""Calculate advantages from experiences
13+
14+
Args:
15+
exps (`DataProto`): The input experiences.
16+
kwargs (`Dict`): The step-level parameters for calculating advantages.
17+
18+
Returns:
19+
`Any`: The experiences with advantages.
20+
`Dict`: The metrics for logging.
21+
"""

trinity/algorithm/entropy_loss/__init__.py

Whitespace-only changes.

trinity/algorithm/kl_loss/__init__.py

Whitespace-only changes.

trinity/algorithm/policy_loss_fn/__init__.py

Whitespace-only changes.
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any, Dict, Tuple
3+
4+
import torch
5+
6+
from trinity.utils.registry import Registry
7+
8+
POLICY_LOSS_FN = Registry("policy_loss_fn")
9+
10+
11+
class PolicyLossFn(ABC):
12+
"""
13+
Policy Loss Function
14+
"""
15+
16+
@abstractmethod
17+
def __call__(
18+
self,
19+
logprob: torch.Tensor,
20+
old_logprob: torch.Tensor,
21+
action_mask: torch.Tensor,
22+
advantages: torch.Tensor,
23+
experiences: Any,
24+
**kwargs,
25+
) -> Tuple[torch.Tensor, Dict]:
26+
"""
27+
Args:
28+
logprob (`torch.Tensor`): The log probability generated by the policy model.
29+
old_logprob (`torch.Tensor`): The log probability generated by the reference model.
30+
action_mask (`torch.Tensor`): The action mask.
31+
advantages (`torch.Tensor`): The advantages.
32+
experiences (`DataProto`): The input experiences.
33+
kwargs (`Dict`): The step-level parameters for calculating the policy loss.
34+
35+
Returns:
36+
`torch.Tensor`: Policy loss
37+
`Dict`: The metrics for logging.
38+
"""

0 commit comments

Comments
 (0)