Skip to content

Commit d1124af

Browse files
Juan de los Riosvmoens
authored andcommitted
Fix composite entropy nested keys
1 parent 978424e commit d1124af

File tree

2 files changed

+79
-30
lines changed

2 files changed

+79
-30
lines changed

test/test_cost.py

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9659,6 +9659,10 @@ def test_ppo_composite_dists(self):
96599659

96609660
make_params = TensorDictModule(
96619661
lambda: (
9662+
torch.ones(4),
9663+
torch.ones(4),
9664+
torch.ones(4),
9665+
torch.ones(4),
96629666
torch.ones(4),
96639667
torch.ones(4),
96649668
torch.ones(4, 2),
@@ -9669,8 +9673,12 @@ def test_ppo_composite_dists(self):
96699673
),
96709674
in_keys=[],
96719675
out_keys=[
9672-
("params", "gamma", "concentration"),
9673-
("params", "gamma", "rate"),
9676+
("params", "gamma1", "concentration"),
9677+
("params", "gamma1", "rate"),
9678+
("params", "gamma2", "concentration"),
9679+
("params", "gamma2", "rate"),
9680+
("params", "gamma3", "concentration"),
9681+
("params", "gamma3", "rate"),
96749682
("params", "Kumaraswamy", "concentration0"),
96759683
("params", "Kumaraswamy", "concentration1"),
96769684
("params", "mixture", "logits"),
@@ -9687,24 +9695,30 @@ def mixture_constructor(logits, loc, scale):
96879695
dist_constructor = functools.partial(
96889696
CompositeDistribution,
96899697
distribution_map={
9690-
"gamma": d.Gamma,
9698+
"gamma1": d.Gamma,
9699+
"gamma2": d.Gamma,
9700+
"gamma3": d.Gamma,
96919701
"Kumaraswamy": d.Kumaraswamy,
96929702
"mixture": mixture_constructor,
96939703
},
96949704
name_map={
9695-
"gamma": ("agent0", "action"),
9705+
"gamma1": ("agent0", "action", "action1", "sub_action1"),
9706+
"gamma2": ("agent0", "action", "action1", "sub_action2"),
9707+
"gamma3": ("agent0", "action", "action2"),
96969708
"Kumaraswamy": ("agent1", "action"),
9697-
"mixture": ("agent2", "action"),
9709+
"mixture": ("agent2"),
96989710
},
96999711
)
97009712
policy = ProbSeq(
97019713
make_params,
97029714
ProbabilisticTensorDictModule(
97039715
in_keys=["params"],
97049716
out_keys=[
9705-
("agent0", "action"),
9717+
("agent0", "action", "action1", "sub_action1"),
9718+
("agent0", "action", "action1", "sub_action2"),
9719+
("agent0", "action", "action2"),
97069720
("agent1", "action"),
9707-
("agent2", "action"),
9721+
("agent2"),
97089722
],
97099723
distribution_class=dist_constructor,
97109724
return_log_prob=True,
@@ -9739,14 +9753,18 @@ def mixture_constructor(logits, loc, scale):
97399753
ppo = cls(policy, value_operator, entropy_coeff=scalar_entropy)
97409754
ppo.set_keys(
97419755
action=[
9742-
("agent0", "action"),
9756+
("agent0", "action", "action1", "sub_action1"),
9757+
("agent0", "action", "action1", "sub_action2"),
9758+
("agent0", "action", "action2"),
97439759
("agent1", "action"),
9744-
("agent2", "action"),
9760+
("agent2"),
97459761
],
97469762
sample_log_prob=[
9747-
("agent0", "action_log_prob"),
9763+
("agent0", "action", "action1", "sub_action1_log_prob"),
9764+
("agent0", "action", "action1", "sub_action2_log_prob"),
9765+
("agent0", "action", "action2_log_prob"),
97489766
("agent1", "action_log_prob"),
9749-
("agent2", "action_log_prob"),
9767+
("agent2_log_prob"),
97509768
],
97519769
)
97529770
loss = ppo(data)
@@ -9761,21 +9779,27 @@ def mixture_constructor(logits, loc, scale):
97619779
# keep per-head entropies instead of the aggregated tensor
97629780
set_composite_lp_aggregate(False).set()
97639781
coef_map = {
9764-
"agent0": 0.10,
9765-
"agent1": 0.05,
9766-
"agent2": 0.02,
9782+
("agent0", "action", "action1", "sub_action1_log_prob"): 0.10,
9783+
"sub_action2_log_prob": 0.10,
9784+
"action2": 0.10,
9785+
("agent1", "action_log_prob"): 0.10,
9786+
"agent2_log_prob": 0.02,
97679787
}
97689788
ppo_weighted = cls(policy, value_operator, entropy_coeff=coef_map)
97699789
ppo_weighted.set_keys(
97709790
action=[
9771-
("agent0", "action"),
9791+
("agent0", "action", "action1", "sub_action1"),
9792+
("agent0", "action", "action1", "sub_action2"),
9793+
("agent0", "action", "action2"),
97729794
("agent1", "action"),
9773-
("agent2", "action"),
9795+
("agent2"),
97749796
],
97759797
sample_log_prob=[
9776-
("agent0", "action_log_prob"),
9798+
("agent0", "action", "action1", "sub_action1_log_prob"),
9799+
("agent0", "action", "action1", "sub_action2_log_prob"),
9800+
("agent0", "action", "action2_log_prob"),
97779801
("agent1", "action_log_prob"),
9778-
("agent2", "action_log_prob"),
9802+
("agent2_log_prob"),
97799803
],
97809804
)
97819805
loss = ppo_weighted(data)
@@ -9786,9 +9810,11 @@ def mixture_constructor(logits, loc, scale):
97869810
assert torch.isfinite(loss["loss_entropy"])
97879811
# Check individual loss is computed with the right weights
97889812
expected_loss = 0.0
9789-
for name, head_entropy in composite_entropy.items():
9813+
for i, (_, head_entropy) in enumerate(
9814+
composite_entropy.items(include_nested=True, leaves_only=True)
9815+
):
97909816
expected_loss -= (
9791-
coef_map[name] * _sum_td_features(head_entropy)
9817+
coef_map[list(coef_map.keys())[i]] * head_entropy
97929818
).mean()
97939819
torch.testing.assert_close(
97949820
loss["loss_entropy"], expected_loss, rtol=1e-5, atol=1e-7

torchrl/objectives/ppo.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def __init__(
351351
*,
352352
entropy_bonus: bool = True,
353353
samples_mc_entropy: int = 1,
354-
entropy_coeff: float | Mapping[str, float] | None = None,
354+
entropy_coeff: float | Mapping[str | tuple | list, float] | None = None,
355355
log_explained_variance: bool = True,
356356
critic_coeff: float | None = None,
357357
loss_critic_type: str = "smooth_l1",
@@ -460,7 +460,8 @@ def __init__(
460460
if isinstance(entropy_coeff, Mapping):
461461
# Store the mapping for per-head coefficients
462462
self._entropy_coeff_map = {
463-
str(k): float(v) for k, v in entropy_coeff.items()
463+
(tuple(k) if isinstance(k, list) else k): float(v)
464+
for k, v in entropy_coeff.items()
464465
}
465466
# Register an empty buffer for compatibility
466467
self.register_buffer("entropy_coeff", torch.tensor(0.0))
@@ -918,15 +919,37 @@ def _weighted_loss_entropy(
918919
return -self.entropy_coeff * entropy
919920

920921
loss_term = None # running sum over heads
921-
for head_name, entropy_head in entropy.items():
922-
try:
923-
coeff = self._entropy_coeff_map[head_name]
924-
except KeyError as exc:
925-
raise KeyError(f"Missing entropy coeff for head '{head_name}'") from exc
922+
coeff = 0
923+
for head_name, entropy_head in entropy.items(
924+
include_nested=True, leaves_only=True
925+
):
926+
if isinstance(head_name, str):
927+
head_name = (head_name,)
928+
for i, (head_name_from_map, _coeff) in enumerate(
929+
self._entropy_coeff_map.items()
930+
):
931+
# Check if distinct head name inisde tuple of nested dict
932+
if head_name_from_map in head_name:
933+
coeff = _coeff
934+
break
935+
# Check if path of head fully or partially in nested dict
936+
if any(
937+
head_name_from_map == head_name[i : i + len(head_name_from_map)]
938+
for i in range(len(head_name) - len(head_name_from_map) + 1)
939+
):
940+
coeff = _coeff
941+
break
942+
if i == len(self._entropy_coeff_map.items()):
943+
raise KeyError(
944+
f"Missing entropy coeff for head '{head_name}'"
945+
) from exec
926946
coeff_t = torch.as_tensor(
927947
coeff, dtype=entropy_head.dtype, device=entropy_head.device
928948
)
929-
head_loss_term = -coeff_t * _sum_td_features(entropy_head)
949+
if isinstance(entropy_head, torch.Tensor):
950+
head_loss_term = -coeff_t * entropy_head
951+
else:
952+
head_loss_term = -coeff_t * _sum_td_features(entropy_head)
930953
loss_term = (
931954
head_loss_term if loss_term is None else loss_term + head_loss_term
932955
) # accumulate
@@ -1075,7 +1098,7 @@ def __init__(
10751098
clip_epsilon: float = 0.2,
10761099
entropy_bonus: bool = True,
10771100
samples_mc_entropy: int = 1,
1078-
entropy_coeff: float | Mapping[str, float] | None = None,
1101+
entropy_coeff: float | Mapping[str | tuple | list, float] | None = None,
10791102
critic_coeff: float | None = None,
10801103
loss_critic_type: str = "smooth_l1",
10811104
normalize_advantage: bool = False,
@@ -1369,7 +1392,7 @@ def __init__(
13691392
samples_mc_kl: int = 1,
13701393
entropy_bonus: bool = True,
13711394
samples_mc_entropy: int = 1,
1372-
entropy_coeff: float | Mapping[str, float] | None = None,
1395+
entropy_coeff: float | Mapping[str | tuple | list, float] | None = None,
13731396
critic_coeff: float | None = None,
13741397
loss_critic_type: str = "smooth_l1",
13751398
normalize_advantage: bool = False,

0 commit comments

Comments
 (0)