diff --git a/scripts/rsl_rl/cli_args.py b/scripts/rsl_rl/cli_args.py index 4bc4f87..fde3046 100644 --- a/scripts/rsl_rl/cli_args.py +++ b/scripts/rsl_rl/cli_args.py @@ -15,7 +15,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from isaaclab_rl.rsl_rl import RslRlOnPolicyRunnerCfg + from isaaclab_rl.rsl_rl import RslRlBaseRunnerCfg def add_rsl_rl_args(parser: argparse.ArgumentParser): @@ -44,7 +44,7 @@ def add_rsl_rl_args(parser: argparse.ArgumentParser): ) -def parse_rsl_rl_cfg(task_name: str, args_cli: argparse.Namespace) -> RslRlOnPolicyRunnerCfg: +def parse_rsl_rl_cfg(task_name: str, args_cli: argparse.Namespace) -> RslRlBaseRunnerCfg: """Parse configuration for RSL-RL agent based on inputs. Args: @@ -57,12 +57,12 @@ def parse_rsl_rl_cfg(task_name: str, args_cli: argparse.Namespace) -> RslRlOnPol from isaaclab_tasks.utils.parse_cfg import load_cfg_from_registry # load the default configuration - rslrl_cfg: RslRlOnPolicyRunnerCfg = load_cfg_from_registry(task_name, "rsl_rl_cfg_entry_point") + rslrl_cfg: RslRlBaseRunnerCfg = load_cfg_from_registry(task_name, "rsl_rl_cfg_entry_point") rslrl_cfg = update_rsl_rl_cfg(rslrl_cfg, args_cli) return rslrl_cfg -def update_rsl_rl_cfg(agent_cfg: RslRlOnPolicyRunnerCfg, args_cli: argparse.Namespace): +def update_rsl_rl_cfg(agent_cfg: RslRlBaseRunnerCfg, args_cli: argparse.Namespace): """Update configuration for RSL-RL agent based on inputs. Args: diff --git a/scripts/rsl_rl/play.py b/scripts/rsl_rl/play.py index dba2368..a9a98e5 100644 --- a/scripts/rsl_rl/play.py +++ b/scripts/rsl_rl/play.py @@ -13,6 +13,7 @@ """Launch Isaac Sim Simulator first.""" import argparse +import sys from isaaclab.app import AppLauncher @@ -28,6 +29,10 @@ ) parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.") parser.add_argument("--task", type=str, default=None, help="Name of the task.") +parser.add_argument( + "--agent", type=str, default="rsl_rl_cfg_entry_point", help="Name of the RL agent configuration entry point." +) +parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment") parser.add_argument( "--use_pretrained_checkpoint", action="store_true", @@ -38,11 +43,15 @@ cli_args.add_rsl_rl_args(parser) # append AppLauncher cli args AppLauncher.add_app_launcher_args(parser) -args_cli = parser.parse_args() +# parse the arguments +args_cli, hydra_args = parser.parse_known_args() # always enable cameras to record video if args_cli.video: args_cli.enable_cameras = True +# clear out sys.argv for Hydra +sys.argv = [sys.argv[0]] + hydra_args + # launch omniverse app app_launcher = AppLauncher(args_cli) simulation_app = app_launcher.app @@ -56,34 +65,49 @@ import isaaclab_tasks # noqa: F401 import SO_100.tasks # noqa: F401 import torch -from isaaclab.envs import DirectMARLEnv, multi_agent_to_single_agent +from isaaclab.envs import ( + DirectMARLEnv, + DirectMARLEnvCfg, + DirectRLEnvCfg, + ManagerBasedRLEnvCfg, + multi_agent_to_single_agent, +) from isaaclab.utils.assets import retrieve_file_path from isaaclab.utils.dict import print_dict from isaaclab.utils.pretrained_checkpoint import get_published_pretrained_checkpoint from isaaclab_rl.rsl_rl import ( - RslRlOnPolicyRunnerCfg, + RslRlBaseRunnerCfg, RslRlVecEnvWrapper, export_policy_as_jit, export_policy_as_onnx, ) -from isaaclab_tasks.utils import get_checkpoint_path, parse_env_cfg -from rsl_rl.runners import OnPolicyRunner +from isaaclab_tasks.utils import get_checkpoint_path +from isaaclab_tasks.utils.hydra import hydra_task_config +from rsl_rl.runners import DistillationRunner, OnPolicyRunner -def main(): +@hydra_task_config(args_cli.task, args_cli.agent) +def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: RslRlBaseRunnerCfg): """Play with RSL-RL agent.""" - # parse configuration - env_cfg = parse_env_cfg( - args_cli.task, device=args_cli.device, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric - ) - agent_cfg: RslRlOnPolicyRunnerCfg = cli_args.parse_rsl_rl_cfg(args_cli.task, args_cli) + # grab task name for checkpoint path + task_name = args_cli.task.split(":")[-1] + train_task_name = task_name.replace("-Play", "") + + # override configurations with non-hydra CLI arguments + agent_cfg: RslRlBaseRunnerCfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli) + env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs + + # set the environment seed + # note: certain randomizations occur in the environment initialization so we set the seed here + env_cfg.seed = agent_cfg.seed + env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device # specify directory for logging experiments log_root_path = os.path.join("logs", "rsl_rl", agent_cfg.experiment_name) log_root_path = os.path.abspath(log_root_path) print(f"[INFO] Loading experiment from directory: {log_root_path}") if args_cli.use_pretrained_checkpoint: - resume_path = get_published_pretrained_checkpoint("rsl_rl", args_cli.task) + resume_path = get_published_pretrained_checkpoint("rsl_rl", train_task_name) if not resume_path: print("[INFO] Unfortunately a pre-trained checkpoint is currently unavailable for this task.") return @@ -118,32 +142,43 @@ def main(): print(f"[INFO]: Loading model checkpoint from: {resume_path}") # load previously trained model - ppo_runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device) - ppo_runner.load(resume_path) + if agent_cfg.class_name == "OnPolicyRunner": + runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device) + elif agent_cfg.class_name == "DistillationRunner": + runner = DistillationRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device) + else: + raise ValueError(f"Unsupported runner class: {agent_cfg.class_name}") + runner.load(resume_path) # obtain the trained policy for inference - policy = ppo_runner.get_inference_policy(device=env.unwrapped.device) + policy = runner.get_inference_policy(device=env.unwrapped.device) # extract the neural network module # we do this in a try-except to maintain backwards compatibility. try: # version 2.3 onwards - policy_nn = ppo_runner.alg.policy + policy_nn = runner.alg.policy except AttributeError: # version 2.2 and below - policy_nn = ppo_runner.alg.actor_critic + policy_nn = runner.alg.actor_critic + + # extract the normalizer + if hasattr(policy_nn, "actor_obs_normalizer"): + normalizer = policy_nn.actor_obs_normalizer + elif hasattr(policy_nn, "student_obs_normalizer"): + normalizer = policy_nn.student_obs_normalizer + else: + normalizer = None # export policy to onnx/jit export_model_dir = os.path.join(os.path.dirname(resume_path), "exported") - export_policy_as_jit(policy_nn, ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.pt") - export_policy_as_onnx( - policy_nn, normalizer=ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.onnx" - ) + export_policy_as_jit(policy_nn, normalizer=normalizer, path=export_model_dir, filename="policy.pt") + export_policy_as_onnx(policy_nn, normalizer=normalizer, path=export_model_dir, filename="policy.onnx") dt = env.unwrapped.step_dt # reset environment - obs, _ = env.get_observations() + obs = env.get_observations() timestep = 0 # simulate environment while simulation_app.is_running(): diff --git a/source/SO_100/SO_100/robots/__init__.py b/source/SO_100/SO_100/robots/__init__.py index 9e63b1b..aa0a926 100644 --- a/source/SO_100/SO_100/robots/__init__.py +++ b/source/SO_100/SO_100/robots/__init__.py @@ -9,3 +9,4 @@ # SPDX-License-Identifier: BSD-3-Clause from .so_arm100 import * +from .so_arm100_roscon import * diff --git a/source/SO_100/SO_100/tasks/lift/agents/rsl_rl_ppo_cfg.py b/source/SO_100/SO_100/tasks/lift/agents/rsl_rl_ppo_cfg.py index 1ae2c8e..27fa0d9 100644 --- a/source/SO_100/SO_100/tasks/lift/agents/rsl_rl_ppo_cfg.py +++ b/source/SO_100/SO_100/tasks/lift/agents/rsl_rl_ppo_cfg.py @@ -21,7 +21,7 @@ class LiftCubePPORunnerCfg(RslRlOnPolicyRunnerCfg): num_steps_per_env = 24 max_iterations = 1500 save_interval = 50 - experiment_name = "so_arm100_lift" + experiment_name = "lift" empirical_normalization = False policy = RslRlPpoActorCriticCfg( init_noise_std=1.0, diff --git a/source/SO_100/SO_100/tasks/reach/agents/skrl_ppo_cfg.yaml b/source/SO_100/SO_100/tasks/reach/agents/skrl_ppo_cfg.yaml index 73ef462..15e2c15 100644 --- a/source/SO_100/SO_100/tasks/reach/agents/skrl_ppo_cfg.yaml +++ b/source/SO_100/SO_100/tasks/reach/agents/skrl_ppo_cfg.yaml @@ -66,8 +66,8 @@ agent: time_limit_bootstrap: False # logging and checkpoint experiment: - directory: "reach_so_arm100" - experiment_name: "reach_so_arm100" + directory: "reach" + experiment_name: "reach" write_interval: auto checkpoint_interval: auto