diff --git a/trinity/algorithm/__init__.py b/trinity/algorithm/__init__.py new file mode 100644 index 0000000000..f65ec67b47 --- /dev/null +++ b/trinity/algorithm/__init__.py @@ -0,0 +1,9 @@ +from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn + +__all__ = [ + "AdvantageFn", + "ADVANTAGE_FN", + "PolicyLossFn", + "POLICY_LOSS_FN", +] diff --git a/trinity/algorithm/advantage_fn/__init__.py b/trinity/algorithm/advantage_fn/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/trinity/algorithm/advantage_fn/advantage_fn.py b/trinity/algorithm/advantage_fn/advantage_fn.py new file mode 100644 index 0000000000..7e965b017c --- /dev/null +++ b/trinity/algorithm/advantage_fn/advantage_fn.py @@ -0,0 +1,21 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, Tuple + +from trinity.utils.registry import Registry + +ADVANTAGE_FN = Registry("advantage_fn") + + +class AdvantageFn(ABC): + @abstractmethod + def __call__(self, exps: Any, **kwargs: Dict) -> Tuple[Any, Dict]: + """Calculate advantages from experiences + + Args: + exps (`DataProto`): The input experiences. + kwargs (`Dict`): The step-level parameters for calculating advantages. + + Returns: + `Any`: The experiences with advantages. + `Dict`: The metrics for logging. + """ diff --git a/trinity/algorithm/entropy_loss/__init__.py b/trinity/algorithm/entropy_loss/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/trinity/algorithm/kl_loss/__init__.py b/trinity/algorithm/kl_loss/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/trinity/algorithm/policy_loss_fn/__init__.py b/trinity/algorithm/policy_loss_fn/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py new file mode 100644 index 0000000000..392f80e521 --- /dev/null +++ b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py @@ -0,0 +1,38 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, Tuple + +import torch + +from trinity.utils.registry import Registry + +POLICY_LOSS_FN = Registry("policy_loss_fn") + + +class PolicyLossFn(ABC): + """ + Policy Loss Function + """ + + @abstractmethod + def __call__( + self, + logprob: torch.Tensor, + old_logprob: torch.Tensor, + action_mask: torch.Tensor, + advantages: torch.Tensor, + experiences: Any, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + """ + Args: + logprob (`torch.Tensor`): The log probability generated by the policy model. + old_logprob (`torch.Tensor`): The log probability generated by the reference model. + action_mask (`torch.Tensor`): The action mask. + advantages (`torch.Tensor`): The advantages. + experiences (`DataProto`): The input experiences. + kwargs (`Dict`): The step-level parameters for calculating the policy loss. + + Returns: + `torch.Tensor`: Policy loss + `Dict`: The metrics for logging. + """