Skip to content

Commit 6909a47

Browse files
committed
Fixes weight schedule dict for RND
1 parent 8818338 commit 6909a47

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

config/dummy_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ algorithm:
2323

2424
# note: This is a dictionary with a required key called "mode" which can be one of "constant" or "step".
2525
# - If "constant", then the weight is constant.
26-
# - If "step", then the weight is updated using the step scheduler. It takes additional parameters:
26+
# - If "step", then the weight is updated using the step scheduler. The dictionary should contain additional parameters:
2727
# - max_num_steps: maximum number of steps to update the weight
2828
# - final_value: final value of the weight
2929
# If None, then no scheduler is used.

rsl_rl/modules/rnd.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(
3030
state_normalization: bool = False,
3131
reward_normalization: bool = False,
3232
device: str = "cpu",
33-
weight_schedule: str | None = None,
33+
weight_schedule: dict | None = None,
3434
**kwargs,
3535
):
3636
"""Initialize the RND module.
@@ -53,8 +53,13 @@ def __init__(
5353
state_normalization: Whether to normalize the input state. Defaults to False.
5454
reward_normalization: Whether to normalize the intrinsic reward. Defaults to False.
5555
device: Device to use. Defaults to "cpu".
56-
weight_schedule: The type of schedule to use for the RND weight parameter. Must be one of ["constant", "step"].
56+
weight_schedule: The type of schedule to use for the RND weight parameter.
5757
Defaults to None, in which case the weight parameter is constant.
58+
It is a dictionary with the following keys:
59+
60+
- "mode": The type of schedule to use for the RND weight parameter.
61+
- "max_num_steps": Maximum number of steps per episode. Used for the weight schedule of type "step".
62+
- "final_value": Final value of the weight parameter. Used for the weight schedule of type "step".
5863
5964
Keyword Args:
6065
@@ -111,7 +116,7 @@ def get_intrinsic_reward(self, gated_state) -> tuple[torch.Tensor, torch.Tensor]
111116

112117
# Check the weight schedule
113118
if self.weight_scheduler is not None:
114-
self.weight = self.weight_scheduler(self.update_counter, **self.weight_scheduler_params)
119+
self.weight = self.weight_scheduler(step=self.update_counter, **self.weight_scheduler_params)
115120
else:
116121
self.weight = self.initial_weight
117122
# Scale intrinsic reward

0 commit comments

Comments
 (0)