@@ -30,7 +30,7 @@ def __init__(
3030 state_normalization : bool = False ,
3131 reward_normalization : bool = False ,
3232 device : str = "cpu" ,
33- weight_schedule : str | None = None ,
33+ weight_schedule : dict | None = None ,
3434 ** kwargs ,
3535 ):
3636 """Initialize the RND module.
@@ -53,8 +53,13 @@ def __init__(
5353 state_normalization: Whether to normalize the input state. Defaults to False.
5454 reward_normalization: Whether to normalize the intrinsic reward. Defaults to False.
5555 device: Device to use. Defaults to "cpu".
56- weight_schedule: The type of schedule to use for the RND weight parameter. Must be one of ["constant", "step"].
56+ weight_schedule: The type of schedule to use for the RND weight parameter.
5757 Defaults to None, in which case the weight parameter is constant.
58+ It is a dictionary with the following keys:
59+
60+ - "mode": The type of schedule to use for the RND weight parameter.
61+ - "max_num_steps": Maximum number of steps per episode. Used for the weight schedule of type "step".
62+ - "final_value": Final value of the weight parameter. Used for the weight schedule of type "step".
5863
5964 Keyword Args:
6065
@@ -111,7 +116,7 @@ def get_intrinsic_reward(self, gated_state) -> tuple[torch.Tensor, torch.Tensor]
111116
112117 # Check the weight schedule
113118 if self .weight_scheduler is not None :
114- self .weight = self .weight_scheduler (self .update_counter , ** self .weight_scheduler_params )
119+ self .weight = self .weight_scheduler (step = self .update_counter , ** self .weight_scheduler_params )
115120 else :
116121 self .weight = self .initial_weight
117122 # Scale intrinsic reward
0 commit comments