33import pathlib
44from typing import Callable
55
6+ from stable_baselines3 import PPO
67from stable_baselines3 .common .callbacks import CheckpointCallback
8+ from stable_baselines3 .common .vec_env .vec_monitor import VecMonitor
9+
710from godot_rl .core .utils import can_import
8- from godot_rl .wrappers .stable_baselines_wrapper import StableBaselinesGodotEnv
911from godot_rl .wrappers .onnx .stable_baselines_export import export_ppo_model_as_onnx
10- from stable_baselines3 import PPO
11- from stable_baselines3 .common .vec_env .vec_monitor import VecMonitor
12+ from godot_rl .wrappers .stable_baselines_wrapper import StableBaselinesGodotEnv
1213
1314# To download the env source and binary:
1415# 1. gdrl.env_from_hub -r edbeeching/godot_rl_BallChase
2829 default = "logs/sb3" ,
2930 type = str ,
3031 help = "The name of the experiment directory, in which the tensorboard logs and checkpoints (if enabled) are "
31- "getting stored."
32+ "getting stored." ,
3233)
3334parser .add_argument (
3435 "--experiment_name" ,
3536 default = "experiment" ,
3637 type = str ,
3738 help = "The name of the experiment, which will be displayed in tensorboard and "
38- "for checkpoint directory and name (if enabled)." ,
39- )
40- parser .add_argument (
41- "--seed" ,
42- type = int ,
43- default = 0 ,
44- help = "seed of the experiment"
39+ "for checkpoint directory and name (if enabled)." ,
4540)
41+ parser .add_argument ("--seed" , type = int , default = 0 , help = "seed of the experiment" )
4642parser .add_argument (
4743 "--resume_model_path" ,
4844 default = None ,
4945 type = str ,
5046 help = "The path to a model file previously saved using --save_model_path or a checkpoint saved using "
51- "--save_checkpoints_frequency. Use this to resume training or infer from a saved model." ,
47+ "--save_checkpoints_frequency. Use this to resume training or infer from a saved model." ,
5248)
5349parser .add_argument (
5450 "--save_model_path" ,
5551 default = None ,
5652 type = str ,
5753 help = "The path to use for saving the trained sb3 model after training is complete. Saved model can be used later "
58- "to resume training. Extension will be set to .zip" ,
54+ "to resume training. Extension will be set to .zip" ,
5955)
6056parser .add_argument (
6157 "--save_checkpoint_frequency" ,
6258 default = None ,
6359 type = int ,
64- help = ("If set, will save checkpoints every 'frequency' environment steps. "
65- "Requires a unique --experiment_name or --experiment_dir for each run. "
66- "Does not need --save_model_path to be set. " ),
60+ help = (
61+ "If set, will save checkpoints every 'frequency' environment steps. "
62+ "Requires a unique --experiment_name or --experiment_dir for each run. "
63+ "Does not need --save_model_path to be set. "
64+ ),
6765)
6866parser .add_argument (
6967 "--onnx_export_path" ,
7674 default = 1_000_000 ,
7775 type = int ,
7876 help = "The number of environment steps to train for, default is 1_000_000. If resuming from a saved model, "
79- "it will continue training for this amount of steps from the saved state without counting previously trained "
80- "steps" ,
77+ "it will continue training for this amount of steps from the saved state without counting previously trained "
78+ "steps" ,
8179)
8280parser .add_argument (
8381 "--inference" ,
8482 default = False ,
8583 action = "store_true" ,
8684 help = "Instead of training, it will run inference on a loaded model for --timesteps steps. "
87- "Requires --resume_model_path to be set."
85+ "Requires --resume_model_path to be set." ,
8886)
8987parser .add_argument (
9088 "--linear_lr_schedule" ,
9189 default = False ,
9290 action = "store_true" ,
9391 help = "Use a linear LR schedule for training. If set, learning rate will decrease until it reaches 0 at "
94- "--timesteps"
95- "value. Note: On resuming training, the schedule will reset. If disabled, constant LR will be used."
92+ "--timesteps"
93+ "value. Note: On resuming training, the schedule will reset. If disabled, constant LR will be used." ,
9694)
9795parser .add_argument (
9896 "--viz" ,
9997 action = "store_true" ,
10098 help = "If set, the simulation will be displayed in a window during training. Otherwise "
101- "training will run without rendering the simulation. This setting does not apply to in-editor training." ,
102- default = False
99+ "training will run without rendering the simulation. This setting does not apply to in-editor training." ,
100+ default = False ,
103101)
104102parser .add_argument ("--speedup" , default = 1 , type = int , help = "Whether to speed up the physics in the env" )
105- parser .add_argument ("--n_parallel" , default = 1 , type = int , help = "How many instances of the environment executable to "
106- "launch - requires --env_path to be set if > 1." )
103+ parser .add_argument (
104+ "--n_parallel" ,
105+ default = 1 ,
106+ type = int ,
107+ help = "How many instances of the environment executable to " "launch - requires --env_path to be set if > 1." ,
108+ )
107109args , extras = parser .parse_known_args ()
108110
109111
@@ -136,19 +138,22 @@ def close_env():
136138
137139# Prevent overwriting existing checkpoints when starting a new experiment if checkpoint saving is enabled
138140if args .save_checkpoint_frequency is not None and os .path .isdir (path_checkpoint ):
139- raise RuntimeError (abs_path_checkpoint + " folder already exists. "
140- "Use a different --experiment_dir, or --experiment_name,"
141- "or if previous checkpoints are not needed anymore, "
142- "remove the folder containing the checkpoints. " )
141+ raise RuntimeError (
142+ abs_path_checkpoint + " folder already exists. "
143+ "Use a different --experiment_dir, or --experiment_name,"
144+ "or if previous checkpoints are not needed anymore, "
145+ "remove the folder containing the checkpoints. "
146+ )
143147
144148if args .inference and args .resume_model_path is None :
145149 raise parser .error ("Using --inference requires --resume_model_path to be set." )
146150
147151if args .env_path is None and args .viz :
148152 print ("Info: Using --viz without --env_path set has no effect, in-editor training will always render." )
149153
150- env = StableBaselinesGodotEnv (env_path = args .env_path , show_window = args .viz , seed = args .seed , n_parallel = args .n_parallel ,
151- speedup = args .speedup )
154+ env = StableBaselinesGodotEnv (
155+ env_path = args .env_path , show_window = args .viz , seed = args .seed , n_parallel = args .n_parallel , speedup = args .speedup
156+ )
152157env = VecMonitor (env )
153158
154159
@@ -177,13 +182,15 @@ def func(progress_remaining: float) -> float:
177182
178183if args .resume_model_path is None :
179184 learning_rate = 0.0003 if not args .linear_lr_schedule else linear_schedule (0.0003 )
180- model : PPO = PPO ("MultiInputPolicy" ,
181- env ,
182- ent_coef = 0.0001 ,
183- verbose = 2 ,
184- n_steps = 32 ,
185- tensorboard_log = args .experiment_dir ,
186- learning_rate = learning_rate )
185+ model : PPO = PPO (
186+ "MultiInputPolicy" ,
187+ env ,
188+ ent_coef = 0.0001 ,
189+ verbose = 2 ,
190+ n_steps = 32 ,
191+ tensorboard_log = args .experiment_dir ,
192+ learning_rate = learning_rate ,
193+ )
187194else :
188195 path_zip = pathlib .Path (args .resume_model_path )
189196 print ("Loading model: " + os .path .abspath (path_zip ))
@@ -201,13 +208,16 @@ def func(progress_remaining: float) -> float:
201208 checkpoint_callback = CheckpointCallback (
202209 save_freq = (args .save_checkpoint_frequency // env .num_envs ),
203210 save_path = path_checkpoint ,
204- name_prefix = args .experiment_name
211+ name_prefix = args .experiment_name ,
205212 )
206- learn_arguments [' callback' ] = checkpoint_callback
213+ learn_arguments [" callback" ] = checkpoint_callback
207214 try :
208215 model .learn (** learn_arguments )
209216 except KeyboardInterrupt :
210- print ("Training interrupted by user. Will save if --save_model_path was used and/or export if --onnx_export_path was used." )
217+ print (
218+ """Training interrupted by user. Will save if --save_model_path was
219+ used and/or export if --onnx_export_path was used."""
220+ )
211221
212222close_env ()
213223handle_onnx_export ()
0 commit comments