2020from wasabi import Printer
2121
2222import rl_zoo3 .import_envs # noqa: F401 pylint: disable=unused-import
23- from rl_zoo3 import ALGOS , create_test_env , get_saved_hyperparams
23+ from rl_zoo3 import ALGOS , get_saved_hyperparams
2424from rl_zoo3 .exp_manager import ExperimentManager
25- from rl_zoo3 .utils import StoreDict , get_model_path
25+ from rl_zoo3 .utils import StoreDict , create_test_env , get_model_path
2626
2727msg = Printer ()
2828
@@ -277,12 +277,12 @@ def package_to_hub(
277277
278278if __name__ == "__main__" :
279279 parser = argparse .ArgumentParser ()
280- parser .add_argument ("--env" , help = "environment ID" , type = EnvironmentName , required = True )
280+ parser .add_argument ("--env" , help = "Environment ID" , type = EnvironmentName , required = True )
281281 parser .add_argument ("-f" , "--folder" , help = "Log folder" , type = str , required = True )
282282 parser .add_argument ("--algo" , help = "RL Algorithm" , type = str , required = True , choices = list (ALGOS .keys ()))
283- parser .add_argument ("-n" , "--n-timesteps" , help = "number of timesteps" , default = 1000 , type = int )
283+ parser .add_argument ("-n" , "--n-timesteps" , help = "Number of timesteps for the video recording " , default = 1000 , type = int )
284284 parser .add_argument ("--num-threads" , help = "Number of threads for PyTorch (-1 to use default)" , default = - 1 , type = int )
285- parser .add_argument ("--n-envs" , help = "number of environments" , default = 1 , type = int )
285+ parser .add_argument ("--n-envs" , help = "Number of environments" , default = 1 , type = int )
286286 parser .add_argument ("--exp-id" , help = "Experiment ID (default: 0: latest, -1: no exp folder)" , default = 0 , type = int )
287287 parser .add_argument ("--verbose" , help = "Verbose mode (0: no output, 1: INFO)" , default = 1 , type = int )
288288 parser .add_argument (
@@ -357,6 +357,12 @@ def package_to_hub(
357357 loaded_args = yaml .load (f , Loader = yaml .UnsafeLoader ) # pytype: disable=module-attr
358358 if loaded_args ["env_kwargs" ] is not None :
359359 env_kwargs = loaded_args ["env_kwargs" ]
360+
361+ # render and record video by default
362+ should_render = not args .no_render
363+ if should_render :
364+ env_kwargs .update (render_mode = "rgb_array" )
365+
360366 # overwrite with command line arguments
361367 if args .env_kwargs is not None :
362368 env_kwargs .update (args .env_kwargs )
@@ -367,7 +373,7 @@ def package_to_hub(
367373 stats_path = maybe_stats_path ,
368374 seed = args .seed ,
369375 log_dir = None ,
370- should_render = not args . no_render ,
376+ should_render = should_render ,
371377 hyperparams = deepcopy (hyperparams ),
372378 env_kwargs = env_kwargs ,
373379 )
@@ -377,6 +383,12 @@ def package_to_hub(
377383 # Dummy buffer size as we don't need memory to enjoy the trained agent
378384 kwargs .update (dict (buffer_size = 1 ))
379385
386+ # Hack due to breaking change in v1.6
387+ # handle_timeout_termination cannot be at the same time
388+ # with optimize_memory_usage
389+ if "optimize_memory_usage" in hyperparams :
390+ kwargs .update (optimize_memory_usage = False )
391+
380392 # Note: we assume that we push models using the same machine (same python version)
381393 # that trained them, if not, we would need to pass custom object as in enjoy.py
382394 custom_objects : Dict [str , Any ] = {}
@@ -411,6 +423,6 @@ def package_to_hub(
411423 n_eval_episodes = 10 ,
412424 token = None ,
413425 local_repo_path = "hub" ,
414- video_length = 1000 ,
426+ video_length = args . n_timesteps ,
415427 generate_video = not args .no_render ,
416428 )
0 commit comments