Skip to content

[Feature,BugFix] Fix composite entropy for nested keys #3101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 51 additions & 22 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -9659,6 +9659,10 @@ def test_ppo_composite_dists(self):

make_params = TensorDictModule(
lambda: (
torch.ones(4),
torch.ones(4),
torch.ones(4),
torch.ones(4),
torch.ones(4),
torch.ones(4),
torch.ones(4, 2),
Expand All @@ -9669,8 +9673,12 @@ def test_ppo_composite_dists(self):
),
in_keys=[],
out_keys=[
("params", "gamma", "concentration"),
("params", "gamma", "rate"),
("params", "gamma1", "concentration"),
("params", "gamma1", "rate"),
("params", "gamma2", "concentration"),
("params", "gamma2", "rate"),
("params", "gamma3", "concentration"),
("params", "gamma3", "rate"),
("params", "Kumaraswamy", "concentration0"),
("params", "Kumaraswamy", "concentration1"),
("params", "mixture", "logits"),
Expand All @@ -9687,24 +9695,30 @@ def mixture_constructor(logits, loc, scale):
dist_constructor = functools.partial(
CompositeDistribution,
distribution_map={
"gamma": d.Gamma,
"gamma1": d.Gamma,
"gamma2": d.Gamma,
"gamma3": d.Gamma,
"Kumaraswamy": d.Kumaraswamy,
"mixture": mixture_constructor,
},
name_map={
"gamma": ("agent0", "action"),
"gamma1": ("agent0", "action", "action1", "sub_action1"),
"gamma2": ("agent0", "action", "action1", "sub_action2"),
"gamma3": ("agent0", "action", "action2"),
"Kumaraswamy": ("agent1", "action"),
"mixture": ("agent2", "action"),
"mixture": ("agent2"),
},
)
policy = ProbSeq(
make_params,
ProbabilisticTensorDictModule(
in_keys=["params"],
out_keys=[
("agent0", "action"),
("agent0", "action", "action1", "sub_action1"),
("agent0", "action", "action1", "sub_action2"),
("agent0", "action", "action2"),
("agent1", "action"),
("agent2", "action"),
("agent2"),
],
distribution_class=dist_constructor,
return_log_prob=True,
Expand Down Expand Up @@ -9739,14 +9753,18 @@ def mixture_constructor(logits, loc, scale):
ppo = cls(policy, value_operator, entropy_coeff=scalar_entropy)
ppo.set_keys(
action=[
("agent0", "action"),
("agent0", "action", "action1", "sub_action1"),
("agent0", "action", "action1", "sub_action2"),
("agent0", "action", "action2"),
("agent1", "action"),
("agent2", "action"),
("agent2"),
],
sample_log_prob=[
("agent0", "action_log_prob"),
("agent0", "action", "action1", "sub_action1_log_prob"),
("agent0", "action", "action1", "sub_action2_log_prob"),
("agent0", "action", "action2_log_prob"),
("agent1", "action_log_prob"),
("agent2", "action_log_prob"),
("agent2_log_prob"),
],
)
loss = ppo(data)
Expand All @@ -9761,21 +9779,27 @@ def mixture_constructor(logits, loc, scale):
# keep per-head entropies instead of the aggregated tensor
set_composite_lp_aggregate(False).set()
coef_map = {
"agent0": 0.10,
"agent1": 0.05,
"agent2": 0.02,
("agent0", "action", "action1", "sub_action1_log_prob"): 0.02,
("agent0", "action", "action1", "sub_action2_log_prob"): 0.01,
("agent0", "action", "action2_log_prob"): 0.01,
("agent1", "action_log_prob"): 0.01,
"agent2_log_prob": 0.01,
}
ppo_weighted = cls(policy, value_operator, entropy_coeff=coef_map)
ppo_weighted.set_keys(
action=[
("agent0", "action"),
("agent0", "action", "action1", "sub_action1"),
("agent0", "action", "action1", "sub_action2"),
("agent0", "action", "action2"),
("agent1", "action"),
("agent2", "action"),
("agent2"),
],
sample_log_prob=[
("agent0", "action_log_prob"),
("agent0", "action", "action1", "sub_action1_log_prob"),
("agent0", "action", "action1", "sub_action2_log_prob"),
("agent0", "action", "action2_log_prob"),
("agent1", "action_log_prob"),
("agent2", "action_log_prob"),
("agent2_log_prob"),
],
)
loss = ppo_weighted(data)
Expand All @@ -9786,9 +9810,11 @@ def mixture_constructor(logits, loc, scale):
assert torch.isfinite(loss["loss_entropy"])
# Check individual loss is computed with the right weights
expected_loss = 0.0
for name, head_entropy in composite_entropy.items():
for i, (_, head_entropy) in enumerate(
composite_entropy.items(include_nested=True, leaves_only=True)
):
expected_loss -= (
coef_map[name] * _sum_td_features(head_entropy)
coef_map[list(coef_map.keys())[i]] * head_entropy
).mean()
torch.testing.assert_close(
loss["loss_entropy"], expected_loss, rtol=1e-5, atol=1e-7
Expand Down Expand Up @@ -9846,7 +9872,7 @@ def test_weighted_entropy_scalar(self):
torch.testing.assert_close(out, torch.tensor(-1.0))

def test_weighted_entropy_mapping(self):
coef = {"head_0": 0.3, "head_1": 0.7}
coef = {("head_0", "action_log_prob"): 0.3, ("head_1", "action_log_prob"): 0.7}
loss = self._make_entropy_loss(entropy_coeff=coef)
entropy = TensorDict(
{
Expand All @@ -9856,7 +9882,10 @@ def test_weighted_entropy_mapping(self):
[],
)
out = loss._weighted_loss_entropy(entropy)
expected = -(coef["head_0"] * 1.0 + coef["head_1"] * 2.0)
expected = -(
coef[("head_0", "action_log_prob")] * 1.0
+ coef[("head_1", "action_log_prob")] * 2.0
)
torch.testing.assert_close(out, torch.tensor(expected))

def test_weighted_entropy_mapping_missing_key(self):
Expand Down
24 changes: 13 additions & 11 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class PPOLoss(LossModule):
``samples_mc_entropy`` will control how many
samples will be used to compute this estimate.
Defaults to ``1``.
entropy_coeff: scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
entropy_coeff: scalar | Mapping[NestedKey, scalar], optional): entropy multiplier when computing the total loss.
* **Scalar**: one value applied to the summed entropy of every action head.
* **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
Defaults to ``0.01``.
Expand Down Expand Up @@ -351,7 +351,7 @@ def __init__(
*,
entropy_bonus: bool = True,
samples_mc_entropy: int = 1,
entropy_coeff: float | Mapping[str, float] | None = None,
entropy_coeff: float | Mapping[NestedKey, float] | None = None,
log_explained_variance: bool = True,
critic_coeff: float | None = None,
loss_critic_type: str = "smooth_l1",
Expand Down Expand Up @@ -459,9 +459,7 @@ def __init__(

if isinstance(entropy_coeff, Mapping):
# Store the mapping for per-head coefficients
self._entropy_coeff_map = {
str(k): float(v) for k, v in entropy_coeff.items()
}
self._entropy_coeff_map = {k: float(v) for k, v in entropy_coeff.items()}
# Register an empty buffer for compatibility
self.register_buffer("entropy_coeff", torch.tensor(0.0))
elif isinstance(entropy_coeff, (float, int, torch.Tensor)):
Expand Down Expand Up @@ -911,22 +909,26 @@ def _weighted_loss_entropy(

If `self._entropy_coeff_map` is provided, apply per-head entropy coefficients.
Otherwise, use the scalar `self.entropy_coeff`.
The entries in self._entropy_coeff_map require the full nested key to the entropy head.
"""
if self._entropy_coeff_map is None:
if is_tensor_collection(entropy):
entropy = _sum_td_features(entropy)
return -self.entropy_coeff * entropy

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General comment: I feel like this method could use a bit more inline comments to guide the reader.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep agree on that
I can give it a shot

loss_term = None # running sum over heads
for head_name, entropy_head in entropy.items():
coeff = 0
for head_name, entropy_head in entropy.items(
include_nested=True, leaves_only=True
):
try:
coeff = self._entropy_coeff_map[head_name]
except KeyError as exc:
raise KeyError(f"Missing entropy coeff for head '{head_name}'") from exc
coeff_t = torch.as_tensor(
coeff, dtype=entropy_head.dtype, device=entropy_head.device
)
head_loss_term = -coeff_t * _sum_td_features(entropy_head)
head_loss_term = -coeff_t * entropy_head
loss_term = (
head_loss_term if loss_term is None else loss_term + head_loss_term
) # accumulate
Expand Down Expand Up @@ -970,7 +972,7 @@ class ClipPPOLoss(PPOLoss):
``samples_mc_entropy`` will control how many
samples will be used to compute this estimate.
Defaults to ``1``.
entropy_coeff: (scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
entropy_coeff: (scalar | Mapping[NesstedKey, scalar], optional): entropy multiplier when computing the total loss.
* **Scalar**: one value applied to the summed entropy of every action head.
* **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
Defaults to ``0.01``.
Expand Down Expand Up @@ -1075,7 +1077,7 @@ def __init__(
clip_epsilon: float = 0.2,
entropy_bonus: bool = True,
samples_mc_entropy: int = 1,
entropy_coeff: float | Mapping[str, float] | None = None,
entropy_coeff: float | Mapping[NestedKey, float] | None = None,
critic_coeff: float | None = None,
loss_critic_type: str = "smooth_l1",
normalize_advantage: bool = False,
Expand Down Expand Up @@ -1263,7 +1265,7 @@ class KLPENPPOLoss(PPOLoss):
``samples_mc_entropy`` will control how many
samples will be used to compute this estimate.
Defaults to ``1``.
entropy_coeff: scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
entropy_coeff: scalar | Mapping[NestedKey, scalar], optional): entropy multiplier when computing the total loss.
* **Scalar**: one value applied to the summed entropy of every action head.
* **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
Defaults to ``0.01``.
Expand Down Expand Up @@ -1369,7 +1371,7 @@ def __init__(
samples_mc_kl: int = 1,
entropy_bonus: bool = True,
samples_mc_entropy: int = 1,
entropy_coeff: float | Mapping[str, float] | None = None,
entropy_coeff: float | Mapping[NestedKey, float] | None = None,
critic_coeff: float | None = None,
loss_critic_type: str = "smooth_l1",
normalize_advantage: bool = False,
Expand Down
Loading