@@ -31,7 +31,6 @@ def __init__(
3131 reward_normalization : bool = False ,
3232 device : str = "cpu" ,
3333 weight_schedule : dict | None = None ,
34- ** kwargs ,
3534 ):
3635 """Initialize the RND module.
3736
@@ -58,13 +57,19 @@ def __init__(
5857 It is a dictionary with the following keys:
5958
6059 - "mode": The type of schedule to use for the RND weight parameter.
61- - "max_num_steps": Maximum number of steps per episode. Used for the weight schedule of type "step".
62- - "final_value": Final value of the weight parameter. Used for the weight schedule of type "step".
60+ - "constant": Constant weight schedule.
61+ - "step": Step weight schedule.
62+ - "linear": Linear weight schedule.
6363
64- Keyword Args :
64+ For the "step" weight schedule, the following parameters are required :
6565
66- max_num_steps (int): Maximum number of steps per episode. Used for the weight schedule of type "step".
67- final_value (float): Final value of the weight parameter. Used for the weight schedule of type "step".
66+ - "final_step": The step at which the weight parameter is set to the final value.
67+ - "final_value": The final value of the weight parameter.
68+
69+ For the "linear" weight schedule, the following parameters are required:
70+ - "initial_step": The step at which the weight parameter is set to the initial value.
71+ - "final_step": The step at which the weight parameter is set to the final value.
72+ - "final_value": The final value of the weight parameter.
6873 """
6974 # initialize parent class
7075 super ().__init__ ()
@@ -79,7 +84,7 @@ def __init__(
7984
8085 # Normalization of input gates
8186 if state_normalization :
82- self .state_normalizer = EmpiricalNormalization (shape = [self .num_obs ], until = 1.0e8 ).to (self .device )
87+ self .state_normalizer = EmpiricalNormalization (shape = [self .num_states ], until = 1.0e8 ).to (self .device )
8388 else :
8489 self .state_normalizer = torch .nn .Identity ()
8590 # Normalization of intrinsic reward
@@ -101,14 +106,14 @@ def __init__(
101106 self .predictor = self ._build_mlp (num_states , predictor_hidden_dims , num_outputs , activation ).to (self .device )
102107 self .target = self ._build_mlp (num_states , target_hidden_dims , num_outputs , activation ).to (self .device )
103108
104- def get_intrinsic_reward (self , gated_state ) -> tuple [torch .Tensor , torch .Tensor ]:
109+ def get_intrinsic_reward (self , rnd_state ) -> tuple [torch .Tensor , torch .Tensor ]:
105110 # note: the counter is updated number of env steps per learning iteration
106111 self .update_counter += 1
107- # Normalize gated state
108- gated_state = self .state_normalizer (gated_state )
109- # Obtain the embedding of the gated state from the target and predictor networks
110- target_embedding = self .target (gated_state ).detach ()
111- predictor_embedding = self .predictor (gated_state ).detach ()
112+ # Normalize rnd state
113+ rnd_state = self .state_normalizer (rnd_state )
114+ # Obtain the embedding of the rnd state from the target and predictor networks
115+ target_embedding = self .target (rnd_state ).detach ()
116+ predictor_embedding = self .predictor (rnd_state ).detach ()
112117 # Compute the intrinsic reward as the distance between the embeddings
113118 intrinsic_reward = torch .linalg .norm (target_embedding - predictor_embedding , dim = 1 )
114119 # Normalize intrinsic reward
@@ -122,7 +127,7 @@ def get_intrinsic_reward(self, gated_state) -> tuple[torch.Tensor, torch.Tensor]
122127 # Scale intrinsic reward
123128 intrinsic_reward *= self .weight
124129
125- return intrinsic_reward , gated_state
130+ return intrinsic_reward , rnd_state
126131
127132 def forward (self , * args , ** kwargs ):
128133 raise RuntimeError ("Forward method is not implemented. Use get_intrinsic_reward instead." )
@@ -171,8 +176,16 @@ def _build_mlp(input_dims: int, hidden_dims: list[int], output_dims: int, activa
171176 Different weight schedules.
172177 """
173178
174- def _constant_weight_schedule (self , step , ** kwargs ):
179+ def _constant_weight_schedule (self , step : int , ** kwargs ):
175180 return self .initial_weight
176181
177- def _step_weight_schedule (self , step , max_num_steps : int , final_value : float , ** kwargs ):
178- return self .initial_weight if step < max_num_steps else final_value
182+ def _step_weight_schedule (self , step : int , final_step : int , final_value : float , ** kwargs ):
183+ return self .initial_weight if step < final_step else final_value
184+
185+ def _linear_weight_schedule (self , step : int , initial_step : int , final_step : int , final_value : float , ** kwargs ):
186+ if step < initial_step :
187+ return self .initial_weight
188+ elif step > final_step :
189+ return final_value
190+ else :
191+ return self .initial_weight + (final_value - self .initial_weight ) * (step - initial_step ) / (final_step - initial_step )
0 commit comments