Skip to content

Commit f2fcd05

Browse files
thowellcopybara-github
authored andcommitted
Import google-deepmind/mujoco_warp from GitHub.
PiperOrigin-RevId: 807766737 Change-Id: I1c22361ffa61f27fd8e0fd534b96488531761916
1 parent 52da758 commit f2fcd05

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+17956
-3868
lines changed

mjx/cuda_requirements.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@ jax-cuda12-pjrt==0.5.3; python_version >= '3.10' \
1616
jax-cuda12-pjrt==0.4.30; python_version == '3.9' \
1717
--hash=sha256:895d0198ad99638fcaf976c47592e2a543eef79ea15fabd24a402d055390c328 \
1818
--hash=sha256:c36fb1e0c236563bf3a87e70f4d1ab28a31d7cf5d722c9ede30c4172116e8bcb
19-
warp-lang==1.8.1 \
20-
--hash=sha256:cfc59e1070ad71531b5d83186de48162507277af344a102fa33d5df9cdb942f7 \
21-
--hash=sha256:1db9ca92c46902b76bb99565c544347d1a32e9fb875ce902f1cafb94978d1ac3
19+
warp-lang==1.9.0 \
20+
--hash=sha256:23165d3291eeecc5ac47b9a3de0b93d34b4c921414e5a761ab96d68861efa3cb \
21+
--hash=sha256:7ea8057e5d6fdb9b0885d725b51954c5fc6dac62d6ac007828148b1cb13ee615

mjx/mujoco/mjx/_src/io_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
"""Tests for io functions."""
1616

1717
import os
18+
import tempfile
1819
from unittest import mock
20+
1921
from absl.testing import absltest
2022
from absl.testing import parameterized
2123
import jax
@@ -31,6 +33,7 @@
3133
# pylint: enable=g-importing-member
3234
import mujoco.mjx.warp as mjxw
3335
from mujoco.mjx.warp import types as mjxw_types
36+
from mujoco.mjx.warp import warp as wp # pylint: disable=g-importing-member
3437
import numpy as np
3538

3639

@@ -125,6 +128,12 @@ def _get_name_from_path(path: jax.tree_util.KeyPath) -> str:
125128
class ModelIOTest(parameterized.TestCase):
126129
"""IO tests for mjx.Model."""
127130

131+
def setUp(self):
132+
super().setUp()
133+
if mjxw.WARP_INSTALLED:
134+
self.tempdir = tempfile.TemporaryDirectory()
135+
wp.config.kernel_cache_dir = self.tempdir.name
136+
128137
@parameterized.product(
129138
xml=(_MULTIPLE_CONVEX_OBJECTS, _MULTIPLE_CONSTRAINTS),
130139
impl=('jax', 'c', 'warp'),
@@ -326,6 +335,12 @@ def check_ndim(path, x):
326335
class DataIOTest(parameterized.TestCase):
327336
"""IO tests for mjx.Data."""
328337

338+
def setUp(self):
339+
super().setUp()
340+
if mjxw.WARP_INSTALLED:
341+
self.tempdir = tempfile.TemporaryDirectory()
342+
wp.config.kernel_cache_dir = self.tempdir.name
343+
329344
@parameterized.parameters('jax', 'c')
330345
def test_make_data(self, impl: str):
331346
"""Test that make_data returns the correct shapes."""

mjx/mujoco/mjx/third_party/mujoco_warp/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from mujoco.mjx.third_party.mujoco_warp._src.types import Data as Data
2222
# isort: on
2323

24+
from ._src import test_util as test_util # used by viewer and testspeed, not meant for public consumption
2425
from mujoco.mjx.third_party.mujoco_warp._src.collision_driver import collision as collision
2526
from mujoco.mjx.third_party.mujoco_warp._src.collision_driver import nxn_broadphase as nxn_broadphase
2627
from mujoco.mjx.third_party.mujoco_warp._src.collision_driver import sap_broadphase as sap_broadphase
@@ -43,6 +44,7 @@
4344
from mujoco.mjx.third_party.mujoco_warp._src.io import make_data as make_data
4445
from mujoco.mjx.third_party.mujoco_warp._src.io import put_data as put_data
4546
from mujoco.mjx.third_party.mujoco_warp._src.io import put_model as put_model
47+
from mujoco.mjx.third_party.mujoco_warp._src.io import reset_data as reset_data
4648
from mujoco.mjx.third_party.mujoco_warp._src.passive import passive as passive
4749
from mujoco.mjx.third_party.mujoco_warp._src.ray import ray as ray
4850
from mujoco.mjx.third_party.mujoco_warp._src.sensor import energy_pos as energy_pos
@@ -66,8 +68,6 @@
6668
from mujoco.mjx.third_party.mujoco_warp._src.support import contact_force as contact_force
6769
from mujoco.mjx.third_party.mujoco_warp._src.support import mul_m as mul_m
6870
from mujoco.mjx.third_party.mujoco_warp._src.support import xfrc_accumulate as xfrc_accumulate
69-
from mujoco.mjx.third_party.mujoco_warp._src.test_util import BenchmarkSuite as BenchmarkSuite
70-
from mujoco.mjx.third_party.mujoco_warp._src.test_util import benchmark as benchmark
7171
from mujoco.mjx.third_party.mujoco_warp._src.types import BroadphaseFilter as BroadphaseFilter
7272
from mujoco.mjx.third_party.mujoco_warp._src.types import BroadphaseType as BroadphaseType
7373
from mujoco.mjx.third_party.mujoco_warp._src.types import ConeType as ConeType

mjx/mujoco/mjx/third_party/mujoco_warp/_src/broadphase_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,6 @@ def test_broadphase_filter(self):
326326
broadphase_caller(m, d)
327327
self.assertEqual(d.ncollision.numpy()[0], 0)
328328

329-
# TODO(team): test margin
330-
# TODO(team): test DisableBit.FILTERPARENT
331-
332329

333330
if __name__ == "__main__":
334331
wp.init()

0 commit comments

Comments
 (0)