Skip to content

Commit 5cea2a2

Browse files
committed
corrects play script
1 parent 9ce2a4d commit 5cea2a2

File tree

3 files changed

+100
-107
lines changed

3 files changed

+100
-107
lines changed

scripts/rsl_rl/play.py

Lines changed: 97 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
1-
# Copyright (c) 2024-2025, Muammer Bay (LycheeAI), Louis Le Lay
2-
# All rights reserved.
3-
#
4-
# SPDX-License-Identifier: BSD-3-Clause
5-
#
6-
# Copyright (c) 2022-2025, The Isaac Lab Project Developers.
1+
# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
72
# All rights reserved.
83
#
94
# SPDX-License-Identifier: BSD-3-Clause
105

11-
"""Script to train RL agent with RSL-RL."""
6+
"""Script to play a checkpoint if an RL agent from RSL-RL."""
127

138
"""Launch Isaac Sim Simulator first."""
149

@@ -24,24 +19,27 @@
2419
parser = argparse.ArgumentParser(description="Train an RL agent with RSL-RL.")
2520
parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.")
2621
parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).")
27-
parser.add_argument("--video_interval", type=int, default=2000, help="Interval between video recordings (in steps).")
22+
parser.add_argument(
23+
"--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations."
24+
)
2825
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
2926
parser.add_argument("--task", type=str, default=None, help="Name of the task.")
3027
parser.add_argument(
3128
"--agent", type=str, default="rsl_rl_cfg_entry_point", help="Name of the RL agent configuration entry point."
3229
)
3330
parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment")
34-
parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy training iterations.")
3531
parser.add_argument(
36-
"--distributed", action="store_true", default=False, help="Run training with multiple GPUs or nodes."
32+
"--use_pretrained_checkpoint",
33+
action="store_true",
34+
help="Use the pre-trained checkpoint from Nucleus.",
3735
)
38-
parser.add_argument("--export_io_descriptors", action="store_true", default=False, help="Export IO descriptors.")
36+
parser.add_argument("--real-time", action="store_true", default=False, help="Run in real-time, if possible.")
3937
# append RSL-RL cli arguments
4038
cli_args.add_rsl_rl_args(parser)
4139
# append AppLauncher cli args
4240
AppLauncher.add_app_launcher_args(parser)
41+
# parse the arguments
4342
args_cli, hydra_args = parser.parse_known_args()
44-
4543
# always enable cameras to record video
4644
if args_cli.video:
4745
args_cli.enable_cameras = True
@@ -53,103 +51,66 @@
5351
app_launcher = AppLauncher(args_cli)
5452
simulation_app = app_launcher.app
5553

56-
"""Check for minimum supported RSL-RL version."""
57-
58-
import importlib.metadata as metadata
59-
import platform
60-
61-
from packaging import version
62-
63-
# check minimum supported rsl-rl version
64-
RSL_RL_VERSION = "3.0.1"
65-
installed_version = metadata.version("rsl-rl-lib")
66-
if version.parse(installed_version) < version.parse(RSL_RL_VERSION):
67-
if platform.system() == "Windows":
68-
cmd = [r".\isaaclab.bat", "-p", "-m", "pip", "install", f"rsl-rl-lib=={RSL_RL_VERSION}"]
69-
else:
70-
cmd = ["./isaaclab.sh", "-p", "-m", "pip", "install", f"rsl-rl-lib=={RSL_RL_VERSION}"]
71-
print(
72-
f"Please install the correct version of RSL-RL.\nExisting version is: '{installed_version}'"
73-
f" and required version is: '{RSL_RL_VERSION}'.\nTo install the correct version, run:"
74-
f"\n\n\t{' '.join(cmd)}\n"
75-
)
76-
exit(1)
77-
7854
"""Rest everything follows."""
7955

80-
import os
81-
from datetime import datetime
82-
8356
import gymnasium as gym
84-
import isaaclab_tasks # noqa: F401
85-
import omni
86-
import SO_100.tasks # noqa: F401
57+
import os
58+
import time
8759
import torch
60+
61+
from rsl_rl.runners import DistillationRunner, OnPolicyRunner
62+
8863
from isaaclab.envs import (
8964
DirectMARLEnv,
9065
DirectMARLEnvCfg,
9166
DirectRLEnvCfg,
9267
ManagerBasedRLEnvCfg,
9368
multi_agent_to_single_agent,
9469
)
70+
from isaaclab.utils.assets import retrieve_file_path
9571
from isaaclab.utils.dict import print_dict
96-
from isaaclab.utils.io import dump_pickle, dump_yaml
97-
from isaaclab_rl.rsl_rl import RslRlBaseRunnerCfg, RslRlVecEnvWrapper
72+
from isaaclab.utils.pretrained_checkpoint import get_published_pretrained_checkpoint
73+
74+
from isaaclab_rl.rsl_rl import RslRlBaseRunnerCfg, RslRlVecEnvWrapper, export_policy_as_jit, export_policy_as_onnx
75+
76+
import isaaclab_tasks # noqa: F401
9877
from isaaclab_tasks.utils import get_checkpoint_path
9978
from isaaclab_tasks.utils.hydra import hydra_task_config
100-
from rsl_rl.runners import DistillationRunner, OnPolicyRunner
10179

