From 7e8c9443fbdd962cf4cb087940b603f6886bdf35 Mon Sep 17 00:00:00 2001 From: weijie Date: Wed, 11 Jun 2025 16:25:57 +0800 Subject: [PATCH 1/3] 1. Init BasicEntropyLossFn default_args entropy_coef using zero, otherwise, check_and_update will be error. 2. fix writting repo in programming_guide --- docs/sphinx_doc/source/tutorial/trinity_programming_guide.md | 2 +- trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md index 2e4daeab0b..7119b1af35 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md +++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md @@ -299,7 +299,7 @@ pip install -e .[dev] # pip install -e .\[dev\] # Run code style checks -pre-commit --all-files +pre-commit run --all-files # Commit the code after all checks pass git commit -am "create example workflow" diff --git a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py index cf102dd6b7..b43847da7c 100644 --- a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py +++ b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py @@ -38,6 +38,7 @@ def default_args(cls) -> Dict: Returns: `Dict`: The default arguments for the entropy loss function. """ + return {"entropy_coef": 0.0} @ENTROPY_LOSS_FN.register_module("basic") From b5d3e096dd8a0fbeefbd9e574f00823259579699 Mon Sep 17 00:00:00 2001 From: weijie Date: Wed, 11 Jun 2025 16:43:33 +0800 Subject: [PATCH 2/3] 1. Init BasicEntropyLossFn default_args entropy_coef using zero, otherwise, check_and_update will be error. 2. fix writting typo in programming_guide --- .../entropy_loss_fn/entropy_loss_fn.py | 33 +++++++------------ 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py index b43847da7c..ad5db23ea6 100644 --- a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py +++ b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py @@ -16,10 +16,10 @@ class EntropyLossFn(ABC): @abstractmethod def __call__( - self, - entropy: torch.Tensor, - action_mask: torch.Tensor, - **kwargs, + self, + entropy: torch.Tensor, + action_mask: torch.Tensor, + **kwargs, ) -> Tuple[torch.Tensor, Dict]: """ Args: @@ -32,7 +32,6 @@ def __call__( """ @classmethod - @abstractmethod def default_args(cls) -> Dict: """ Returns: @@ -51,18 +50,14 @@ def __init__(self, entropy_coef: float): self.entropy_coef = entropy_coef def __call__( - self, - entropy: torch.Tensor, - action_mask: torch.Tensor, - **kwargs, + self, + entropy: torch.Tensor, + action_mask: torch.Tensor, + **kwargs, ) -> Tuple[torch.Tensor, Dict]: entropy_loss = masked_mean(entropy, action_mask) return entropy_loss * self.entropy_coef, {"entropy_loss": entropy_loss.detach().item()} - @classmethod - def default_args(cls) -> Dict: - return {"entropy_coef": 0.0} - @ENTROPY_LOSS_FN.register_module("none") class DummyEntropyLossFn(EntropyLossFn): @@ -74,13 +69,9 @@ def __init__(self): pass def __call__( - self, - entropy: torch.Tensor, - action_mask: torch.Tensor, - **kwargs, + self, + entropy: torch.Tensor, + action_mask: torch.Tensor, + **kwargs, ) -> Tuple[torch.Tensor, Dict]: return torch.tensor(0.0), {} - - @classmethod - def default_args(cls) -> Dict: - return {} From cece60610ee0c308244d04b97e538196b6816767 Mon Sep 17 00:00:00 2001 From: weijie Date: Wed, 11 Jun 2025 16:52:27 +0800 Subject: [PATCH 3/3] 1. Reformat --- .../entropy_loss_fn/entropy_loss_fn.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py index ad5db23ea6..e575caa449 100644 --- a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py +++ b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py @@ -16,10 +16,10 @@ class EntropyLossFn(ABC): @abstractmethod def __call__( - self, - entropy: torch.Tensor, - action_mask: torch.Tensor, - **kwargs, + self, + entropy: torch.Tensor, + action_mask: torch.Tensor, + **kwargs, ) -> Tuple[torch.Tensor, Dict]: """ Args: @@ -50,10 +50,10 @@ def __init__(self, entropy_coef: float): self.entropy_coef = entropy_coef def __call__( - self, - entropy: torch.Tensor, - action_mask: torch.Tensor, - **kwargs, + self, + entropy: torch.Tensor, + action_mask: torch.Tensor, + **kwargs, ) -> Tuple[torch.Tensor, Dict]: entropy_loss = masked_mean(entropy, action_mask) return entropy_loss * self.entropy_coef, {"entropy_loss": entropy_loss.detach().item()} @@ -69,9 +69,9 @@ def __init__(self): pass def __call__( - self, - entropy: torch.Tensor, - action_mask: torch.Tensor, - **kwargs, + self, + entropy: torch.Tensor, + action_mask: torch.Tensor, + **kwargs, ) -> Tuple[torch.Tensor, Dict]: return torch.tensor(0.0), {}