Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions scripts/rsl_rl/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from isaaclab_rl.rsl_rl import RslRlOnPolicyRunnerCfg
from isaaclab_rl.rsl_rl import RslRlBaseRunnerCfg


def add_rsl_rl_args(parser: argparse.ArgumentParser):
Expand Down Expand Up @@ -44,7 +44,7 @@ def add_rsl_rl_args(parser: argparse.ArgumentParser):
)


def parse_rsl_rl_cfg(task_name: str, args_cli: argparse.Namespace) -> RslRlOnPolicyRunnerCfg:
def parse_rsl_rl_cfg(task_name: str, args_cli: argparse.Namespace) -> RslRlBaseRunnerCfg:
"""Parse configuration for RSL-RL agent based on inputs.

Args:
Expand All @@ -57,12 +57,12 @@ def parse_rsl_rl_cfg(task_name: str, args_cli: argparse.Namespace) -> RslRlOnPol
from isaaclab_tasks.utils.parse_cfg import load_cfg_from_registry

# load the default configuration
rslrl_cfg: RslRlOnPolicyRunnerCfg = load_cfg_from_registry(task_name, "rsl_rl_cfg_entry_point")
rslrl_cfg: RslRlBaseRunnerCfg = load_cfg_from_registry(task_name, "rsl_rl_cfg_entry_point")
rslrl_cfg = update_rsl_rl_cfg(rslrl_cfg, args_cli)
return rslrl_cfg


def update_rsl_rl_cfg(agent_cfg: RslRlOnPolicyRunnerCfg, args_cli: argparse.Namespace):
def update_rsl_rl_cfg(agent_cfg: RslRlBaseRunnerCfg, args_cli: argparse.Namespace):
"""Update configuration for RSL-RL agent based on inputs.

Args:
Expand Down
79 changes: 57 additions & 22 deletions scripts/rsl_rl/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"""Launch Isaac Sim Simulator first."""

import argparse
import sys

from isaaclab.app import AppLauncher

Expand All @@ -28,6 +29,10 @@
)
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument(
"--agent", type=str, default="rsl_rl_cfg_entry_point", help="Name of the RL agent configuration entry point."
)
parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment")
parser.add_argument(
"--use_pretrained_checkpoint",
action="store_true",
Expand All @@ -38,11 +43,15 @@
cli_args.add_rsl_rl_args(parser)
# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
args_cli = parser.parse_args()
# parse the arguments
args_cli, hydra_args = parser.parse_known_args()
# always enable cameras to record video
if args_cli.video:
args_cli.enable_cameras = True

# clear out sys.argv for Hydra
sys.argv = [sys.argv[0]] + hydra_args

# launch omniverse app
app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app
Expand All @@ -56,34 +65,49 @@
import isaaclab_tasks # noqa: F401
import SO_100.tasks # noqa: F401
import torch
from isaaclab.envs import DirectMARLEnv, multi_agent_to_single_agent
from isaaclab.envs import (
DirectMARLEnv,
DirectMARLEnvCfg,
DirectRLEnvCfg,
ManagerBasedRLEnvCfg,
multi_agent_to_single_agent,
)
from isaaclab.utils.assets import retrieve_file_path
from isaaclab.utils.dict import print_dict
from isaaclab.utils.pretrained_checkpoint import get_published_pretrained_checkpoint
from isaaclab_rl.rsl_rl import (
RslRlOnPolicyRunnerCfg,
RslRlBaseRunnerCfg,
RslRlVecEnvWrapper,
export_policy_as_jit,
export_policy_as_onnx,
)
from isaaclab_tasks.utils import get_checkpoint_path, parse_env_cfg
from rsl_rl.runners import OnPolicyRunner
from isaaclab_tasks.utils import get_checkpoint_path
from isaaclab_tasks.utils.hydra import hydra_task_config
from rsl_rl.runners import DistillationRunner, OnPolicyRunner


def main():
@hydra_task_config(args_cli.task, args_cli.agent)
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: RslRlBaseRunnerCfg):
"""Play with RSL-RL agent."""
# parse configuration
env_cfg = parse_env_cfg(
args_cli.task, device=args_cli.device, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric
)
agent_cfg: RslRlOnPolicyRunnerCfg = cli_args.parse_rsl_rl_cfg(args_cli.task, args_cli)
# grab task name for checkpoint path
task_name = args_cli.task.split(":")[-1]
train_task_name = task_name.replace("-Play", "")

