22import numpy as np
33from mlagents .envs .brain import BrainInfo
44
5+ import tensorflow as tf
6+
57from mlagents .trainers .buffer import Buffer
68from mlagents .trainers .components .reward_signals import RewardSignal , RewardSignalResult
79from mlagents .trainers .components .reward_signals .curiosity .model import CuriosityModel
810from mlagents .trainers .tf_policy import TFPolicy
11+ from mlagents .trainers .models import LearningModel
912
1013
1114class CuriosityRewardSignal (RewardSignal ):
1215 def __init__ (
1316 self ,
1417 policy : TFPolicy ,
18+ policy_model : LearningModel ,
1519 strength : float ,
1620 gamma : float ,
1721 encoding_size : int = 128 ,
1822 learning_rate : float = 3e-4 ,
19- num_epoch : int = 3 ,
2023 ):
2124 """
2225 Creates the Curiosity reward generator
@@ -26,18 +29,20 @@ def __init__(
2629 :param gamma: The time discounting factor used for this reward.
2730 :param encoding_size: The size of the hidden encoding layer for the ICM
2831 :param learning_rate: The learning rate for the ICM.
29- :param num_epoch: The number of epochs to train over the training buffer for the ICM.
3032 """
31- super ().__init__ (policy , strength , gamma )
33+ super ().__init__ (policy , policy_model , strength , gamma )
3234 self .model = CuriosityModel (
33- policy . model , encoding_size = encoding_size , learning_rate = learning_rate
35+ policy_model , encoding_size = encoding_size , learning_rate = learning_rate
3436 )
35- self .num_epoch = num_epoch
3637 self .use_terminal_states = False
3738 self .update_dict = {
38- "forward_loss" : self .model .forward_loss ,
39- "inverse_loss" : self .model .inverse_loss ,
40- "update" : self .model .update_batch ,
39+ "curiosity_forward_loss" : self .model .forward_loss ,
40+ "curiosity_inverse_loss" : self .model .inverse_loss ,
41+ "curiosity_update" : self .model .update_batch ,
42+ }
43+ self .stats_name_to_update_name = {
44+ "Losses/Curiosity Forward Loss" : "curiosity_forward_loss" ,
45+ "Losses/Curiosity Inverse Loss" : "curiosity_inverse_loss" ,
4146 }
4247 self .has_updated = False
4348
@@ -89,67 +94,39 @@ def check_config(
8994 param_keys = ["strength" , "gamma" , "encoding_size" ]
9095 super ().check_config (config_dict , param_keys )
9196
92- def update (self , update_buffer : Buffer , num_sequences : int ) -> Dict [str , float ]:
93- """
94- Updates Curiosity model using training buffer. Divides training buffer into mini batches and performs
95- gradient descent.
96- :param update_buffer: Update buffer from which to pull data from.
97- :param num_sequences: Number of sequences in the update buffer.
98- :return: Dict of stats that should be reported to Tensorboard.
99- """
100- forward_total : List [float ] = []
101- inverse_total : List [float ] = []
102- for _ in range (self .num_epoch ):
103- update_buffer .shuffle (sequence_length = self .policy .sequence_length )
104- buffer = update_buffer
105- for l in range (len (update_buffer ["actions" ]) // num_sequences ):
106- start = l * num_sequences
107- end = (l + 1 ) * num_sequences
108- run_out_curio = self ._update_batch (
109- buffer .make_mini_batch (start , end ), num_sequences
110- )
111- inverse_total .append (run_out_curio ["inverse_loss" ])
112- forward_total .append (run_out_curio ["forward_loss" ])
113-
114- update_stats = {
115- "Losses/Curiosity Forward Loss" : np .mean (forward_total ),
116- "Losses/Curiosity Inverse Loss" : np .mean (inverse_total ),
117- }
118- return update_stats
119-
120- def _update_batch (
121- self , mini_batch : Dict [str , np .ndarray ], num_sequences : int
122- ) -> Dict [str , float ]:
97+ def prepare_update (
98+ self ,
99+ policy_model : LearningModel ,
100+ mini_batch : Dict [str , np .ndarray ],
101+ num_sequences : int ,
102+ ) -> Dict [tf .Tensor , Any ]:
123103 """
124- Updates model using buffer .
104+ Prepare for update and get feed_dict .
125105 :param num_sequences: Number of trajectories in batch.
126106 :param mini_batch: Experience batch.
127- :return: Output from update process .
107+ :return: Feed_dict needed for update .
128108 """
129109 feed_dict = {
130- self . policy . model .batch_size : num_sequences ,
131- self . policy . model .sequence_length : self .policy .sequence_length ,
132- self . policy . model .mask_input : mini_batch ["masks" ],
133- self . policy . model .advantage : mini_batch ["advantages" ],
134- self . policy . model .all_old_log_probs : mini_batch ["action_probs" ],
110+ policy_model .batch_size : num_sequences ,
111+ policy_model .sequence_length : self .policy .sequence_length ,
112+ policy_model .mask_input : mini_batch ["masks" ],
113+ policy_model .advantage : mini_batch ["advantages" ],
114+ policy_model .all_old_log_probs : mini_batch ["action_probs" ],
135115 }
136116 if self .policy .use_continuous_act :
137- feed_dict [self . policy . model .output_pre ] = mini_batch ["actions_pre" ]
117+ feed_dict [policy_model .output_pre ] = mini_batch ["actions_pre" ]
138118 else :
139- feed_dict [self . policy . model .action_holder ] = mini_batch ["actions" ]
119+ feed_dict [policy_model .action_holder ] = mini_batch ["actions" ]
140120 if self .policy .use_vec_obs :
141- feed_dict [self . policy . model .vector_in ] = mini_batch ["vector_obs" ]
121+ feed_dict [policy_model .vector_in ] = mini_batch ["vector_obs" ]
142122 feed_dict [self .model .next_vector_in ] = mini_batch ["next_vector_in" ]
143- if self .policy .model .vis_obs_size > 0 :
144- for i , _ in enumerate (self .policy .model .visual_in ):
145- feed_dict [self .policy .model .visual_in [i ]] = mini_batch [
146- "visual_obs%d" % i
147- ]
148- for i , _ in enumerate (self .policy .model .visual_in ):
123+ if policy_model .vis_obs_size > 0 :
124+ for i , _ in enumerate (policy_model .visual_in ):
125+ feed_dict [policy_model .visual_in [i ]] = mini_batch ["visual_obs%d" % i ]
126+ for i , _ in enumerate (policy_model .visual_in ):
149127 feed_dict [self .model .next_visual_in [i ]] = mini_batch [
150128 "next_visual_obs%d" % i
151129 ]
152130
153131 self .has_updated = True
154- run_out = self .policy ._execute_model (feed_dict , self .update_dict )
155- return run_out
132+ return feed_dict
0 commit comments