@@ -9779,11 +9779,11 @@ def mixture_constructor(logits, loc, scale):
9779
9779
# keep per-head entropies instead of the aggregated tensor
9780
9780
set_composite_lp_aggregate(False).set()
9781
9781
coef_map = {
9782
- ("agent0", "action", "action1", "sub_action1_log_prob"):0.02,
9783
- ("agent0", "action", "action1", "sub_action2_log_prob"):0.01,
9784
- ("agent0", "action", "action2_log_prob"):0.01,
9785
- ("agent1", "action_log_prob"):0.01,
9786
- "agent2_log_prob":0.01,
9782
+ ("agent0", "action", "action1", "sub_action1_log_prob"): 0.02,
9783
+ ("agent0", "action", "action1", "sub_action2_log_prob"): 0.01,
9784
+ ("agent0", "action", "action2_log_prob"): 0.01,
9785
+ ("agent1", "action_log_prob"): 0.01,
9786
+ "agent2_log_prob": 0.01,
9787
9787
}
9788
9788
ppo_weighted = cls(policy, value_operator, entropy_coeff=coef_map)
9789
9789
ppo_weighted.set_keys(
@@ -9872,7 +9872,7 @@ def test_weighted_entropy_scalar(self):
9872
9872
torch.testing.assert_close(out, torch.tensor(-1.0))
9873
9873
9874
9874
def test_weighted_entropy_mapping(self):
9875
- coef = {("head_0","action_log_prob"): 0.3, ("head_1","action_log_prob"): 0.7}
9875
+ coef = {("head_0", "action_log_prob"): 0.3, ("head_1", "action_log_prob"): 0.7}
9876
9876
loss = self._make_entropy_loss(entropy_coeff=coef)
9877
9877
entropy = TensorDict(
9878
9878
{
@@ -9882,7 +9882,10 @@ def test_weighted_entropy_mapping(self):
9882
9882
[],
9883
9883
)
9884
9884
out = loss._weighted_loss_entropy(entropy)
9885
- expected = -(coef[("head_0","action_log_prob")] * 1.0 + coef[("head_1","action_log_prob")] * 2.0)
9885
+ expected = -(
9886
+ coef[("head_0", "action_log_prob")] * 1.0
9887
+ + coef[("head_1", "action_log_prob")] * 2.0
9888
+ )
9886
9889
torch.testing.assert_close(out, torch.tensor(expected))
9887
9890
9888
9891
def test_weighted_entropy_mapping_missing_key(self):
0 commit comments