Skip to content

Commit 5b1ae36

Browse files
authored
Merge pull request #64 from MuammerBay/fix/train-play-errors
Updates train/play scripts for rsl_rl
2 parents 8eb4911 + b1ba024 commit 5b1ae36

File tree

5 files changed

+65
-29
lines changed

5 files changed

+65
-29
lines changed

scripts/rsl_rl/cli_args.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from typing import TYPE_CHECKING
1616

1717
if TYPE_CHECKING:
18-
from isaaclab_rl.rsl_rl import RslRlOnPolicyRunnerCfg
18+
from isaaclab_rl.rsl_rl import RslRlBaseRunnerCfg
1919

2020

2121
def add_rsl_rl_args(parser: argparse.ArgumentParser):
@@ -44,7 +44,7 @@ def add_rsl_rl_args(parser: argparse.ArgumentParser):
4444
)
4545

4646

47-
def parse_rsl_rl_cfg(task_name: str, args_cli: argparse.Namespace) -> RslRlOnPolicyRunnerCfg:
47+
def parse_rsl_rl_cfg(task_name: str, args_cli: argparse.Namespace) -> RslRlBaseRunnerCfg:
4848
"""Parse configuration for RSL-RL agent based on inputs.
4949
5050
Args:
@@ -57,12 +57,12 @@ def parse_rsl_rl_cfg(task_name: str, args_cli: argparse.Namespace) -> RslRlOnPol
5757
from isaaclab_tasks.utils.parse_cfg import load_cfg_from_registry
5858

5959
# load the default configuration
60-
rslrl_cfg: RslRlOnPolicyRunnerCfg = load_cfg_from_registry(task_name, "rsl_rl_cfg_entry_point")
60+
rslrl_cfg: RslRlBaseRunnerCfg = load_cfg_from_registry(task_name, "rsl_rl_cfg_entry_point")
6161
rslrl_cfg = update_rsl_rl_cfg(rslrl_cfg, args_cli)
6262
return rslrl_cfg
6363

6464

65-
def update_rsl_rl_cfg(agent_cfg: RslRlOnPolicyRunnerCfg, args_cli: argparse.Namespace):
65+
def update_rsl_rl_cfg(agent_cfg: RslRlBaseRunnerCfg, args_cli: argparse.Namespace):
6666
"""Update configuration for RSL-RL agent based on inputs.
6767
6868
Args:

scripts/rsl_rl/play.py

