|
13 | 13 | """Launch Isaac Sim Simulator first.""" |
14 | 14 |
|
15 | 15 | import argparse |
| 16 | +import sys |
16 | 17 |
|
17 | 18 | from isaaclab.app import AppLauncher |
18 | 19 |
|
|
28 | 29 | ) |
29 | 30 | parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.") |
30 | 31 | parser.add_argument("--task", type=str, default=None, help="Name of the task.") |
| 32 | +parser.add_argument( |
| 33 | + "--agent", type=str, default="rsl_rl_cfg_entry_point", help="Name of the RL agent configuration entry point." |
| 34 | +) |
| 35 | +parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment") |
31 | 36 | parser.add_argument( |
32 | 37 | "--use_pretrained_checkpoint", |
33 | 38 | action="store_true", |
|
38 | 43 | cli_args.add_rsl_rl_args(parser) |
39 | 44 | # append AppLauncher cli args |
40 | 45 | AppLauncher.add_app_launcher_args(parser) |
41 | | -args_cli = parser.parse_args() |
| 46 | +# parse the arguments |
| 47 | +args_cli, hydra_args = parser.parse_known_args() |
42 | 48 | # always enable cameras to record video |
43 | 49 | if args_cli.video: |
44 | 50 | args_cli.enable_cameras = True |
45 | 51 |
|
| 52 | +# clear out sys.argv for Hydra |
| 53 | +sys.argv = [sys.argv[0]] + hydra_args |
| 54 | + |
46 | 55 | # launch omniverse app |
47 | 56 | app_launcher = AppLauncher(args_cli) |
48 | 57 | simulation_app = app_launcher.app |
|
56 | 65 | import isaaclab_tasks # noqa: F401 |
57 | 66 | import SO_100.tasks # noqa: F401 |
58 | 67 | import torch |
59 | | -from isaaclab.envs import DirectMARLEnv, multi_agent_to_single_agent |
| 68 | +from isaaclab.envs import ( |
| 69 | + DirectMARLEnv, |
| 70 | + DirectMARLEnvCfg, |
| 71 | + DirectRLEnvCfg, |
| 72 | + ManagerBasedRLEnvCfg, |
| 73 | + multi_agent_to_single_agent, |
| 74 | +) |
60 | 75 | from isaaclab.utils.assets import retrieve_file_path |
61 | 76 | from isaaclab.utils.dict import print_dict |
62 | 77 | from isaaclab.utils.pretrained_checkpoint import get_published_pretrained_checkpoint |
63 | 78 | from isaaclab_rl.rsl_rl import ( |
64 | | - RslRlOnPolicyRunnerCfg, |
| 79 | + RslRlBaseRunnerCfg, |
65 | 80 | RslRlVecEnvWrapper, |
66 | 81 | export_policy_as_jit, |
67 | 82 | export_policy_as_onnx, |
68 | 83 | ) |
69 | | -from isaaclab_tasks.utils import get_checkpoint_path, parse_env_cfg |
70 | | -from rsl_rl.runners import OnPolicyRunner |
| 84 | +from isaaclab_tasks.utils import get_checkpoint_path |
| 85 | +from isaaclab_tasks.utils.hydra import hydra_task_config |
| 86 | +from rsl_rl.runners import DistillationRunner, OnPolicyRunner |
71 | 87 |
|
72 | 88 |
|
73 | | -def main(): |
| 89 | +@hydra_task_config(args_cli.task, args_cli.agent) |
| 90 | +def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: RslRlBaseRunnerCfg): |
74 | 91 | """Play with RSL-RL agent.""" |
75 | | - # parse configuration |
76 | | - env_cfg = parse_env_cfg( |
77 | | - args_cli.task, device=args_cli.device, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric |
78 | | - ) |
79 | | - agent_cfg: RslRlOnPolicyRunnerCfg = cli_args.parse_rsl_rl_cfg(args_cli.task, args_cli) |
| 92 | + # grab task name for checkpoint path |
| 93 | + task_name = args_cli.task.split(":")[-1] |
| 94 | + train_task_name = task_name.replace("-Play", "") |
| 95 | + |
| 96 | + # override configurations with non-hydra CLI arguments |
| 97 | + agent_cfg: RslRlBaseRunnerCfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli) |
| 98 | + env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs |
| 99 | + |
| 100 | + # set the environment seed |
| 101 | + # note: certain randomizations occur in the environment initialization so we set the seed here |
| 102 | + env_cfg.seed = agent_cfg.seed |
| 103 | + env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device |
80 | 104 |
|
81 | 105 | # specify directory for logging experiments |
82 | 106 | log_root_path = os.path.join("logs", "rsl_rl", agent_cfg.experiment_name) |
83 | 107 | log_root_path = os.path.abspath(log_root_path) |
84 | 108 | print(f"[INFO] Loading experiment from directory: {log_root_path}") |
85 | 109 | if args_cli.use_pretrained_checkpoint: |
86 | | - resume_path = get_published_pretrained_checkpoint("rsl_rl", args_cli.task) |
| 110 | + resume_path = get_published_pretrained_checkpoint("rsl_rl", train_task_name) |
87 | 111 | if not resume_path: |
88 | 112 | print("[INFO] Unfortunately a pre-trained checkpoint is currently unavailable for this task.") |
89 | 113 | return |
@@ -118,32 +142,43 @@ def main(): |
118 | 142 |
|
119 | 143 | print(f"[INFO]: Loading model checkpoint from: {resume_path}") |
120 | 144 | # load previously trained model |
121 | | - ppo_runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device) |
122 | | - ppo_runner.load(resume_path) |
| 145 | + if agent_cfg.class_name == "OnPolicyRunner": |
| 146 | + runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device) |
| 147 | + elif agent_cfg.class_name == "DistillationRunner": |
| 148 | + runner = DistillationRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device) |
| 149 | + else: |
| 150 | + raise ValueError(f"Unsupported runner class: {agent_cfg.class_name}") |
| 151 | + runner.load(resume_path) |
123 | 152 |
|
124 | 153 | # obtain the trained policy for inference |
125 | | - policy = ppo_runner.get_inference_policy(device=env.unwrapped.device) |
| 154 | + policy = runner.get_inference_policy(device=env.unwrapped.device) |
126 | 155 |
|
127 | 156 | # extract the neural network module |
128 | 157 | # we do this in a try-except to maintain backwards compatibility. |
129 | 158 | try: |
130 | 159 | # version 2.3 onwards |
131 | | - policy_nn = ppo_runner.alg.policy |
| 160 | + policy_nn = runner.alg.policy |
132 | 161 | except AttributeError: |
133 | 162 | # version 2.2 and below |
134 | | - policy_nn = ppo_runner.alg.actor_critic |
| 163 | + policy_nn = runner.alg.actor_critic |
| 164 | + |
| 165 | + # extract the normalizer |
| 166 | + if hasattr(policy_nn, "actor_obs_normalizer"): |
| 167 | + normalizer = policy_nn.actor_obs_normalizer |
| 168 | + elif hasattr(policy_nn, "student_obs_normalizer"): |
| 169 | + normalizer = policy_nn.student_obs_normalizer |
| 170 | + else: |
| 171 | + normalizer = None |
135 | 172 |
|
136 | 173 | # export policy to onnx/jit |
137 | 174 | export_model_dir = os.path.join(os.path.dirname(resume_path), "exported") |
138 | | - export_policy_as_jit(policy_nn, ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.pt") |
139 | | - export_policy_as_onnx( |
140 | | - policy_nn, normalizer=ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.onnx" |
141 | | - ) |
| 175 | + export_policy_as_jit(policy_nn, normalizer=normalizer, path=export_model_dir, filename="policy.pt") |
| 176 | + export_policy_as_onnx(policy_nn, normalizer=normalizer, path=export_model_dir, filename="policy.onnx") |
142 | 177 |
|
143 | 178 | dt = env.unwrapped.step_dt |
144 | 179 |
|
145 | 180 | # reset environment |
146 | | - obs, _ = env.get_observations() |
| 181 | + obs = env.get_observations() |
147 | 182 | timestep = 0 |
148 | 183 | # simulate environment |
149 | 184 | while simulation_app.is_running(): |
|
0 commit comments