Skip to content

Commit 4fc9e58

Browse files
Juan de los Riosvmoens
authored andcommitted
update comment, format
1 parent 5b8d8aa commit 4fc9e58

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

test/test_cost.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9779,11 +9779,11 @@ def mixture_constructor(logits, loc, scale):
97799779
# keep per-head entropies instead of the aggregated tensor
97809780
set_composite_lp_aggregate(False).set()
97819781
coef_map = {
9782-
("agent0", "action", "action1", "sub_action1_log_prob"):0.02,
9783-
("agent0", "action", "action1", "sub_action2_log_prob"):0.01,
9784-
("agent0", "action", "action2_log_prob"):0.01,
9785-
("agent1", "action_log_prob"):0.01,
9786-
"agent2_log_prob":0.01,
9782+
("agent0", "action", "action1", "sub_action1_log_prob"): 0.02,
9783+
("agent0", "action", "action1", "sub_action2_log_prob"): 0.01,
9784+
("agent0", "action", "action2_log_prob"): 0.01,
9785+
("agent1", "action_log_prob"): 0.01,
9786+
"agent2_log_prob": 0.01,
97879787
}
97889788
ppo_weighted = cls(policy, value_operator, entropy_coeff=coef_map)
97899789
ppo_weighted.set_keys(
@@ -9872,7 +9872,7 @@ def test_weighted_entropy_scalar(self):
98729872
torch.testing.assert_close(out, torch.tensor(-1.0))
98739873

98749874
def test_weighted_entropy_mapping(self):
9875-
coef = {("head_0","action_log_prob"): 0.3, ("head_1","action_log_prob"): 0.7}
9875+
coef = {("head_0", "action_log_prob"): 0.3, ("head_1", "action_log_prob"): 0.7}
98769876
loss = self._make_entropy_loss(entropy_coeff=coef)
98779877
entropy = TensorDict(
98789878
{
@@ -9882,7 +9882,10 @@ def test_weighted_entropy_mapping(self):
98829882
[],
98839883
)
98849884
out = loss._weighted_loss_entropy(entropy)
9885-
expected = -(coef[("head_0","action_log_prob")] * 1.0 + coef[("head_1","action_log_prob")] * 2.0)
9885+
expected = -(
9886+
coef[("head_0", "action_log_prob")] * 1.0
9887+
+ coef[("head_1", "action_log_prob")] * 2.0
9888+
)
98869889
torch.testing.assert_close(out, torch.tensor(expected))
98879890

98889891
def test_weighted_entropy_mapping_missing_key(self):

torchrl/objectives/ppo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,7 @@ def _weighted_loss_entropy(
912912
913913
If `self._entropy_coeff_map` is provided, apply per-head entropy coefficients.
914914
Otherwise, use the scalar `self.entropy_coeff`.
915+
The entries in self._entropy_coeff_map require the full nested key to the entropy head.
915916
"""
916917
if self._entropy_coeff_map is None:
917918
if is_tensor_collection(entropy):

0 commit comments

Comments
 (0)