Lines changed: 57 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""Launch Isaac Sim Simulator first."""
1414

1515
import argparse
16+
import sys
1617

1718
from isaaclab.app import AppLauncher
1819

@@ -28,6 +29,10 @@
2829
)
2930
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
3031
parser.add_argument("--task", type=str, default=None, help="Name of the task.")
32+
parser.add_argument(
33+
"--agent", type=str, default="rsl_rl_cfg_entry_point", help="Name of the RL agent configuration entry point."
34+
)
35+
parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment")
3136
parser.add_argument(
3237
"--use_pretrained_checkpoint",
3338
action="store_true",
@@ -38,11 +43,15 @@
3843
cli_args.add_rsl_rl_args(parser)
3944
# append AppLauncher cli args
4045
AppLauncher.add_app_launcher_args(parser)
41-
args_cli = parser.parse_args()
46+
# parse the arguments
47+
args_cli, hydra_args = parser.parse_known_args()
4248
# always enable cameras to record video
4349
if args_cli.video:
4450
args_cli.enable_cameras = True
4551

52+
# clear out sys.argv for Hydra
53+
sys.argv = [sys.argv[0]] + hydra_args
54+
4655
# launch omniverse app
4756
app_launcher = AppLauncher(args_cli)
4857
simulation_app = app_launcher.app
@@ -56,34 +65,49 @@
5665
import isaaclab_tasks # noqa: F401
5766
import SO_100.tasks # noqa: F401
5867
import torch
59-
from isaaclab.envs import DirectMARLEnv, multi_agent_to_single_agent
68+
from isaaclab.envs import (
69+
DirectMARLEnv,
70+
DirectMARLEnvCfg,
71+
DirectRLEnvCfg,
72+
ManagerBasedRLEnvCfg,
73+
multi_agent_to_single_agent,
74+
)
6075
from isaaclab.utils.assets import retrieve_file_path
6176
from isaaclab.utils.dict import print_dict
6277
from isaaclab.utils.pretrained_checkpoint import get_published_pretrained_checkpoint
6378
from isaaclab_rl.rsl_rl import (
64-
RslRlOnPolicyRunnerCfg,
79+
RslRlBaseRunnerCfg,
6580
RslRlVecEnvWrapper,
6681
export_policy_as_jit,
6782
export_policy_as_onnx,
6883
)
69-
from isaaclab_tasks.utils import get_checkpoint_path, parse_env_cfg
70-
from rsl_rl.runners import OnPolicyRunner
84+
from isaaclab_tasks.utils import get_checkpoint_path
85+
from isaaclab_tasks.utils.hydra import hydra_task_config
86+
from rsl_rl.runners import DistillationRunner, OnPolicyRunner
7187

7288

73-
def main():
89+
@hydra_task_config(args_cli.task, args_cli.agent)
90+
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: RslRlBaseRunnerCfg):
7491
"""Play with RSL-RL agent."""
75-
# parse configuration
76-
env_cfg = parse_env_cfg(
77-
args_cli.task, device=args_cli.device, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric
78-
)
79-
agent_cfg: RslRlOnPolicyRunnerCfg = cli_args.parse_rsl_rl_cfg(args_cli.task, args_cli)
92+
# grab task name for checkpoint path
93+
task_name = args_cli.task.split(":")[-1]
94+
train_task_name = task_name.replace("-Play", "")
95+
96+
# override configurations with non-hydra CLI arguments
97+
agent_cfg: RslRlBaseRunnerCfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli)
98+
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
99+
100+
# set the environment seed
101+
# note: certain randomizations occur in the environment initialization so we set the seed here
102+
env_cfg.seed = agent_cfg.seed
103+
env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device
80104

81105
# specify directory for logging experiments
82106
log_root_path = os.path.join("logs", "rsl_rl", agent_cfg.experiment_name)
83107
log_root_path = os.path.abspath(log_root_path)
84108
print(f"[INFO] Loading experiment from directory: {log_root_path}")
85109
if args_cli.use_pretrained_checkpoint:
86-
resume_path = get_published_pretrained_checkpoint("rsl_rl", args_cli.task)
110+
resume_path = get_published_pretrained_checkpoint("rsl_rl", train_task_name)
87111
if not resume_path:
88112
print("[INFO] Unfortunately a pre-trained checkpoint is currently unavailable for this task.")
89113
return
@@ -118,32 +142,43 @@ def main():
118142

119143
print(f"[INFO]: Loading model checkpoint from: {resume_path}")
120144
# load previously trained model
121-
ppo_runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device)
122-
ppo_runner.load(resume_path)
145+
if agent_cfg.class_name == "OnPolicyRunner":
146+
runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device)
147+
elif agent_cfg.class_name == "DistillationRunner":
148+
runner = DistillationRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device)
149+
else:
150+
raise ValueError(f"Unsupported runner class: {agent_cfg.class_name}")
151+
runner.load(resume_path)
123152

124153
# obtain the trained policy for inference
125-
policy = ppo_runner.get_inference_policy(device=env.unwrapped.device)
154+
policy = runner.get_inference_policy(device=env.unwrapped.device)
126155

127156
# extract the neural network module
128157
# we do this in a try-except to maintain backwards compatibility.
129158
try:
130159
# version 2.3 onwards
131-
policy_nn = ppo_runner.alg.policy
160+
policy_nn = runner.alg.policy
132161
except AttributeError:
133162
# version 2.2 and below
134-
policy_nn = ppo_runner.alg.actor_critic
163+
policy_nn = runner.alg.actor_critic
164+
165+
# extract the normalizer
166+
if hasattr(policy_nn, "actor_obs_normalizer"):
167+
normalizer = policy_nn.actor_obs_normalizer
168+
elif hasattr(policy_nn, "student_obs_normalizer"):
169+
normalizer = policy_nn.student_obs_normalizer
170+
else:
171+
normalizer = None
135172

136173
# export policy to onnx/jit
137174
export_model_dir = os.path.join(os.path.dirname(resume_path), "exported")
138-
export_policy_as_jit(policy_nn, ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.pt")
139-
export_policy_as_onnx(
140-
policy_nn, normalizer=ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.onnx"
141-
)
175+
export_policy_as_jit(policy_nn, normalizer=normalizer, path=export_model_dir, filename="policy.pt")
176+
export_policy_as_onnx(policy_nn, normalizer=normalizer, path=export_model_dir, filename="policy.onnx")
142177

143178
dt = env.unwrapped.step_dt
144179

145180
# reset environment
146-
obs, _ = env.get_observations()
181+
obs = env.get_observations()
147182
timestep = 0
148183
# simulate environment
149184
while simulation_app.is_running():

source/SO_100/SO_100/robots/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@
99
# SPDX-License-Identifier: BSD-3-Clause
1010

1111
from .so_arm100 import *
12+
from .so_arm100_roscon import *

source/SO_100/SO_100/tasks/lift/agents/rsl_rl_ppo_cfg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class LiftCubePPORunnerCfg(RslRlOnPolicyRunnerCfg):
2121
num_steps_per_env = 24
2222
max_iterations = 1500
2323
save_interval = 50
24-
experiment_name = "so_arm100_lift"
24+
experiment_name = "lift"
2525
empirical_normalization = False
2626
policy = RslRlPpoActorCriticCfg(
2727
init_noise_std=1.0,

source/SO_100/SO_100/tasks/reach/agents/skrl_ppo_cfg.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ agent:
6666
time_limit_bootstrap: False
6767
# logging and checkpoint
6868
experiment:
69-
directory: "reach_so_arm100"
70-
experiment_name: "reach_so_arm100"
69+
directory: "reach"
70+
experiment_name: "reach"
7171
write_interval: auto
7272
checkpoint_interval: auto
7373

0 commit comments

Comments
 (0)