diff --git a/python/mujoco_mpc/mjx/README.md b/python/mujoco_mpc/mjx/README.md new file mode 100644 index 000000000..b4c82a32b --- /dev/null +++ b/python/mujoco_mpc/mjx/README.md @@ -0,0 +1,10 @@ +# MJX Predictive Sampling + +Run `handover` example: + +```sh +python visualize.py +``` + +## +Requires: mujoco, mujoco-mjx, jax[cuda], matplotlib, mediapy (Python), ffmpeg \ No newline at end of file diff --git a/python/mujoco_mpc/mjx/tasks/bimanual/handover.py b/python/mujoco_mpc/mjx/tasks/bimanual/handover.py index 4d1d81205..31efb3f97 100644 --- a/python/mujoco_mpc/mjx/tasks/bimanual/handover.py +++ b/python/mujoco_mpc/mjx/tasks/bimanual/handover.py @@ -13,13 +13,14 @@ # limitations under the License. # ============================================================================== -from etils import epath +from typing import Callable +from pathlib import Path import jax from jax import numpy as jp import mujoco from mujoco import mjx -from mujoco_mpc.mjx import predictive_sampling +CostFn = Callable[[mjx.Model, mjx.Data], jax.Array] def bring_to_target(m: mjx.Model, d: mjx.Data) -> jax.Array: """Returns cost for bimanual bring to target task.""" @@ -48,22 +49,14 @@ def bring_to_target(m: mjx.Model, d: mjx.Data) -> jax.Array: def get_models_and_cost_fn() -> ( - tuple[mujoco.MjModel, mujoco.MjModel, predictive_sampling.CostFn] + tuple[mujoco.MjModel, mujoco.MjModel, CostFn] ): """Returns a tuple of the model and the cost function.""" - path = epath.Path( - 'build/mjpc/tasks/bimanual/' + model_path = ( + Path(__file__).parent.parent.parent + / "../../../build/mjpc/tasks/bimanual/mjx_scene.xml" ) - model_file_name = 'mjx_scene.xml' - xml = (path / model_file_name).read_text() - assets = {} - for f in path.glob('*.xml'): - if f.name == model_file_name: - continue - assets[f.name] = f.read_bytes() - for f in (path / 'assets').glob('*'): - assets[f.name] = f.read_bytes() - sim_model = mujoco.MjModel.from_xml_string(xml, assets) - plan_model = mujoco.MjModel.from_xml_string(xml, assets) + sim_model = mujoco.MjModel.from_xml_path(str(model_path)) + plan_model = mujoco.MjModel.from_xml_path(str(model_path)) plan_model.opt.timestep = 0.01 # incidentally, already the case return sim_model, plan_model, bring_to_target diff --git a/python/mujoco_mpc/mjx/visualize.py b/python/mujoco_mpc/mjx/visualize.py index deb493d65..b3b21af87 100644 --- a/python/mujoco_mpc/mjx/visualize.py +++ b/python/mujoco_mpc/mjx/visualize.py @@ -16,8 +16,8 @@ import matplotlib.pyplot as plt import mediapy import mujoco -from mujoco_mpc.mjx import predictive_sampling -from mujoco_mpc.mjx.tasks.bimanual import handover +import predictive_sampling +from tasks.bimanual import handover # %% nsteps = 500 steps_per_plan = 4