Skip to content

[Doc] Better doc on multi-head entropy #3109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 80 additions & 5 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ class PPOLoss(LossModule):
* **Scalar**: one value applied to the summed entropy of every action head.
* **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
Defaults to ``0.01``.

See :ref:`ppo_entropy_coefficients` for detailed usage examples and troubleshooting.
log_explained_variance (bool, optional): if ``True``, the explained variance of the critic
predictions w.r.t. value targets will be computed and logged as ``"explained_variance"``.
This can help monitor critic quality during training. Best possible score is 1.0, lower values are worse. Defaults to ``True``.
Expand Down Expand Up @@ -217,7 +219,7 @@ class PPOLoss(LossModule):
>>> action = spec.rand(batch)
>>> data = TensorDict({"observation": torch.randn(*batch, n_obs),
... "action": action,
... "sample_log_prob": torch.randn_like(action[..., 1]),
... "action_log_prob": torch.randn_like(action[..., 1]),
... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
... ("next", "reward"): torch.randn(*batch, 1),
Expand All @@ -227,6 +229,8 @@ class PPOLoss(LossModule):
TensorDict(
fields={
entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
explained_variance: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
kl_approx: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_critic: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_objective: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
Expand Down Expand Up @@ -279,12 +283,69 @@ class PPOLoss(LossModule):
... next_observation=torch.randn(*batch, n_obs))
>>> loss_objective.backward()

**Simple Entropy Coefficient Examples**:
>>> # Scalar entropy coefficient (default behavior)
>>> loss = PPOLoss(actor, critic, entropy_coeff=0.01)
>>>
>>> # Per-head entropy coefficients (for composite action spaces)
>>> entropy_coeff = {
... ("agent0", "action_log_prob"): 0.01, # Low exploration
... ("agent1", "action_log_prob"): 0.05, # High exploration
... }
>>> loss = PPOLoss(actor, critic, entropy_coeff=entropy_coeff)

.. note::
There is an exception regarding compatibility with non-tensordict-based modules.
If the actor network is probabilistic and uses a :class:`~tensordict.nn.distributions.CompositeDistribution`,
this class must be used with tensordicts and cannot function as a tensordict-independent module.
This is because composite action spaces inherently rely on the structured representation of data provided by
tensordicts to handle their actions.

.. _ppo_entropy_coefficients:

.. note::
**Entropy Bonus and Coefficient Management**

The entropy bonus encourages exploration by adding the negative entropy of the policy to the loss.
This can be configured in two ways:

**Scalar Coefficient (Default)**: Use a single coefficient for all action heads:
>>> loss = PPOLoss(actor, critic, entropy_coeff=0.01)

**Per-Head Coefficients**: Use different coefficients for different action components:
>>> # For a robot with movement and gripper actions
>>> entropy_coeff = {
... ("agent0", "action_log_prob"): 0.01, # Movement: low exploration
... ("agent1", "action_log_prob"): 0.05, # Gripper: high exploration
... }
>>> loss = PPOLoss(actor, critic, entropy_coeff=entropy_coeff)

**Key Requirements**: When using per-head coefficients, you must provide the full nested key
path to each action head's log probability (e.g., `("agent0", "action_log_prob")`).

**Monitoring Entropy Loss**:

When using composite action spaces, the loss output includes:
- `"entropy"`: Summed entropy across all action heads (for logging)
- `"composite_entropy"`: Individual entropy values for each action head
- `"loss_entropy"`: The weighted entropy loss term

Example output:
>>> result = loss(data)
>>> print(result["entropy"]) # Total entropy: 2.34
>>> print(result["composite_entropy"]) # Per-head: {"movement": 1.2, "gripper": 1.14}
>>> print(result["loss_entropy"]) # Weighted loss: -0.0234

**Common Issues**:

**KeyError: "Missing entropy coeff for head 'head_name'"**:
- Ensure you provide coefficients for ALL action heads
- Use full nested keys: `("head_name", "action_log_prob")`
- Check that your action space structure matches the coefficient mapping

**Incorrect Entropy Calculation**:
- Call `set_composite_lp_aggregate(False).set()` before creating your policy
- Verify that your action space uses :class:`~tensordict.nn.distributions.CompositeDistribution`
"""

@dataclass
Expand Down Expand Up @@ -911,27 +972,37 @@ def _weighted_loss_entropy(
Otherwise, use the scalar `self.entropy_coeff`.
The entries in self._entropy_coeff_map require the full nested key to the entropy head.
"""
# Mode 1: Use scalar entropy coefficient (default behavior)
if self._entropy_coeff_map is None:
# If entropy is a TensorDict (composite action space), sum all entropy values
if is_tensor_collection(entropy):
entropy = _sum_td_features(entropy)
# Apply scalar coefficient: loss = -coeff * entropy (negative for maximization)
return -self.entropy_coeff * entropy

loss_term = None # running sum over heads
coeff = 0
# Mode 2: Use per-head entropy coefficients (for composite action spaces)
loss_term = None # Initialize running sum over action heads
coeff = 0 # Placeholder for coefficient value
# Iterate through all entropy heads in the composite action space
for head_name, entropy_head in entropy.items(
include_nested=True, leaves_only=True
):
try:
# Look up the coefficient for this specific action head
coeff = self._entropy_coeff_map[head_name]
except KeyError as exc:
# Provide clear error message if coefficient mapping is incomplete
raise KeyError(f"Missing entropy coeff for head '{head_name}'") from exc
# Convert coefficient to tensor with matching dtype and device
coeff_t = torch.as_tensor(
coeff, dtype=entropy_head.dtype, device=entropy_head.device
)
# Compute weighted loss for this head: -coeff * entropy
head_loss_term = -coeff_t * entropy_head
# Accumulate loss terms across all heads
loss_term = (
head_loss_term if loss_term is None else loss_term + head_loss_term
) # accumulate
)

return loss_term

Expand Down Expand Up @@ -972,10 +1043,12 @@ class ClipPPOLoss(PPOLoss):
``samples_mc_entropy`` will control how many
samples will be used to compute this estimate.
Defaults to ``1``.
entropy_coeff: (scalar | Mapping[NesstedKey, scalar], optional): entropy multiplier when computing the total loss.
entropy_coeff: (scalar | Mapping[NestedKey, scalar], optional): entropy multiplier when computing the total loss.
* **Scalar**: one value applied to the summed entropy of every action head.
* **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
Defaults to ``0.01``.

See :ref:`ppo_entropy_coefficients` for detailed usage examples and troubleshooting.
critic_coeff (scalar, optional): critic loss multiplier when computing the total
loss. Defaults to ``1.0``. Set ``critic_coeff`` to ``None`` to exclude the value
loss from the forward outputs.
Expand Down Expand Up @@ -1269,6 +1342,8 @@ class KLPENPPOLoss(PPOLoss):
* **Scalar**: one value applied to the summed entropy of every action head.
* **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
Defaults to ``0.01``.

See :ref:`ppo_entropy_coefficients` for detailed usage examples and troubleshooting.
critic_coeff (scalar, optional): critic loss multiplier when computing the total
loss. Defaults to ``1.0``.
loss_critic_type (str, optional): loss function for the value discrepancy.
Expand Down
Loading