102-
torch.backends.cuda.matmul.allow_tf32 = True
103-
torch.backends.cudnn.allow_tf32 = True
104-
torch.backends.cudnn.deterministic = False
105-
torch.backends.cudnn.benchmark = False
80+
import SO_100.tasks # noqa: F401
10681

10782

10883
@hydra_task_config(args_cli.task, args_cli.agent)
10984
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: RslRlBaseRunnerCfg):
110-
"""Train with RSL-RL agent."""
85+
"""Play with RSL-RL agent."""
86+
# grab task name for checkpoint path
87+
task_name = args_cli.task.split(":")[-1]
88+
train_task_name = task_name.replace("-Play", "")
89+
11190
# override configurations with non-hydra CLI arguments
112-
agent_cfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli)
91+
agent_cfg: RslRlBaseRunnerCfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli)
11392
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
114-
agent_cfg.max_iterations = (
115-
args_cli.max_iterations if args_cli.max_iterations is not None else agent_cfg.max_iterations
116-
)
11793

11894
# set the environment seed
11995
# note: certain randomizations occur in the environment initialization so we set the seed here
12096
env_cfg.seed = agent_cfg.seed
12197
env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device
12298

123-
# multi-gpu training configuration
124-
if args_cli.distributed:
125-
env_cfg.sim.device = f"cuda:{app_launcher.local_rank}"
126-
agent_cfg.device = f"cuda:{app_launcher.local_rank}"
127-
128-
# set seed to have diversity in different threads
129-
seed = agent_cfg.seed + app_launcher.local_rank
130-
env_cfg.seed = seed
131-
agent_cfg.seed = seed
132-
13399
# specify directory for logging experiments
134100
log_root_path = os.path.join("logs", "rsl_rl", agent_cfg.experiment_name)
135101
log_root_path = os.path.abspath(log_root_path)
136-
print(f"[INFO] Logging experiment in directory: {log_root_path}")
137-
# specify directory for logging runs: {time-stamp}_{run_name}
138-
log_dir = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
139-
# The Ray Tune workflow extracts experiment name using the logging line below, hence, do not change it (see PR #2346, comment-2819298849)
140-
print(f"Exact experiment name requested from command line: {log_dir}")
141-
if agent_cfg.run_name:
142-
log_dir += f"_{agent_cfg.run_name}"
143-
log_dir = os.path.join(log_root_path, log_dir)
144-
145-
# set the IO descriptors output directory if requested
146-
if isinstance(env_cfg, ManagerBasedRLEnvCfg):
147-
env_cfg.export_io_descriptors = args_cli.export_io_descriptors
148-
env_cfg.io_descriptors_output_dir = log_dir
102+
print(f"[INFO] Loading experiment from directory: {log_root_path}")
103+
if args_cli.use_pretrained_checkpoint:
104+
resume_path = get_published_pretrained_checkpoint("rsl_rl", train_task_name)
105+
if not resume_path:
106+
print("[INFO] Unfortunately a pre-trained checkpoint is currently unavailable for this task.")
107+
return
108+
elif args_cli.checkpoint:
109+
resume_path = retrieve_file_path(args_cli.checkpoint)
149110
else:
150-
omni.log.warn(
151-
"IO descriptors are only supported for manager based RL environments. No IO descriptors will be exported."
152-
)
111+
resume_path = get_checkpoint_path(log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint)
112+
113+
log_dir = os.path.dirname(resume_path)
153114

154115
# create isaac environment
155116
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
@@ -158,15 +119,11 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
158119
if isinstance(env.unwrapped, DirectMARLEnv):
159120
env = multi_agent_to_single_agent(env)
160121

161-
# save resume path before creating a new log_dir
162-
if agent_cfg.resume or agent_cfg.algorithm.class_name == "Distillation":
163-
resume_path = get_checkpoint_path(log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint)
164-
165122
# wrap for video recording
166123
if args_cli.video:
167124
video_kwargs = {
168-
"video_folder": os.path.join(log_dir, "videos", "train"),
169-
"step_trigger": lambda step: step % args_cli.video_interval == 0,
125+
"video_folder": os.path.join(log_dir, "videos", "play"),
126+
"step_trigger": lambda step: step == 0,
170127
"video_length": args_cli.video_length,
171128
"disable_logger": True,
172129
}
@@ -177,29 +134,65 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
177134
# wrap around environment for rsl-rl
178135
env = RslRlVecEnvWrapper(env, clip_actions=agent_cfg.clip_actions)
179136

