Skip to content

Commit d34dbb2

Browse files
juandelosJuan de los Rios
andauthored
[Feature,BugFix] Fix composite entropy for nested keys (#3101)
Co-authored-by: Juan de los Rios <[email protected]>
1 parent 978424e commit d34dbb2

File tree

2 files changed

+64
-33
lines changed

2 files changed

+64
-33
lines changed

test/test_cost.py

Lines changed: 51 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9659,6 +9659,10 @@ def test_ppo_composite_dists(self):
96599659

96609660
make_params = TensorDictModule(
96619661
lambda: (
9662+
torch.ones(4),
9663+
torch.ones(4),
9664+
torch.ones(4),
9665+
torch.ones(4),
96629666
torch.ones(4),
96639667
torch.ones(4),
96649668
torch.ones(4, 2),
@@ -9669,8 +9673,12 @@ def test_ppo_composite_dists(self):
96699673
),
96709674
in_keys=[],
96719675
out_keys=[
9672-
("params", "gamma", "concentration"),
9673-
("params", "gamma", "rate"),
9676+
("params", "gamma1", "concentration"),
9677+
("params", "gamma1", "rate"),
9678+
("params", "gamma2", "concentration"),
9679+
("params", "gamma2", "rate"),
9680+
("params", "gamma3", "concentration"),
9681+
("params", "gamma3", "rate"),
96749682
("params", "Kumaraswamy", "concentration0"),
96759683
("params", "Kumaraswamy", "concentration1"),
96769684
("params", "mixture", "logits"),
@@ -9687,24 +9695,30 @@ def mixture_constructor(logits, loc, scale):
96879695
dist_constructor = functools.partial(
96889696
CompositeDistribution,
96899697
distribution_map={
9690-
"gamma": d.Gamma,
9698+
"gamma1": d.Gamma,
9699+
"gamma2": d.Gamma,
9700+
"gamma3": d.Gamma,
96919701
"Kumaraswamy": d.Kumaraswamy,
96929702
"mixture": mixture_constructor,
96939703
},
96949704
name_map={
9695-
"gamma": ("agent0", "action"),
9705+
"gamma1": ("agent0", "action", "action1", "sub_action1"),
9706+
"gamma2": ("agent0", "action", "action1", "sub_action2"),
9707+
"gamma3": ("agent0", "action", "action2"),
96969708
"Kumaraswamy": ("agent1", "action"),
9697-
"mixture": ("agent2", "action"),
9709+
"mixture": ("agent2"),
96989710
},
96999711
)
97009712
policy = ProbSeq(
97019713
make_params,
97029714
ProbabilisticTensorDictModule(
97039715
in_keys=["params"],
97049716
out_keys=[
9705-
("agent0", "action"),
9717+
("agent0", "action", "action1", "sub_action1"),
9718+
("agent0", "action", "action1", "sub_action2"),
9719+
("agent0", "action", "action2"),
97069720
("agent1", "action"),
9707-
("agent2", "action"),
9721+
("agent2"),
97089722
],
97099723
distribution_class=dist_constructor,
97109724
return_log_prob=True,
@@ -9739,14 +9753,18 @@ def mixture_constructor(logits, loc, scale):
97399753
ppo = cls(policy, value_operator, entropy_coeff=scalar_entropy)
97409754
ppo.set_keys(
97419755
action=[
9742-
("agent0", "action"),
9756+
("agent0", "action", "action1", "sub_action1"),
9757+
("agent0", "action", "action1", "sub_action2"),
9758+
("agent0", "action", "action2"),
97439759
("agent1", "action"),
9744-
("agent2", "action"),
9760+
("agent2"),
97459761
],
97469762
sample_log_prob=[
9747-
("agent0", "action_log_prob"),
9763+
("agent0", "action", "action1", "sub_action1_log_prob"),
9764+
("agent0", "action", "action1", "sub_action2_log_prob"),
9765+
("agent0", "action", "action2_log_prob"),
97489766
("agent1", "action_log_prob"),
9749-
("agent2", "action_log_prob"),
9767+
("agent2_log_prob"),
97509768
],
97519769
)
97529770
loss = ppo(data)
@@ -9761,21 +9779,27 @@ def mixture_constructor(logits, loc, scale):
97619779
# keep per-head entropies instead of the aggregated tensor
97629780
set_composite_lp_aggregate(False).set()
97639781
coef_map = {
9764-
"agent0": 0.10,
9765-
"agent1": 0.05,
9766-
"agent2": 0.02,
9782+
("agent0", "action", "action1", "sub_action1_log_prob"): 0.02,
9783+
("agent0", "action", "action1", "sub_action2_log_prob"): 0.01,
9784+
("agent0", "action", "action2_log_prob"): 0.01,
9785+
("agent1", "action_log_prob"): 0.01,
9786+
"agent2_log_prob": 0.01,
97679787
}
97689788
ppo_weighted = cls(policy, value_operator, entropy_coeff=coef_map)
97699789
ppo_weighted.set_keys(
97709790
action=[
9771-
("agent0", "action"),
9791+
("agent0", "action", "action1", "sub_action1"),
9792+
("agent0", "action", "action1", "sub_action2"),
9793+
("agent0", "action", "action2"),
97729794
("agent1", "action"),
9773-
("agent2", "action"),
9795+
("agent2"),
97749796
],
97759797
sample_log_prob=[
9776-
("agent0", "action_log_prob"),
9798+
("agent0", "action", "action1", "sub_action1_log_prob"),
9799+
("agent0", "action", "action1", "sub_action2_log_prob"),
9800+
("agent0", "action", "action2_log_prob"),
97779801
("agent1", "action_log_prob"),
9778-
("agent2", "action_log_prob"),
9802+
("agent2_log_prob"),
97799803
],
97809804
)
97819805
loss = ppo_weighted(data)
@@ -9786,9 +9810,11 @@ def mixture_constructor(logits, loc, scale):
97869810
assert torch.isfinite(loss["loss_entropy"])
97879811
# Check individual loss is computed with the right weights
97889812
expected_loss = 0.0
9789-
for name, head_entropy in composite_entropy.items():
9813+
for i, (_, head_entropy) in enumerate(
9814+
composite_entropy.items(include_nested=True, leaves_only=True)
9815+
):
97909816
expected_loss -= (
9791-
coef_map[name] * _sum_td_features(head_entropy)
9817+
coef_map[list(coef_map.keys())[i]] * head_entropy
97929818
).mean()
97939819
torch.testing.assert_close(
97949820
loss["loss_entropy"], expected_loss, rtol=1e-5, atol=1e-7
@@ -9846,7 +9872,7 @@ def test_weighted_entropy_scalar(self):
98469872
torch.testing.assert_close(out, torch.tensor(-1.0))
98479873

98489874
def test_weighted_entropy_mapping(self):
9849-
coef = {"head_0": 0.3, "head_1": 0.7}
9875+
coef = {("head_0", "action_log_prob"): 0.3, ("head_1", "action_log_prob"): 0.7}
98509876
loss = self._make_entropy_loss(entropy_coeff=coef)
98519877
entropy = TensorDict(
98529878
{
@@ -9856,7 +9882,10 @@ def test_weighted_entropy_mapping(self):
98569882
[],
98579883
)
98589884
out = loss._weighted_loss_entropy(entropy)
9859-
expected = -(coef["head_0"] * 1.0 + coef["head_1"] * 2.0)
9885+
expected = -(
9886+
coef[("head_0", "action_log_prob")] * 1.0
9887+
+ coef[("head_1", "action_log_prob")] * 2.0
9888+
)
98609889
torch.testing.assert_close(out, torch.tensor(expected))
98619890

98629891
def test_weighted_entropy_mapping_missing_key(self):

torchrl/objectives/ppo.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ class PPOLoss(LossModule):
100100
``samples_mc_entropy`` will control how many
101101
samples will be used to compute this estimate.
102102
Defaults to ``1``.
103-
entropy_coeff: scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
103+
entropy_coeff: scalar | Mapping[NestedKey, scalar], optional): entropy multiplier when computing the total loss.
104104
* **Scalar**: one value applied to the summed entropy of every action head.
105105
* **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
106106
Defaults to ``0.01``.
@@ -351,7 +351,7 @@ def __init__(
351351
*,
352352
entropy_bonus: bool = True,
353353
samples_mc_entropy: int = 1,
354-
entropy_coeff: float | Mapping[str, float] | None = None,
354+
entropy_coeff: float | Mapping[NestedKey, float] | None = None,
355355
log_explained_variance: bool = True,
356356
critic_coeff: float | None = None,
357357
loss_critic_type: str = "smooth_l1",
@@ -459,9 +459,7 @@ def __init__(
459459

460460
if isinstance(entropy_coeff, Mapping):
461461
# Store the mapping for per-head coefficients
462-
self._entropy_coeff_map = {
463-
str(k): float(v) for k, v in entropy_coeff.items()
464-
}
462+
self._entropy_coeff_map = {k: float(v) for k, v in entropy_coeff.items()}
465463
# Register an empty buffer for compatibility
466464
self.register_buffer("entropy_coeff", torch.tensor(0.0))
467465
elif isinstance(entropy_coeff, (float, int, torch.Tensor)):
@@ -911,22 +909,26 @@ def _weighted_loss_entropy(
911909
912910
If `self._entropy_coeff_map` is provided, apply per-head entropy coefficients.
913911
Otherwise, use the scalar `self.entropy_coeff`.
912+
The entries in self._entropy_coeff_map require the full nested key to the entropy head.
914913
"""
915914
if self._entropy_coeff_map is None:
916915
if is_tensor_collection(entropy):
917916
entropy = _sum_td_features(entropy)
918917
return -self.entropy_coeff * entropy
919918

920919
loss_term = None # running sum over heads
921-
for head_name, entropy_head in entropy.items():
920+
coeff = 0
921+
for head_name, entropy_head in entropy.items(
922+
include_nested=True, leaves_only=True
923+
):
922924
try:
923925
coeff = self._entropy_coeff_map[head_name]
924926
except KeyError as exc:
925927
raise KeyError(f"Missing entropy coeff for head '{head_name}'") from exc
926928
coeff_t = torch.as_tensor(
927929
coeff, dtype=entropy_head.dtype, device=entropy_head.device
928930
)
929-
head_loss_term = -coeff_t * _sum_td_features(entropy_head)
931+
head_loss_term = -coeff_t * entropy_head
930932
loss_term = (
931933
head_loss_term if loss_term is None else loss_term + head_loss_term
932934
) # accumulate
@@ -970,7 +972,7 @@ class ClipPPOLoss(PPOLoss):
970972
``samples_mc_entropy`` will control how many
971973
samples will be used to compute this estimate.
972974
Defaults to ``1``.
973-
entropy_coeff: (scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
975+
entropy_coeff: (scalar | Mapping[NesstedKey, scalar], optional): entropy multiplier when computing the total loss.
974976
* **Scalar**: one value applied to the summed entropy of every action head.
975977
* **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
976978
Defaults to ``0.01``.
@@ -1075,7 +1077,7 @@ def __init__(
10751077
clip_epsilon: float = 0.2,
10761078
entropy_bonus: bool = True,
10771079
samples_mc_entropy: int = 1,
1078-
entropy_coeff: float | Mapping[str, float] | None = None,
1080+
entropy_coeff: float | Mapping[NestedKey, float] | None = None,
10791081
critic_coeff: float | None = None,
10801082
loss_critic_type: str = "smooth_l1",
10811083
normalize_advantage: bool = False,
@@ -1263,7 +1265,7 @@ class KLPENPPOLoss(PPOLoss):
12631265
``samples_mc_entropy`` will control how many
12641266
samples will be used to compute this estimate.
12651267
Defaults to ``1``.
1266-
entropy_coeff: scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
1268+
entropy_coeff: scalar | Mapping[NestedKey, scalar], optional): entropy multiplier when computing the total loss.
12671269
* **Scalar**: one value applied to the summed entropy of every action head.
12681270
* **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
12691271
Defaults to ``0.01``.
@@ -1369,7 +1371,7 @@ def __init__(
13691371
samples_mc_kl: int = 1,
13701372
entropy_bonus: bool = True,
13711373
samples_mc_entropy: int = 1,
1372-
entropy_coeff: float | Mapping[str, float] | None = None,
1374+
entropy_coeff: float | Mapping[NestedKey, float] | None = None,
13731375
critic_coeff: float | None = None,
13741376
loss_critic_type: str = "smooth_l1",
13751377
normalize_advantage: bool = False,

0 commit comments

Comments
 (0)