Skip to content

Commit cd3602d

Browse files
authored
Fixes unexpected keyword argument learning_rate within RandomNetworkDistillation (#87)
* Removes learning_rate from the dict argument of RandomNetworkDistillation * Updates contributors * Uses pop to remove learning_rate rather than using a dict copy
1 parent c6834de commit cd3602d

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

CONTRIBUTORS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ Please keep the lists sorted alphabetically.
3333
* Lorenzo Terenzi
3434
* Marko Bjelonic
3535
* Matthijs van der Boon
36+
* Özhan Özen
3637
* Pascal Roth
3738
* Zhang Chong
3839
* Ziqi Fan

rsl_rl/algorithms/ppo.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,13 @@ def __init__(
5959

6060
# RND components
6161
if rnd_cfg is not None:
62+
# Extract learning rate and remove it from the original dict
63+
learning_rate = rnd_cfg.pop("learning_rate", 1e-3)
6264
# Create RND module
6365
self.rnd = RandomNetworkDistillation(device=self.device, **rnd_cfg)
6466
# Create RND optimizer
6567
params = self.rnd.predictor.parameters()
66-
self.rnd_optimizer = optim.Adam(params, lr=rnd_cfg.get("learning_rate", 1e-3))
68+
self.rnd_optimizer = optim.Adam(params, lr=learning_rate)
6769
else:
6870
self.rnd = None
6971
self.rnd_optimizer = None

0 commit comments

Comments
 (0)