Skip to content

Commit 0f0f654

Browse files
author
Juan de los Rios
committed
use NestedKey for type
1 parent d2d65f4 commit 0f0f654

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

torchrl/objectives/ppo.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def __init__(
351351
*,
352352
entropy_bonus: bool = True,
353353
samples_mc_entropy: int = 1,
354-
entropy_coeff: float | Mapping[str | tuple | list, float] | None = None,
354+
entropy_coeff: float | NestedKey | None = None,
355355
log_explained_variance: bool = True,
356356
critic_coeff: float | None = None,
357357
loss_critic_type: str = "smooth_l1",
@@ -460,8 +460,7 @@ def __init__(
460460
if isinstance(entropy_coeff, Mapping):
461461
# Store the mapping for per-head coefficients
462462
self._entropy_coeff_map = {
463-
(tuple(k) if isinstance(k, list) else k): float(v)
464-
for k, v in entropy_coeff.items()
463+
str(k): float(v) for k, v in entropy_coeff.items()
465464
}
466465
# Register an empty buffer for compatibility
467466
self.register_buffer("entropy_coeff", torch.tensor(0.0))

0 commit comments

Comments
 (0)