@@ -431,9 +431,8 @@ def make_gym_env(env, observation_type):
431431
432432 # TRY NOT TO MODIFY: start the game
433433 obs , info = env .reset (seed = args .seed )
434+ episode = 1
434435 df_cell_obs = info ["df_cell" ] if "image" in args .observation_type else None
435- n = 1
436- done_one_time = False
437436 for global_step in range (args .total_timesteps ):
438437 # ALGO LOGIC: put action logic here
439438 if global_step <= args .learning_starts :
@@ -543,12 +542,11 @@ def make_gym_env(env, observation_type):
543542 "charts/episodic_return" , info ["episode" ]["r" ], global_step
544543 )
545544 writer .add_scalar (
546- "charts/episodic_length" , info ["episode" ]["l" ], global_step
545+ "charts/episodic_length" , info ["episode" ]["l" ],global_step
547546 )
548- if global_step >= 10000 * n and global_step < 12000 * n and not done_one_time :
549- done_one_time = True
550- n += 1
551- output_video = f"name_{ args .name } _seed_{ args .seed } _step_{ global_step } .mp4"
547+ episode += 1
548+ if episode % 64 == 0 :
549+ output_video = f"name_{ args .name } _seed_{ args .seed } _step_{ episode } .mp4"
552550 obs , info = env .reset (seed = args .seed )
553551 done = False
554552 step_episode = 0
@@ -559,13 +557,14 @@ def make_gym_env(env, observation_type):
559557 actions = actions .detach ().squeeze (0 ).cpu ().numpy ()
560558 obs , _ , terminated , truncated , info = env .step (actions )
561559 step_episode += 1
562- saving_img (image_folder = image_folder + f"/{ global_step } " ,info = info ,step_episode = step_episode ,x_max = x_max ,y_max = y_max ,x_min = x_min ,y_min = y_min )
560+ saving_img (image_folder = image_folder + f"/{ episode } " ,info = info ,step_episode = step_episode ,x_max = x_max ,y_max = y_max ,x_min = x_min ,y_min = y_min )
563561 if terminated or truncated :
564- png_to_video_imageio (image_folder + f"/{ global_step } /" + output_video , image_folder + f"/{ global_step } " , fps = 10 )
562+ png_to_video_imageio (image_folder + f"/{ episode } /" + output_video , image_folder + f"/{ episode } " , fps = 10 )
565563 if args .wandb_track :
566- wandb .log ({"test/simulation_video" : wandb .Video (image_folder + f"/{ global_step } /" + output_video , fps = 10 , format = "mp4" )})
564+ wandb .log ({"test/simulation_video" : wandb .Video (image_folder + f"/{ episode } /" + output_video , fps = 10 , format = "mp4" )})
567565 obs , _ = env .reset (seed = args .seed )
568- step_episode = 0
566+
567+
569568 else :
570569 obs , _ = env .reset (seed = args .seed )
571570 env .close ()
0 commit comments