# override configurations with non-hydra CLI arguments
agent_cfg: RslRlBaseRunnerCfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli)
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs

# set the environment seed
# note: certain randomizations occur in the environment initialization so we set the seed here
env_cfg.seed = agent_cfg.seed
env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device

# specify directory for logging experiments
log_root_path = os.path.join("logs", "rsl_rl", agent_cfg.experiment_name)
log_root_path = os.path.abspath(log_root_path)
print(f"[INFO] Loading experiment from directory: {log_root_path}")
if args_cli.use_pretrained_checkpoint:
resume_path = get_published_pretrained_checkpoint("rsl_rl", args_cli.task)
resume_path = get_published_pretrained_checkpoint("rsl_rl", train_task_name)
if not resume_path:
print("[INFO] Unfortunately a pre-trained checkpoint is currently unavailable for this task.")
return
Expand Down Expand Up @@ -118,32 +142,43 @@ def main():

print(f"[INFO]: Loading model checkpoint from: {resume_path}")
# load previously trained model
ppo_runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device)
ppo_runner.load(resume_path)
if agent_cfg.class_name == "OnPolicyRunner":
runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device)
elif agent_cfg.class_name == "DistillationRunner":
runner = DistillationRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device)
else:
raise ValueError(f"Unsupported runner class: {agent_cfg.class_name}")
runner.load(resume_path)

# obtain the trained policy for inference
policy = ppo_runner.get_inference_policy(device=env.unwrapped.device)
policy = runner.get_inference_policy(device=env.unwrapped.device)

# extract the neural network module
# we do this in a try-except to maintain backwards compatibility.
try:
# version 2.3 onwards
policy_nn = ppo_runner.alg.policy
policy_nn = runner.alg.policy
except AttributeError:
# version 2.2 and below
policy_nn = ppo_runner.alg.actor_critic
policy_nn = runner.alg.actor_critic

# extract the normalizer
if hasattr(policy_nn, "actor_obs_normalizer"):
normalizer = policy_nn.actor_obs_normalizer
elif hasattr(policy_nn, "student_obs_normalizer"):
normalizer = policy_nn.student_obs_normalizer
else:
normalizer = None

# export policy to onnx/jit
export_model_dir = os.path.join(os.path.dirname(resume_path), "exported")
export_policy_as_jit(policy_nn, ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.pt")
export_policy_as_onnx(
policy_nn, normalizer=ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.onnx"
)
export_policy_as_jit(policy_nn, normalizer=normalizer, path=export_model_dir, filename="policy.pt")
export_policy_as_onnx(policy_nn, normalizer=normalizer, path=export_model_dir, filename="policy.onnx")

dt = env.unwrapped.step_dt

# reset environment
obs, _ = env.get_observations()
obs = env.get_observations()
timestep = 0
# simulate environment
while simulation_app.is_running():
Expand Down
1 change: 1 addition & 0 deletions source/SO_100/SO_100/robots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
# SPDX-License-Identifier: BSD-3-Clause

from .so_arm100 import *
from .so_arm100_roscon import *
2 changes: 1 addition & 1 deletion source/SO_100/SO_100/tasks/lift/agents/rsl_rl_ppo_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class LiftCubePPORunnerCfg(RslRlOnPolicyRunnerCfg):
num_steps_per_env = 24
max_iterations = 1500
save_interval = 50
experiment_name = "so_arm100_lift"
experiment_name = "lift"
empirical_normalization = False
policy = RslRlPpoActorCriticCfg(
init_noise_std=1.0,
Expand Down
4 changes: 2 additions & 2 deletions source/SO_100/SO_100/tasks/reach/agents/skrl_ppo_cfg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ agent:
time_limit_bootstrap: False
# logging and checkpoint
experiment:
directory: "reach_so_arm100"
experiment_name: "reach_so_arm100"
directory: "reach"
experiment_name: "reach"
write_interval: auto
checkpoint_interval: auto

Expand Down