1- # Copyright (c) 2024-2025, Muammer Bay (LycheeAI), Louis Le Lay
2- # All rights reserved.
3- #
4- # SPDX-License-Identifier: BSD-3-Clause
5- #
6- # Copyright (c) 2022-2025, The Isaac Lab Project Developers.
1+ # Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
72# All rights reserved.
83#
94# SPDX-License-Identifier: BSD-3-Clause
105
11- """Script to train RL agent with RSL-RL."""
6+ """Script to play a checkpoint if an RL agent from RSL-RL."""
127
138"""Launch Isaac Sim Simulator first."""
149
2419parser = argparse .ArgumentParser (description = "Train an RL agent with RSL-RL." )
2520parser .add_argument ("--video" , action = "store_true" , default = False , help = "Record videos during training." )
2621parser .add_argument ("--video_length" , type = int , default = 200 , help = "Length of the recorded video (in steps)." )
27- parser .add_argument ("--video_interval" , type = int , default = 2000 , help = "Interval between video recordings (in steps)." )
22+ parser .add_argument (
23+ "--disable_fabric" , action = "store_true" , default = False , help = "Disable fabric and use USD I/O operations."
24+ )
2825parser .add_argument ("--num_envs" , type = int , default = None , help = "Number of environments to simulate." )
2926parser .add_argument ("--task" , type = str , default = None , help = "Name of the task." )
3027parser .add_argument (
3128 "--agent" , type = str , default = "rsl_rl_cfg_entry_point" , help = "Name of the RL agent configuration entry point."
3229)
3330parser .add_argument ("--seed" , type = int , default = None , help = "Seed used for the environment" )
34- parser .add_argument ("--max_iterations" , type = int , default = None , help = "RL Policy training iterations." )
3531parser .add_argument (
36- "--distributed" , action = "store_true" , default = False , help = "Run training with multiple GPUs or nodes."
32+ "--use_pretrained_checkpoint" ,
33+ action = "store_true" ,
34+ help = "Use the pre-trained checkpoint from Nucleus." ,
3735)
38- parser .add_argument ("--export_io_descriptors " , action = "store_true" , default = False , help = "Export IO descriptors ." )
36+ parser .add_argument ("--real-time " , action = "store_true" , default = False , help = "Run in real-time, if possible ." )
3937# append RSL-RL cli arguments
4038cli_args .add_rsl_rl_args (parser )
4139# append AppLauncher cli args
4240AppLauncher .add_app_launcher_args (parser )
41+ # parse the arguments
4342args_cli , hydra_args = parser .parse_known_args ()
44-
4543# always enable cameras to record video
4644if args_cli .video :
4745 args_cli .enable_cameras = True
5351app_launcher = AppLauncher (args_cli )
5452simulation_app = app_launcher .app
5553
56- """Check for minimum supported RSL-RL version."""
57-
58- import importlib .metadata as metadata
59- import platform
60-
61- from packaging import version
62-
63- # check minimum supported rsl-rl version
64- RSL_RL_VERSION = "3.0.1"
65- installed_version = metadata .version ("rsl-rl-lib" )
66- if version .parse (installed_version ) < version .parse (RSL_RL_VERSION ):
67- if platform .system () == "Windows" :
68- cmd = [r".\isaaclab.bat" , "-p" , "-m" , "pip" , "install" , f"rsl-rl-lib=={ RSL_RL_VERSION } " ]
69- else :
70- cmd = ["./isaaclab.sh" , "-p" , "-m" , "pip" , "install" , f"rsl-rl-lib=={ RSL_RL_VERSION } " ]
71- print (
72- f"Please install the correct version of RSL-RL.\n Existing version is: '{ installed_version } '"
73- f" and required version is: '{ RSL_RL_VERSION } '.\n To install the correct version, run:"
74- f"\n \n \t { ' ' .join (cmd )} \n "
75- )
76- exit (1 )
77-
7854"""Rest everything follows."""
7955
80- import os
81- from datetime import datetime
82-
8356import gymnasium as gym
84- import isaaclab_tasks # noqa: F401
85- import omni
86- import SO_100 .tasks # noqa: F401
57+ import os
58+ import time
8759import torch
60+
61+ from rsl_rl .runners import DistillationRunner , OnPolicyRunner
62+
8863from isaaclab .envs import (
8964 DirectMARLEnv ,
9065 DirectMARLEnvCfg ,
9166 DirectRLEnvCfg ,
9267 ManagerBasedRLEnvCfg ,
9368 multi_agent_to_single_agent ,
9469)
70+ from isaaclab .utils .assets import retrieve_file_path
9571from isaaclab .utils .dict import print_dict
96- from isaaclab .utils .io import dump_pickle , dump_yaml
97- from isaaclab_rl .rsl_rl import RslRlBaseRunnerCfg , RslRlVecEnvWrapper
72+ from isaaclab .utils .pretrained_checkpoint import get_published_pretrained_checkpoint
73+
74+ from isaaclab_rl .rsl_rl import RslRlBaseRunnerCfg , RslRlVecEnvWrapper , export_policy_as_jit , export_policy_as_onnx
75+
76+ import isaaclab_tasks # noqa: F401
9877from isaaclab_tasks .utils import get_checkpoint_path
9978from isaaclab_tasks .utils .hydra import hydra_task_config
100- from rsl_rl .runners import DistillationRunner , OnPolicyRunner
10179
102- torch .backends .cuda .matmul .allow_tf32 = True
103- torch .backends .cudnn .allow_tf32 = True
104- torch .backends .cudnn .deterministic = False
105- torch .backends .cudnn .benchmark = False
80+ import SO_100 .tasks # noqa: F401
10681
10782
10883@hydra_task_config (args_cli .task , args_cli .agent )
10984def main (env_cfg : ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg , agent_cfg : RslRlBaseRunnerCfg ):
110- """Train with RSL-RL agent."""
85+ """Play with RSL-RL agent."""
86+ # grab task name for checkpoint path
87+ task_name = args_cli .task .split (":" )[- 1 ]
88+ train_task_name = task_name .replace ("-Play" , "" )
89+
11190 # override configurations with non-hydra CLI arguments
112- agent_cfg = cli_args .update_rsl_rl_cfg (agent_cfg , args_cli )
91+ agent_cfg : RslRlBaseRunnerCfg = cli_args .update_rsl_rl_cfg (agent_cfg , args_cli )
11392 env_cfg .scene .num_envs = args_cli .num_envs if args_cli .num_envs is not None else env_cfg .scene .num_envs
114- agent_cfg .max_iterations = (
115- args_cli .max_iterations if args_cli .max_iterations is not None else agent_cfg .max_iterations
116- )
11793
11894 # set the environment seed
11995 # note: certain randomizations occur in the environment initialization so we set the seed here
12096 env_cfg .seed = agent_cfg .seed
12197 env_cfg .sim .device = args_cli .device if args_cli .device is not None else env_cfg .sim .device
12298
123- # multi-gpu training configuration
124- if args_cli .distributed :
125- env_cfg .sim .device = f"cuda:{ app_launcher .local_rank } "
126- agent_cfg .device = f"cuda:{ app_launcher .local_rank } "
127-
128- # set seed to have diversity in different threads
129- seed = agent_cfg .seed + app_launcher .local_rank
130- env_cfg .seed = seed
131- agent_cfg .seed = seed
132-
13399 # specify directory for logging experiments
134100 log_root_path = os .path .join ("logs" , "rsl_rl" , agent_cfg .experiment_name )
135101 log_root_path = os .path .abspath (log_root_path )
136- print (f"[INFO] Logging experiment in directory: { log_root_path } " )
137- # specify directory for logging runs: {time-stamp}_{run_name}
138- log_dir = datetime .now ().strftime ("%Y-%m-%d_%H-%M-%S" )
139- # The Ray Tune workflow extracts experiment name using the logging line below, hence, do not change it (see PR #2346, comment-2819298849)
140- print (f"Exact experiment name requested from command line: { log_dir } " )
141- if agent_cfg .run_name :
142- log_dir += f"_{ agent_cfg .run_name } "
143- log_dir = os .path .join (log_root_path , log_dir )
144-
145- # set the IO descriptors output directory if requested
146- if isinstance (env_cfg , ManagerBasedRLEnvCfg ):
147- env_cfg .export_io_descriptors = args_cli .export_io_descriptors
148- env_cfg .io_descriptors_output_dir = log_dir
102+ print (f"[INFO] Loading experiment from directory: { log_root_path } " )
103+ if args_cli .use_pretrained_checkpoint :
104+ resume_path = get_published_pretrained_checkpoint ("rsl_rl" , train_task_name )
105+ if not resume_path :
106+ print ("[INFO] Unfortunately a pre-trained checkpoint is currently unavailable for this task." )
107+ return
108+ elif args_cli .checkpoint :
109+ resume_path = retrieve_file_path (args_cli .checkpoint )
149110 else :
150- omni . log . warn (
151- "IO descriptors are only supported for manager based RL environments. No IO descriptors will be exported."
152- )
111+ resume_path = get_checkpoint_path ( log_root_path , agent_cfg . load_run , agent_cfg . load_checkpoint )
112+
113+ log_dir = os . path . dirname ( resume_path )
153114
154115 # create isaac environment
155116 env = gym .make (args_cli .task , cfg = env_cfg , render_mode = "rgb_array" if args_cli .video else None )
@@ -158,15 +119,11 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
158119 if isinstance (env .unwrapped , DirectMARLEnv ):
159120 env = multi_agent_to_single_agent (env )
160121
161- # save resume path before creating a new log_dir
162- if agent_cfg .resume or agent_cfg .algorithm .class_name == "Distillation" :
163- resume_path = get_checkpoint_path (log_root_path , agent_cfg .load_run , agent_cfg .load_checkpoint )
164-
165122 # wrap for video recording
166123 if args_cli .video :
167124 video_kwargs = {
168- "video_folder" : os .path .join (log_dir , "videos" , "train " ),
169- "step_trigger" : lambda step : step % args_cli . video_interval == 0 ,
125+ "video_folder" : os .path .join (log_dir , "videos" , "play " ),
126+ "step_trigger" : lambda step : step == 0 ,
170127 "video_length" : args_cli .video_length ,
171128 "disable_logger" : True ,
172129 }
@@ -177,29 +134,65 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
177134 # wrap around environment for rsl-rl
178135 env = RslRlVecEnvWrapper (env , clip_actions = agent_cfg .clip_actions )
179136
180- # create runner from rsl-rl
137+ print (f"[INFO]: Loading model checkpoint from: { resume_path } " )
138+ # load previously trained model
181139 if agent_cfg .class_name == "OnPolicyRunner" :
182- runner = OnPolicyRunner (env , agent_cfg .to_dict (), log_dir = log_dir , device = agent_cfg .device )
140+ runner = OnPolicyRunner (env , agent_cfg .to_dict (), log_dir = None , device = agent_cfg .device )
183141 elif agent_cfg .class_name == "DistillationRunner" :
184- runner = DistillationRunner (env , agent_cfg .to_dict (), log_dir = log_dir , device = agent_cfg .device )
142+ runner = DistillationRunner (env , agent_cfg .to_dict (), log_dir = None , device = agent_cfg .device )
185143 else :
186144 raise ValueError (f"Unsupported runner class: { agent_cfg .class_name } " )
187- # write git state to logs
188- runner .add_git_repo_to_log (__file__ )
189- # load the checkpoint
190- if agent_cfg .resume or agent_cfg .algorithm .class_name == "Distillation" :
191- print (f"[INFO]: Loading model checkpoint from: { resume_path } " )
192- # load previously trained model
193- runner .load (resume_path )
194-
195- # dump the configuration into log-directory
196- dump_yaml (os .path .join (log_dir , "params" , "env.yaml" ), env_cfg )
197- dump_yaml (os .path .join (log_dir , "params" , "agent.yaml" ), agent_cfg )
198- dump_pickle (os .path .join (log_dir , "params" , "env.pkl" ), env_cfg )
199- dump_pickle (os .path .join (log_dir , "params" , "agent.pkl" ), agent_cfg )
200-
201- # run training
202- runner .learn (num_learning_iterations = agent_cfg .max_iterations , init_at_random_ep_len = True )
145+ runner .load (resume_path )
146+
147+ # obtain the trained policy for inference
148+ policy = runner .get_inference_policy (device = env .unwrapped .device )
149+
150+ # extract the neural network module
151+ # we do this in a try-except to maintain backwards compatibility.
152+ try :
153+ # version 2.3 onwards
154+ policy_nn = runner .alg .policy
155+ except AttributeError :
156+ # version 2.2 and below
157+ policy_nn = runner .alg .actor_critic
158+
159+ # extract the normalizer
160+ if hasattr (policy_nn , "actor_obs_normalizer" ):
161+ normalizer = policy_nn .actor_obs_normalizer
162+ elif hasattr (policy_nn , "student_obs_normalizer" ):
163+ normalizer = policy_nn .student_obs_normalizer
164+ else :
165+ normalizer = None
166+
167+ # export policy to onnx/jit
168+ export_model_dir = os .path .join (os .path .dirname (resume_path ), "exported" )
169+ export_policy_as_jit (policy_nn , normalizer = normalizer , path = export_model_dir , filename = "policy.pt" )
170+ export_policy_as_onnx (policy_nn , normalizer = normalizer , path = export_model_dir , filename = "policy.onnx" )
171+
172+ dt = env .unwrapped .step_dt
173+
174+ # reset environment
175+ obs = env .get_observations ()
176+ timestep = 0
177+ # simulate environment
178+ while simulation_app .is_running ():
179+ start_time = time .time ()
180+ # run everything in inference mode
181+ with torch .inference_mode ():
182+ # agent stepping
183+ actions = policy (obs )
184+ # env stepping
185+ obs , _ , _ , _ = env .step (actions )
186+ if args_cli .video :
187+ timestep += 1
188+ # Exit the play loop after recording one video
189+ if timestep == args_cli .video_length :
190+ break
191+
192+ # time delay for real-time evaluation
193+ sleep_time = dt - (time .time () - start_time )
194+ if args_cli .real_time and sleep_time > 0 :
195+ time .sleep (sleep_time )
203196
204197 # close the simulator
205198 env .close ()
@@ -209,4 +202,4 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
209202 # run the main function
210203 main ()
211204 # close sim app
212- simulation_app .close ()
205+ simulation_app .close ()
0 commit comments