diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 23fb856a413..e47d5baf30c 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -104,6 +104,8 @@ class PPOLoss(LossModule): * **Scalar**: one value applied to the summed entropy of every action head. * **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy. Defaults to ``0.01``. + + See :ref:`ppo_entropy_coefficients` for detailed usage examples and troubleshooting. log_explained_variance (bool, optional): if ``True``, the explained variance of the critic predictions w.r.t. value targets will be computed and logged as ``"explained_variance"``. This can help monitor critic quality during training. Best possible score is 1.0, lower values are worse. Defaults to ``True``. @@ -217,7 +219,7 @@ class PPOLoss(LossModule): >>> action = spec.rand(batch) >>> data = TensorDict({"observation": torch.randn(*batch, n_obs), ... "action": action, - ... "sample_log_prob": torch.randn_like(action[..., 1]), + ... "action_log_prob": torch.randn_like(action[..., 1]), ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1), @@ -227,6 +229,8 @@ class PPOLoss(LossModule): TensorDict( fields={ entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + explained_variance: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + kl_approx: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_critic: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_objective: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, @@ -279,12 +283,69 @@ class PPOLoss(LossModule): ... next_observation=torch.randn(*batch, n_obs)) >>> loss_objective.backward() + **Simple Entropy Coefficient Examples**: + >>> # Scalar entropy coefficient (default behavior) + >>> loss = PPOLoss(actor, critic, entropy_coeff=0.01) + >>> + >>> # Per-head entropy coefficients (for composite action spaces) + >>> entropy_coeff = { + ... ("agent0", "action_log_prob"): 0.01, # Low exploration + ... ("agent1", "action_log_prob"): 0.05, # High exploration + ... } + >>> loss = PPOLoss(actor, critic, entropy_coeff=entropy_coeff) + .. note:: There is an exception regarding compatibility with non-tensordict-based modules. If the actor network is probabilistic and uses a :class:`~tensordict.nn.distributions.CompositeDistribution`, this class must be used with tensordicts and cannot function as a tensordict-independent module. This is because composite action spaces inherently rely on the structured representation of data provided by tensordicts to handle their actions. + + .. _ppo_entropy_coefficients: + + .. note:: + **Entropy Bonus and Coefficient Management** + + The entropy bonus encourages exploration by adding the negative entropy of the policy to the loss. + This can be configured in two ways: + + **Scalar Coefficient (Default)**: Use a single coefficient for all action heads: + >>> loss = PPOLoss(actor, critic, entropy_coeff=0.01) + + **Per-Head Coefficients**: Use different coefficients for different action components: + >>> # For a robot with movement and gripper actions + >>> entropy_coeff = { + ... ("agent0", "action_log_prob"): 0.01, # Movement: low exploration + ... ("agent1", "action_log_prob"): 0.05, # Gripper: high exploration + ... } + >>> loss = PPOLoss(actor, critic, entropy_coeff=entropy_coeff) + + **Key Requirements**: When using per-head coefficients, you must provide the full nested key + path to each action head's log probability (e.g., `("agent0", "action_log_prob")`). + + **Monitoring Entropy Loss**: + + When using composite action spaces, the loss output includes: + - `"entropy"`: Summed entropy across all action heads (for logging) + - `"composite_entropy"`: Individual entropy values for each action head + - `"loss_entropy"`: The weighted entropy loss term + + Example output: + >>> result = loss(data) + >>> print(result["entropy"]) # Total entropy: 2.34 + >>> print(result["composite_entropy"]) # Per-head: {"movement": 1.2, "gripper": 1.14} + >>> print(result["loss_entropy"]) # Weighted loss: -0.0234 + + **Common Issues**: + + **KeyError: "Missing entropy coeff for head 'head_name'"**: + - Ensure you provide coefficients for ALL action heads + - Use full nested keys: `("head_name", "action_log_prob")` + - Check that your action space structure matches the coefficient mapping + + **Incorrect Entropy Calculation**: + - Call `set_composite_lp_aggregate(False).set()` before creating your policy + - Verify that your action space uses :class:`~tensordict.nn.distributions.CompositeDistribution` """ @dataclass @@ -911,27 +972,37 @@ def _weighted_loss_entropy( Otherwise, use the scalar `self.entropy_coeff`. The entries in self._entropy_coeff_map require the full nested key to the entropy head. """ + # Mode 1: Use scalar entropy coefficient (default behavior) if self._entropy_coeff_map is None: + # If entropy is a TensorDict (composite action space), sum all entropy values if is_tensor_collection(entropy): entropy = _sum_td_features(entropy) + # Apply scalar coefficient: loss = -coeff * entropy (negative for maximization) return -self.entropy_coeff * entropy - loss_term = None # running sum over heads - coeff = 0 + # Mode 2: Use per-head entropy coefficients (for composite action spaces) + loss_term = None # Initialize running sum over action heads + coeff = 0 # Placeholder for coefficient value + # Iterate through all entropy heads in the composite action space for head_name, entropy_head in entropy.items( include_nested=True, leaves_only=True ): try: + # Look up the coefficient for this specific action head coeff = self._entropy_coeff_map[head_name] except KeyError as exc: + # Provide clear error message if coefficient mapping is incomplete raise KeyError(f"Missing entropy coeff for head '{head_name}'") from exc + # Convert coefficient to tensor with matching dtype and device coeff_t = torch.as_tensor( coeff, dtype=entropy_head.dtype, device=entropy_head.device ) + # Compute weighted loss for this head: -coeff * entropy head_loss_term = -coeff_t * entropy_head + # Accumulate loss terms across all heads loss_term = ( head_loss_term if loss_term is None else loss_term + head_loss_term - ) # accumulate + ) return loss_term @@ -972,10 +1043,12 @@ class ClipPPOLoss(PPOLoss): ``samples_mc_entropy`` will control how many samples will be used to compute this estimate. Defaults to ``1``. - entropy_coeff: (scalar | Mapping[NesstedKey, scalar], optional): entropy multiplier when computing the total loss. + entropy_coeff: (scalar | Mapping[NestedKey, scalar], optional): entropy multiplier when computing the total loss. * **Scalar**: one value applied to the summed entropy of every action head. * **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy. Defaults to ``0.01``. + + See :ref:`ppo_entropy_coefficients` for detailed usage examples and troubleshooting. critic_coeff (scalar, optional): critic loss multiplier when computing the total loss. Defaults to ``1.0``. Set ``critic_coeff`` to ``None`` to exclude the value loss from the forward outputs. @@ -1269,6 +1342,8 @@ class KLPENPPOLoss(PPOLoss): * **Scalar**: one value applied to the summed entropy of every action head. * **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy. Defaults to ``0.01``. + + See :ref:`ppo_entropy_coefficients` for detailed usage examples and troubleshooting. critic_coeff (scalar, optional): critic loss multiplier when computing the total loss. Defaults to ``1.0``. loss_critic_type (str, optional): loss function for the value discrepancy.