Skip to content

Commit 4db6eed

Browse files
author
Juan de los Rios
committed
fix typing
1 parent 0f0f654 commit 4db6eed

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

torchrl/objectives/ppo.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ class PPOLoss(LossModule):
100100
``samples_mc_entropy`` will control how many
101101
samples will be used to compute this estimate.
102102
Defaults to ``1``.
103-
entropy_coeff: scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
103+
entropy_coeff: scalar | Mapping[NestedKey, scalar], optional): entropy multiplier when computing the total loss.
104104
* **Scalar**: one value applied to the summed entropy of every action head.
105105
* **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
106106
Defaults to ``0.01``.
@@ -351,7 +351,7 @@ def __init__(
351351
*,
352352
entropy_bonus: bool = True,
353353
samples_mc_entropy: int = 1,
354-
entropy_coeff: float | NestedKey | None = None,
354+
entropy_coeff: float | Mapping[NestedKey, float] | None = None,
355355
log_explained_variance: bool = True,
356356
critic_coeff: float | None = None,
357357
loss_critic_type: str = "smooth_l1",
@@ -459,9 +459,7 @@ def __init__(
459459

460460
if isinstance(entropy_coeff, Mapping):
461461
# Store the mapping for per-head coefficients
462-
self._entropy_coeff_map = {
463-
str(k): float(v) for k, v in entropy_coeff.items()
464-
}
462+
self._entropy_coeff_map = {k: float(v) for k, v in entropy_coeff.items()}
465463
# Register an empty buffer for compatibility
466464
self.register_buffer("entropy_coeff", torch.tensor(0.0))
467465
elif isinstance(entropy_coeff, (float, int, torch.Tensor)):
@@ -974,7 +972,7 @@ class ClipPPOLoss(PPOLoss):
974972
``samples_mc_entropy`` will control how many
975973
samples will be used to compute this estimate.
976974
Defaults to ``1``.
977-
entropy_coeff: (scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
975+
entropy_coeff: (scalar | Mapping[NesstedKey, scalar], optional): entropy multiplier when computing the total loss.
978976
* **Scalar**: one value applied to the summed entropy of every action head.
979977
* **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
980978
Defaults to ``0.01``.
@@ -1079,7 +1077,7 @@ def __init__(
10791077
clip_epsilon: float = 0.2,
10801078
entropy_bonus: bool = True,
10811079
samples_mc_entropy: int = 1,
1082-
entropy_coeff: float | Mapping[str | tuple | list, float] | None = None,
1080+
entropy_coeff: float | Mapping[NestedKey, float] | None = None,
10831081
critic_coeff: float | None = None,
10841082
loss_critic_type: str = "smooth_l1",
10851083
normalize_advantage: bool = False,
@@ -1267,7 +1265,7 @@ class KLPENPPOLoss(PPOLoss):
12671265
``samples_mc_entropy`` will control how many
12681266
samples will be used to compute this estimate.
12691267
Defaults to ``1``.
1270-
entropy_coeff: scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
1268+
entropy_coeff: scalar | Mapping[NestedKey, scalar], optional): entropy multiplier when computing the total loss.
12711269
* **Scalar**: one value applied to the summed entropy of every action head.
12721270
* **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
12731271
Defaults to ``0.01``.
@@ -1373,7 +1371,7 @@ def __init__(
13731371
samples_mc_kl: int = 1,
13741372
entropy_bonus: bool = True,
13751373
samples_mc_entropy: int = 1,
1376-
entropy_coeff: float | Mapping[str | tuple | list, float] | None = None,
1374+
entropy_coeff: float | Mapping[NestedKey, float] | None = None,
13771375
critic_coeff: float | None = None,
13781376
loss_critic_type: str = "smooth_l1",
13791377
normalize_advantage: bool = False,

0 commit comments

Comments
 (0)