@@ -9659,6 +9659,10 @@ def test_ppo_composite_dists(self):
9659
9659
9660
9660
make_params = TensorDictModule(
9661
9661
lambda: (
9662
+ torch.ones(4),
9663
+ torch.ones(4),
9664
+ torch.ones(4),
9665
+ torch.ones(4),
9662
9666
torch.ones(4),
9663
9667
torch.ones(4),
9664
9668
torch.ones(4, 2),
@@ -9669,8 +9673,12 @@ def test_ppo_composite_dists(self):
9669
9673
),
9670
9674
in_keys=[],
9671
9675
out_keys=[
9672
- ("params", "gamma", "concentration"),
9673
- ("params", "gamma", "rate"),
9676
+ ("params", "gamma1", "concentration"),
9677
+ ("params", "gamma1", "rate"),
9678
+ ("params", "gamma2", "concentration"),
9679
+ ("params", "gamma2", "rate"),
9680
+ ("params", "gamma3", "concentration"),
9681
+ ("params", "gamma3", "rate"),
9674
9682
("params", "Kumaraswamy", "concentration0"),
9675
9683
("params", "Kumaraswamy", "concentration1"),
9676
9684
("params", "mixture", "logits"),
@@ -9687,24 +9695,30 @@ def mixture_constructor(logits, loc, scale):
9687
9695
dist_constructor = functools.partial(
9688
9696
CompositeDistribution,
9689
9697
distribution_map={
9690
- "gamma": d.Gamma,
9698
+ "gamma1": d.Gamma,
9699
+ "gamma2": d.Gamma,
9700
+ "gamma3": d.Gamma,
9691
9701
"Kumaraswamy": d.Kumaraswamy,
9692
9702
"mixture": mixture_constructor,
9693
9703
},
9694
9704
name_map={
9695
- "gamma": ("agent0", "action"),
9705
+ "gamma1": ("agent0", "action", "action1", "sub_action1"),
9706
+ "gamma2": ("agent0", "action", "action1", "sub_action2"),
9707
+ "gamma3": ("agent0", "action", "action2"),
9696
9708
"Kumaraswamy": ("agent1", "action"),
9697
- "mixture": ("agent2", "action" ),
9709
+ "mixture": ("agent2"),
9698
9710
},
9699
9711
)
9700
9712
policy = ProbSeq(
9701
9713
make_params,
9702
9714
ProbabilisticTensorDictModule(
9703
9715
in_keys=["params"],
9704
9716
out_keys=[
9705
- ("agent0", "action"),
9717
+ ("agent0", "action", "action1", "sub_action1"),
9718
+ ("agent0", "action", "action1", "sub_action2"),
9719
+ ("agent0", "action", "action2"),
9706
9720
("agent1", "action"),
9707
- ("agent2", "action" ),
9721
+ ("agent2"),
9708
9722
],
9709
9723
distribution_class=dist_constructor,
9710
9724
return_log_prob=True,
@@ -9739,14 +9753,18 @@ def mixture_constructor(logits, loc, scale):
9739
9753
ppo = cls(policy, value_operator, entropy_coeff=scalar_entropy)
9740
9754
ppo.set_keys(
9741
9755
action=[
9742
- ("agent0", "action"),
9756
+ ("agent0", "action", "action1", "sub_action1"),
9757
+ ("agent0", "action", "action1", "sub_action2"),
9758
+ ("agent0", "action", "action2"),
9743
9759
("agent1", "action"),
9744
- ("agent2", "action" ),
9760
+ ("agent2"),
9745
9761
],
9746
9762
sample_log_prob=[
9747
- ("agent0", "action_log_prob"),
9763
+ ("agent0", "action", "action1", "sub_action1_log_prob"),
9764
+ ("agent0", "action", "action1", "sub_action2_log_prob"),
9765
+ ("agent0", "action", "action2_log_prob"),
9748
9766
("agent1", "action_log_prob"),
9749
- ("agent2", "action_log_prob "),
9767
+ ("agent2_log_prob "),
9750
9768
],
9751
9769
)
9752
9770
loss = ppo(data)
@@ -9761,21 +9779,27 @@ def mixture_constructor(logits, loc, scale):
9761
9779
# keep per-head entropies instead of the aggregated tensor
9762
9780
set_composite_lp_aggregate(False).set()
9763
9781
coef_map = {
9764
- "agent0": 0.10,
9765
- "agent1": 0.05,
9766
- "agent2": 0.02,
9782
+ ("agent0", "action", "action1", "sub_action1_log_prob"): 0.02,
9783
+ ("agent0", "action", "action1", "sub_action2_log_prob"): 0.01,
9784
+ ("agent0", "action", "action2_log_prob"): 0.01,
9785
+ ("agent1", "action_log_prob"): 0.01,
9786
+ "agent2_log_prob": 0.01,
9767
9787
}
9768
9788
ppo_weighted = cls(policy, value_operator, entropy_coeff=coef_map)
9769
9789
ppo_weighted.set_keys(
9770
9790
action=[
9771
- ("agent0", "action"),
9791
+ ("agent0", "action", "action1", "sub_action1"),
9792
+ ("agent0", "action", "action1", "sub_action2"),
9793
+ ("agent0", "action", "action2"),
9772
9794
("agent1", "action"),
9773
- ("agent2", "action" ),
9795
+ ("agent2"),
9774
9796
],
9775
9797
sample_log_prob=[
9776
- ("agent0", "action_log_prob"),
9798
+ ("agent0", "action", "action1", "sub_action1_log_prob"),
9799
+ ("agent0", "action", "action1", "sub_action2_log_prob"),
9800
+ ("agent0", "action", "action2_log_prob"),
9777
9801
("agent1", "action_log_prob"),
9778
- ("agent2", "action_log_prob "),
9802
+ ("agent2_log_prob "),
9779
9803
],
9780
9804
)
9781
9805
loss = ppo_weighted(data)
@@ -9786,9 +9810,11 @@ def mixture_constructor(logits, loc, scale):
9786
9810
assert torch.isfinite(loss["loss_entropy"])
9787
9811
# Check individual loss is computed with the right weights
9788
9812
expected_loss = 0.0
9789
- for name, head_entropy in composite_entropy.items():
9813
+ for i, (_, head_entropy) in enumerate(
9814
+ composite_entropy.items(include_nested=True, leaves_only=True)
9815
+ ):
9790
9816
expected_loss -= (
9791
- coef_map[name] * _sum_td_features( head_entropy)
9817
+ coef_map[list(coef_map.keys())[i]] * head_entropy
9792
9818
).mean()
9793
9819
torch.testing.assert_close(
9794
9820
loss["loss_entropy"], expected_loss, rtol=1e-5, atol=1e-7
@@ -9846,7 +9872,7 @@ def test_weighted_entropy_scalar(self):
9846
9872
torch.testing.assert_close(out, torch.tensor(-1.0))
9847
9873
9848
9874
def test_weighted_entropy_mapping(self):
9849
- coef = {"head_0": 0.3, "head_1": 0.7}
9875
+ coef = {( "head_0", "action_log_prob") : 0.3, ( "head_1", "action_log_prob") : 0.7}
9850
9876
loss = self._make_entropy_loss(entropy_coeff=coef)
9851
9877
entropy = TensorDict(
9852
9878
{
@@ -9856,7 +9882,10 @@ def test_weighted_entropy_mapping(self):
9856
9882
[],
9857
9883
)
9858
9884
out = loss._weighted_loss_entropy(entropy)
9859
- expected = -(coef["head_0"] * 1.0 + coef["head_1"] * 2.0)
9885
+ expected = -(
9886
+ coef[("head_0", "action_log_prob")] * 1.0
9887
+ + coef[("head_1", "action_log_prob")] * 2.0
9888
+ )
9860
9889
torch.testing.assert_close(out, torch.tensor(expected))
9861
9890
9862
9891
def test_weighted_entropy_mapping_missing_key(self):
0 commit comments