@@ -129,7 +129,27 @@ def __init__(self,
129129 dropout = self .dropout ,)
130130
131131 critic_params = merge_dictionaries (critic_params , network_critic_params )
132-
132+ self .agent_params = {
133+ "mdp_info" : environment_info ,
134+ "actor_mu_params" : actor_mu_params ,
135+ "actor_sigma_params" : actor_sigma_params ,
136+ "actor_optimizer" : actor_optimizer ,
137+ "critic_params" : critic_params ,
138+ "batch_size" : batch_size ,
139+ "initial_replay_size" : initial_replay_size ,
140+ "max_replay_size" : max_replay_size ,
141+ "warmup_transitions" : warmup_transitions ,
142+ "tau" : tau ,
143+ "lr_alpha" : lr_alpha ,
144+ "use_log_alpha_loss" : use_log_alpha_loss ,
145+ "log_std_min" : log_std_min ,
146+ "log_std_max" : log_std_max ,
147+ "target_entropy" : target_entropy ,
148+ "critic_fit_params" : None
149+ }
150+ self ._obsprocessors = obsprocessors
151+ self .device = device
152+ self .agent_name = agent_name
133153 self .agent = SAC (
134154 mdp_info = environment_info ,
135155 actor_mu_params = actor_mu_params ,
@@ -228,6 +248,33 @@ def predict_(self, observation: np.ndarray) -> np.ndarray: #
228248 action = action .cpu ().detach ().numpy ()
229249
230250 return action
251+
252+ def update_task (self , env ):
253+ self .agent = SAC (
254+ mdp_info = env .mdp_info ,
255+ actor_mu_params = self .agent_params ["actor_mu_params" ],
256+ actor_sigma_params = self .agent_params ["actor_sigma_params" ],
257+ actor_optimizer = self .agent_params ["actor_optimizer" ],
258+ critic_params = self .agent_params ["critic_params" ],
259+ batch_size = self .agent_params ["batch_size" ],
260+ initial_replay_size = self .agent_params ["initial_replay_size" ],
261+ max_replay_size = self .agent_params ["max_replay_size" ],
262+ warmup_transitions = self .agent_params ["warmup_transitions" ],
263+ tau = self .agent_params ["tau" ],
264+ lr_alpha = self .agent_params ["lr_alpha" ],
265+ use_log_alpha_loss = self .agent_params ["use_log_alpha_loss" ],
266+ log_std_min = self .agent_params ["log_std_min" ],
267+ log_std_max = self .agent_params ["log_std_max" ],
268+ target_entropy = self .agent_params ["target_entropy" ],
269+ critic_fit_params = self .agent_params ["critic_fit_params" ]
270+ )
271+ self .obsprocessors = self ._obsprocessors
272+ super ().__init__ (
273+ environment_info = env .mdp_info ,
274+ obsprocessors = self ._obsprocessors ,
275+ device = self .device ,
276+ agent_name = self .agent_name
277+ )
231278
232279# %% ../../../nbs/30_agents/51_RL_agents/10_SAC_agents.ipynb 6
233280class SACAgent (SACBaseAgent ):
0 commit comments