4444 auto_torch_device ,
4545 init_logging ,
4646)
47-
47+ from tqdm import tqdm
4848
4949def ensure_primitive (maybe_tensor ):
5050 if isinstance (maybe_tensor , np .ndarray ):
@@ -152,7 +152,7 @@ def main(cfg: TrainPipelineConfig):
152152 ds_advantage = {} # per-dataset advantages
153153 with torch .inference_mode ():
154154 # First pass to get the values
155- for batch in dataloader :
155+ for batch in tqdm ( dataloader , desc = "Computing values" ) :
156156 for key , value in batch .items ():
157157 if isinstance (value , torch .Tensor ):
158158 batch [key ] = value .to (device )
@@ -172,15 +172,15 @@ def main(cfg: TrainPipelineConfig):
172172 success = success ,
173173 n_steps_look_ahead = cfg .policy .reward_config .N_steps_look_ahead ,
174174 episode_end_idx = episode_end_idx ,
175- max_episode_length = cfg .policy .reward_config .reward_normalizer ,
175+ reward_normalizer = cfg .policy .reward_config .reward_normalizer ,
176176 current_idx = current_idx ,
177177 c_neg = cfg .policy .reward_config .C_neg ,
178178 )
179179
180- values [(episode_index , current_idx )] = {"v0" : v0 , "reward" : reward }
180+ values [(episode_index . item () , current_idx . item () )] = {"v0" : v0 , "reward" : reward }
181181
182182 # Second pass to compute the advantages
183- for batch in dataloader :
183+ for batch in tqdm ( dataloader , desc = "Computing advantages" ) :
184184 for episode_index , current_idx , timestamp in zip (
185185 batch ["episode_index" ],
186186 batch ["current_idx" ],
@@ -192,13 +192,12 @@ def main(cfg: TrainPipelineConfig):
192192 )
193193 # check if the value for the next n_steps_look_ahead steps is available, else set it to 0
194194 look_ahead_idx = current_idx + cfg .policy .reward_config .N_steps_look_ahead
195- vn = values .get ((episode_index , look_ahead_idx ), _default0 )["v0" ]
196- reward = values .get ((episode_index , current_idx ), _default0 )["reward" ]
197- v0 = values .get ((episode_index , current_idx ), _default0 )["v0" ]
195+ vn = values .get ((episode_index . item () , look_ahead_idx . item () ), _default0 )["v0" ]
196+ reward = values .get ((episode_index . item () , current_idx . item () ), _default0 )["reward" ]
197+ v0 = values .get ((episode_index . item () , current_idx . item () ), _default0 )["v0" ]
198198 advantage = ensure_primitive (reward + vn - v0 )
199- advantages .append (advantage )
200- ds_advantage [(episode_index , timestamp )] = advantage
201-
199+ advantages .append (advantage .item ())
200+ ds_advantage [(episode_index .item (), timestamp .item ())] = advantage .item ()
202201 # Convert tuple keys to strings for JSON serialization
203202 advantage_data_json = {f"{ ep_idx } ,{ ts } " : val for (ep_idx , ts ), val in ds_advantage .items ()}
204203
0 commit comments