diff --git a/test/test_cost.py b/test/test_cost.py index 308a51abf5e..c159366221b 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -9659,6 +9659,10 @@ def test_ppo_composite_dists(self): make_params = TensorDictModule( lambda: ( + torch.ones(4), + torch.ones(4), + torch.ones(4), + torch.ones(4), torch.ones(4), torch.ones(4), torch.ones(4, 2), @@ -9669,8 +9673,12 @@ def test_ppo_composite_dists(self): ), in_keys=[], out_keys=[ - ("params", "gamma", "concentration"), - ("params", "gamma", "rate"), + ("params", "gamma1", "concentration"), + ("params", "gamma1", "rate"), + ("params", "gamma2", "concentration"), + ("params", "gamma2", "rate"), + ("params", "gamma3", "concentration"), + ("params", "gamma3", "rate"), ("params", "Kumaraswamy", "concentration0"), ("params", "Kumaraswamy", "concentration1"), ("params", "mixture", "logits"), @@ -9687,14 +9695,18 @@ def mixture_constructor(logits, loc, scale): dist_constructor = functools.partial( CompositeDistribution, distribution_map={ - "gamma": d.Gamma, + "gamma1": d.Gamma, + "gamma2": d.Gamma, + "gamma3": d.Gamma, "Kumaraswamy": d.Kumaraswamy, "mixture": mixture_constructor, }, name_map={ - "gamma": ("agent0", "action"), + "gamma1": ("agent0", "action", "action1", "sub_action1"), + "gamma2": ("agent0", "action", "action1", "sub_action2"), + "gamma3": ("agent0", "action", "action2"), "Kumaraswamy": ("agent1", "action"), - "mixture": ("agent2", "action"), + "mixture": ("agent2"), }, ) policy = ProbSeq( @@ -9702,9 +9714,11 @@ def mixture_constructor(logits, loc, scale): ProbabilisticTensorDictModule( in_keys=["params"], out_keys=[ - ("agent0", "action"), + ("agent0", "action", "action1", "sub_action1"), + ("agent0", "action", "action1", "sub_action2"), + ("agent0", "action", "action2"), ("agent1", "action"), - ("agent2", "action"), + ("agent2"), ], distribution_class=dist_constructor, return_log_prob=True, @@ -9739,14 +9753,18 @@ def mixture_constructor(logits, loc, scale): ppo = cls(policy, value_operator, entropy_coeff=scalar_entropy) ppo.set_keys( action=[ - ("agent0", "action"), + ("agent0", "action", "action1", "sub_action1"), + ("agent0", "action", "action1", "sub_action2"), + ("agent0", "action", "action2"), ("agent1", "action"), - ("agent2", "action"), + ("agent2"), ], sample_log_prob=[ - ("agent0", "action_log_prob"), + ("agent0", "action", "action1", "sub_action1_log_prob"), + ("agent0", "action", "action1", "sub_action2_log_prob"), + ("agent0", "action", "action2_log_prob"), ("agent1", "action_log_prob"), - ("agent2", "action_log_prob"), + ("agent2_log_prob"), ], ) loss = ppo(data) @@ -9761,21 +9779,27 @@ def mixture_constructor(logits, loc, scale): # keep per-head entropies instead of the aggregated tensor set_composite_lp_aggregate(False).set() coef_map = { - "agent0": 0.10, - "agent1": 0.05, - "agent2": 0.02, + ("agent0", "action", "action1", "sub_action1_log_prob"): 0.02, + ("agent0", "action", "action1", "sub_action2_log_prob"): 0.01, + ("agent0", "action", "action2_log_prob"): 0.01, + ("agent1", "action_log_prob"): 0.01, + "agent2_log_prob": 0.01, } ppo_weighted = cls(policy, value_operator, entropy_coeff=coef_map) ppo_weighted.set_keys( action=[ - ("agent0", "action"), + ("agent0", "action", "action1", "sub_action1"), + ("agent0", "action", "action1", "sub_action2"), + ("agent0", "action", "action2"), ("agent1", "action"), - ("agent2", "action"), + ("agent2"), ], sample_log_prob=[ - ("agent0", "action_log_prob"), + ("agent0", "action", "action1", "sub_action1_log_prob"), + ("agent0", "action", "action1", "sub_action2_log_prob"), + ("agent0", "action", "action2_log_prob"), ("agent1", "action_log_prob"), - ("agent2", "action_log_prob"), + ("agent2_log_prob"), ], ) loss = ppo_weighted(data) @@ -9786,9 +9810,11 @@ def mixture_constructor(logits, loc, scale): assert torch.isfinite(loss["loss_entropy"]) # Check individual loss is computed with the right weights expected_loss = 0.0 - for name, head_entropy in composite_entropy.items(): + for i, (_, head_entropy) in enumerate( + composite_entropy.items(include_nested=True, leaves_only=True) + ): expected_loss -= ( - coef_map[name] * _sum_td_features(head_entropy) + coef_map[list(coef_map.keys())[i]] * head_entropy ).mean() torch.testing.assert_close( loss["loss_entropy"], expected_loss, rtol=1e-5, atol=1e-7 @@ -9846,7 +9872,7 @@ def test_weighted_entropy_scalar(self): torch.testing.assert_close(out, torch.tensor(-1.0)) def test_weighted_entropy_mapping(self): - coef = {"head_0": 0.3, "head_1": 0.7} + coef = {("head_0", "action_log_prob"): 0.3, ("head_1", "action_log_prob"): 0.7} loss = self._make_entropy_loss(entropy_coeff=coef) entropy = TensorDict( { @@ -9856,7 +9882,10 @@ def test_weighted_entropy_mapping(self): [], ) out = loss._weighted_loss_entropy(entropy) - expected = -(coef["head_0"] * 1.0 + coef["head_1"] * 2.0) + expected = -( + coef[("head_0", "action_log_prob")] * 1.0 + + coef[("head_1", "action_log_prob")] * 2.0 + ) torch.testing.assert_close(out, torch.tensor(expected)) def test_weighted_entropy_mapping_missing_key(self): diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index e9e126dc282..23fb856a413 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -100,7 +100,7 @@ class PPOLoss(LossModule): ``samples_mc_entropy`` will control how many samples will be used to compute this estimate. Defaults to ``1``. - entropy_coeff: scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss. + entropy_coeff: scalar | Mapping[NestedKey, scalar], optional): entropy multiplier when computing the total loss. * **Scalar**: one value applied to the summed entropy of every action head. * **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy. Defaults to ``0.01``. @@ -351,7 +351,7 @@ def __init__( *, entropy_bonus: bool = True, samples_mc_entropy: int = 1, - entropy_coeff: float | Mapping[str, float] | None = None, + entropy_coeff: float | Mapping[NestedKey, float] | None = None, log_explained_variance: bool = True, critic_coeff: float | None = None, loss_critic_type: str = "smooth_l1", @@ -459,9 +459,7 @@ def __init__( if isinstance(entropy_coeff, Mapping): # Store the mapping for per-head coefficients - self._entropy_coeff_map = { - str(k): float(v) for k, v in entropy_coeff.items() - } + self._entropy_coeff_map = {k: float(v) for k, v in entropy_coeff.items()} # Register an empty buffer for compatibility self.register_buffer("entropy_coeff", torch.tensor(0.0)) elif isinstance(entropy_coeff, (float, int, torch.Tensor)): @@ -911,6 +909,7 @@ def _weighted_loss_entropy( If `self._entropy_coeff_map` is provided, apply per-head entropy coefficients. Otherwise, use the scalar `self.entropy_coeff`. + The entries in self._entropy_coeff_map require the full nested key to the entropy head. """ if self._entropy_coeff_map is None: if is_tensor_collection(entropy): @@ -918,7 +917,10 @@ def _weighted_loss_entropy( return -self.entropy_coeff * entropy loss_term = None # running sum over heads - for head_name, entropy_head in entropy.items(): + coeff = 0 + for head_name, entropy_head in entropy.items( + include_nested=True, leaves_only=True + ): try: coeff = self._entropy_coeff_map[head_name] except KeyError as exc: @@ -926,7 +928,7 @@ def _weighted_loss_entropy( coeff_t = torch.as_tensor( coeff, dtype=entropy_head.dtype, device=entropy_head.device ) - head_loss_term = -coeff_t * _sum_td_features(entropy_head) + head_loss_term = -coeff_t * entropy_head loss_term = ( head_loss_term if loss_term is None else loss_term + head_loss_term ) # accumulate @@ -970,7 +972,7 @@ class ClipPPOLoss(PPOLoss): ``samples_mc_entropy`` will control how many samples will be used to compute this estimate. Defaults to ``1``. - entropy_coeff: (scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss. + entropy_coeff: (scalar | Mapping[NesstedKey, scalar], optional): entropy multiplier when computing the total loss. * **Scalar**: one value applied to the summed entropy of every action head. * **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy. Defaults to ``0.01``. @@ -1075,7 +1077,7 @@ def __init__( clip_epsilon: float = 0.2, entropy_bonus: bool = True, samples_mc_entropy: int = 1, - entropy_coeff: float | Mapping[str, float] | None = None, + entropy_coeff: float | Mapping[NestedKey, float] | None = None, critic_coeff: float | None = None, loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False, @@ -1263,7 +1265,7 @@ class KLPENPPOLoss(PPOLoss): ``samples_mc_entropy`` will control how many samples will be used to compute this estimate. Defaults to ``1``. - entropy_coeff: scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss. + entropy_coeff: scalar | Mapping[NestedKey, scalar], optional): entropy multiplier when computing the total loss. * **Scalar**: one value applied to the summed entropy of every action head. * **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy. Defaults to ``0.01``. @@ -1369,7 +1371,7 @@ def __init__( samples_mc_kl: int = 1, entropy_bonus: bool = True, samples_mc_entropy: int = 1, - entropy_coeff: float | Mapping[str, float] | None = None, + entropy_coeff: float | Mapping[NestedKey, float] | None = None, critic_coeff: float | None = None, loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False,