180-
# create runner from rsl-rl
137+
print(f"[INFO]: Loading model checkpoint from: {resume_path}")
138+
# load previously trained model
181139
if agent_cfg.class_name == "OnPolicyRunner":
182-
runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=log_dir, device=agent_cfg.device)
140+
runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device)
183141
elif agent_cfg.class_name == "DistillationRunner":
184-
runner = DistillationRunner(env, agent_cfg.to_dict(), log_dir=log_dir, device=agent_cfg.device)
142+
runner = DistillationRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device)
185143
else:
186144
raise ValueError(f"Unsupported runner class: {agent_cfg.class_name}")
187-
# write git state to logs
188-
runner.add_git_repo_to_log(__file__)
189-
# load the checkpoint
190-
if agent_cfg.resume or agent_cfg.algorithm.class_name == "Distillation":
191-
print(f"[INFO]: Loading model checkpoint from: {resume_path}")
192-
# load previously trained model
193-
runner.load(resume_path)
194-
195-
# dump the configuration into log-directory
196-
dump_yaml(os.path.join(log_dir, "params", "env.yaml"), env_cfg)
197-
dump_yaml(os.path.join(log_dir, "params", "agent.yaml"), agent_cfg)
198-
dump_pickle(os.path.join(log_dir, "params", "env.pkl"), env_cfg)
199-
dump_pickle(os.path.join(log_dir, "params", "agent.pkl"), agent_cfg)
200-
201-
# run training
202-
runner.learn(num_learning_iterations=agent_cfg.max_iterations, init_at_random_ep_len=True)
145+
runner.load(resume_path)
146+
147+
# obtain the trained policy for inference
148+
policy = runner.get_inference_policy(device=env.unwrapped.device)
149+
150+
# extract the neural network module
151+
# we do this in a try-except to maintain backwards compatibility.
152+
try:
153+
# version 2.3 onwards
154+
policy_nn = runner.alg.policy
155+
except AttributeError:
156+
# version 2.2 and below
157+
policy_nn = runner.alg.actor_critic
158+
159+
# extract the normalizer
160+
if hasattr(policy_nn, "actor_obs_normalizer"):
161+
normalizer = policy_nn.actor_obs_normalizer
162+
elif hasattr(policy_nn, "student_obs_normalizer"):
163+
normalizer = policy_nn.student_obs_normalizer
164+
else:
165+
normalizer = None
166+
167+
# export policy to onnx/jit
168+
export_model_dir = os.path.join(os.path.dirname(resume_path), "exported")
169+
export_policy_as_jit(policy_nn, normalizer=normalizer, path=export_model_dir, filename="policy.pt")
170+
export_policy_as_onnx(policy_nn, normalizer=normalizer, path=export_model_dir, filename="policy.onnx")
171+
172+
dt = env.unwrapped.step_dt
173+
174+
# reset environment
175+
obs = env.get_observations()
176+
timestep = 0
177+
# simulate environment
178+
while simulation_app.is_running():
179+
start_time = time.time()
180+
# run everything in inference mode
181+
with torch.inference_mode():
182+
# agent stepping
183+
actions = policy(obs)
184+
# env stepping
185+
obs, _, _, _ = env.step(actions)
186+
if args_cli.video:
187+
timestep += 1
188+
# Exit the play loop after recording one video
189+
if timestep == args_cli.video_length:
190+
break
191+
192+
# time delay for real-time evaluation
193+
sleep_time = dt - (time.time() - start_time)
194+
if args_cli.real_time and sleep_time > 0:
195+
time.sleep(sleep_time)
203196

204197
# close the simulator
205198
env.close()
@@ -209,4 +202,4 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
209202
# run the main function
210203
main()
211204
# close sim app
212-
simulation_app.close()
205+
simulation_app.close()

source/SO_100/SO_100/tasks/lift/agents/rsl_rl_ppo_cfg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class LiftCubePPORunnerCfg(RslRlOnPolicyRunnerCfg):
2121
num_steps_per_env = 24
2222
max_iterations = 1500
2323
save_interval = 50
24-
experiment_name = "so_arm100_lift"
24+
experiment_name = "lift"
2525
empirical_normalization = False
2626
policy = RslRlPpoActorCriticCfg(
2727
init_noise_std=1.0,

source/SO_100/SO_100/tasks/reach/agents/skrl_ppo_cfg.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ agent:
6666
time_limit_bootstrap: False
6767
# logging and checkpoint
6868
experiment:
69-
directory: "reach_so_arm100"
70-
experiment_name: "reach_so_arm100"
69+
directory: "reach"
70+
experiment_name: "reach"
7171
write_interval: auto
7272
checkpoint_interval: auto
7373

0 commit comments

Comments
 (0)