Replies: 2 comments
-
Does this example help? |
Beta Was this translation helpful? Give feedback.
0 replies
-
Thank you, I finally managed to load what I wanted. from brax.training.agents.ppo import checkpoint
from etils import epath
from ml_collections import config_dict
from mujoco_playground import registry
import mujoco
import mediapy as media
import json
import jax
path_checkpoints = epath.Path("abs_path_to_checkpoints").as_posix()
# Load the policy from the specified checkpoint
checkpoint_name = "000000000000001"
policy_fn = checkpoint.load_policy(
epath.Path(path_checkpoints) / checkpoint_name,
deterministic=True
)
# Load the config.json content as a string
env_cfg_text = (epath.Path(path_checkpoints) / "config.json").read_text()
# Parse the JSON string into a dictionary
env_cfg_dict = json.loads(env_cfg_text)
# Convert the dictionary to a ConfigDict object
env_cfg = config_dict.ConfigDict(env_cfg_dict)
# Create the environment using the loaded configuration
env_name = "registered_env"
eval_env = registry.load(env_name, config=env_cfg)
jit_inference_fn = jax.jit(policy_fn)
jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)
rng = jax.random.PRNGKey(123)
rng, reset_rng = jax.random.split(rng)
state = jit_reset(reset_rng)
state0 = state
rollout = [state0]
for _ in range(env_cfg.episode_length):
act_rng, rng = jax.random.split(rng)
ctrl, _ = jit_inference_fn(state.obs, act_rng)
state = jit_step(state, ctrl)
state0 = state
rollout.append(state0)
if state0.done:
break
render_every = 2
fps = 1.0 / eval_env.dt / render_every
print(f"FPS for rendering: {fps}")
traj = rollout[::render_every]
scene_option = mujoco.MjvOption()
scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = False
frames = eval_env.render(
traj,
camera="track", # only for camera="track"
scene_option=scene_option,
width=640,
height=480,
)
media.write_video("test.mp4", frames, fps=fps)
print("Rollout video saved as 'test.mp4'.") |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello MuJoCo Playground users,
I'm trying to figure out the best way to load and test a trained policy in my own code. I've noticed a few different methods in the tutorials, which has made the process a bit unclear.
locomotion.ipynb
notebook shows how to load a policy for fine-tuning.training.ipynb
on Brax notebook on Brax demonstrates saving and loading usingbrax.io.model
.However, the MuJoCo Playground examples save the model using
checkpoint
frombrax.training.agents.ppo
, but there isn't a clear tutorial on how to simply load one of these checkpoints for evaluation, testing, or rendering.Could someone please provide a brief tutorial or code example on how to:
Thank you for your time and for this great library!
Beta Was this translation helpful? Give feedback.
All reactions