@@ -9779,11 +9779,11 @@ def mixture_constructor(logits, loc, scale):
9779
9779
# keep per-head entropies instead of the aggregated tensor
9780
9780
set_composite_lp_aggregate(False).set()
9781
9781
coef_map = {
9782
- ("agent0", "action", "action1", "sub_action1_log_prob"): 0.10 ,
9783
- " sub_action2_log_prob": 0.10 ,
9784
- "action2": 0.10 ,
9785
- ("agent1", "action_log_prob"): 0.10 ,
9786
- "agent2_log_prob": 0.02 ,
9782
+ ("agent0", "action", "action1", "sub_action1_log_prob"):0.02 ,
9783
+ ("agent0", "action", "action1", " sub_action2_log_prob"):0.01 ,
9784
+ ("agent0", "action", "action2_log_prob"):0.01 ,
9785
+ ("agent1", "action_log_prob"):0.01 ,
9786
+ "agent2_log_prob":0.01 ,
9787
9787
}
9788
9788
ppo_weighted = cls(policy, value_operator, entropy_coeff=coef_map)
9789
9789
ppo_weighted.set_keys(
@@ -9872,7 +9872,7 @@ def test_weighted_entropy_scalar(self):
9872
9872
torch.testing.assert_close(out, torch.tensor(-1.0))
9873
9873
9874
9874
def test_weighted_entropy_mapping(self):
9875
- coef = {"head_0": 0.3, "head_1": 0.7}
9875
+ coef = {( "head_0","action_log_prob") : 0.3, ( "head_1","action_log_prob") : 0.7}
9876
9876
loss = self._make_entropy_loss(entropy_coeff=coef)
9877
9877
entropy = TensorDict(
9878
9878
{
@@ -9882,7 +9882,7 @@ def test_weighted_entropy_mapping(self):
9882
9882
[],
9883
9883
)
9884
9884
out = loss._weighted_loss_entropy(entropy)
9885
- expected = -(coef["head_0"] * 1.0 + coef["head_1"] * 2.0)
9885
+ expected = -(coef[( "head_0","action_log_prob") ] * 1.0 + coef[( "head_1","action_log_prob") ] * 2.0)
9886
9886
torch.testing.assert_close(out, torch.tensor(expected))
9887
9887
9888
9888
def test_weighted_entropy_mapping_missing_key(self):
0 commit comments