Skip to content

Commit 48f596a

Browse files
shiweijiezeroweijie
andauthored
Fix EntropyLossFn (#77)
Co-authored-by: weijie <[email protected]>
1 parent fec7f3c commit 48f596a

File tree

2 files changed

+2
-10
lines changed

2 files changed

+2
-10
lines changed

docs/sphinx_doc/source/tutorial/trinity_programming_guide.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ pip install -e .[dev]
299299
# pip install -e .\[dev\]
300300

301301
# Run code style checks
302-
pre-commit --all-files
302+
pre-commit run --all-files
303303

304304
# Commit the code after all checks pass
305305
git commit -am "create example workflow"

trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ def __call__(
3232
"""
3333

3434
@classmethod
35-
@abstractmethod
3635
def default_args(cls) -> Dict:
3736
"""
3837
Returns:
3938
`Dict`: The default arguments for the entropy loss function.
4039
"""
40+
return {"entropy_coef": 0.0}
4141

4242

4343
@ENTROPY_LOSS_FN.register_module("basic")
@@ -58,10 +58,6 @@ def __call__(
5858
entropy_loss = masked_mean(entropy, action_mask)
5959
return entropy_loss * self.entropy_coef, {"entropy_loss": entropy_loss.detach().item()}
6060

61-
@classmethod
62-
def default_args(cls) -> Dict:
63-
return {"entropy_coef": 0.0}
64-
6561

6662
@ENTROPY_LOSS_FN.register_module("none")
6763
class DummyEntropyLossFn(EntropyLossFn):
@@ -79,7 +75,3 @@ def __call__(
7975
**kwargs,
8076
) -> Tuple[torch.Tensor, Dict]:
8177
return torch.tensor(0.0), {}
82-
83-
@classmethod
84-
def default_args(cls) -> Dict:
85-
return {}

0 commit comments

Comments
 (0)