Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions trinity/algorithm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn

__all__ = [
"AdvantageFn",
"ADVANTAGE_FN",
"PolicyLossFn",
"POLICY_LOSS_FN",
]
Empty file.
21 changes: 21 additions & 0 deletions trinity/algorithm/advantage_fn/advantage_fn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Tuple

from trinity.utils.registry import Registry

ADVANTAGE_FN = Registry("advantage_fn")


class AdvantageFn(ABC):
@abstractmethod
def __call__(self, exps: Any, **kwargs: Dict) -> Tuple[Any, Dict]:
"""Calculate advantages from experiences

Args:
exps (`DataProto`): The input experiences.
kwargs (`Dict`): The step-level parameters for calculating advantages.

Returns:
`Any`: The experiences with advantages.
`Dict`: The metrics for logging.
"""
Empty file.
Empty file.
Empty file.
38 changes: 38 additions & 0 deletions trinity/algorithm/policy_loss_fn/policy_loss_fn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Tuple

import torch

from trinity.utils.registry import Registry

POLICY_LOSS_FN = Registry("policy_loss_fn")


class PolicyLossFn(ABC):
"""
Policy Loss Function
"""

@abstractmethod
def __call__(
self,
logprob: torch.Tensor,
old_logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
experiences: Any,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
"""
Args:
logprob (`torch.Tensor`): The log probability generated by the policy model.
old_logprob (`torch.Tensor`): The log probability generated by the reference model.
action_mask (`torch.Tensor`): The action mask.
advantages (`torch.Tensor`): The advantages.
experiences (`DataProto`): The input experiences.
kwargs (`Dict`): The step-level parameters for calculating the policy loss.

Returns:
`torch.Tensor`: Policy loss
`Dict`: The metrics for logging.
"""