Skip to content

Commit 03d308d

Browse files
committed
restore test dir except unnecessary files
1 parent 098c887 commit 03d308d

File tree

3 files changed

+138
-0
lines changed

3 files changed

+138
-0
lines changed

test/README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
If you are interested in contributing to Menagerie, you will need to install a few Python dependencies which exist for unit testing purposes.
2+
3+
In a virtual environment of your choice with Python 3.7 or later, run:
4+
5+
```bash
6+
pip install -r requirements.txt
7+
pytest test/
8+
```

test/model_test.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright 2022 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for all models."""
15+
16+
import pathlib
17+
from typing import List
18+
19+
from absl.testing import absltest
20+
from absl.testing import parameterized
21+
import jax
22+
import jax.numpy as jp
23+
import mujoco
24+
from mujoco import mjx
25+
26+
# Internal import.
27+
28+
29+
_ROOT_DIR = pathlib.Path(__file__).parent.parent
30+
_MODEL_DIRS = [f for f in _ROOT_DIR.iterdir() if f.is_dir()]
31+
_MODEL_XMLS: List[pathlib.Path] = []
32+
_MJX_MODEL_XMLS: List[pathlib.Path] = []
33+
34+
35+
def _get_xmls(pattern: str) -> List[pathlib.Path]:
36+
for d in _MODEL_DIRS:
37+
# Produce tuples of test name and XML path.
38+
for f in d.glob(pattern):
39+
test_name = str(f).removeprefix(str(f.parent.parent))
40+
yield (test_name, f)
41+
42+
_MODEL_XMLS = list(_get_xmls('scene*.xml'))
43+
_MJX_MODEL_XMLS = list(_get_xmls('scene*mjx.xml'))
44+
45+
# Total simulation duration, in seconds.
46+
_MAX_SIM_TIME = 0.1
47+
# Scale for the pseudorandom control noise.
48+
_NOISE_SCALE = 1.0
49+
50+
51+
def _pseudorandom_ctrlnoise(
52+
model: mujoco.MjModel,
53+
data: mujoco.MjData,
54+
i: int,
55+
noise: float,
56+
) -> None:
57+
for j in range(model.nu):
58+
ctrlrange = model.actuator_ctrlrange[j]
59+
if model.actuator_ctrllimited[j]:
60+
center = 0.5 * (ctrlrange[1] + ctrlrange[0])
61+
radius = 0.5 * (ctrlrange[1] - ctrlrange[0])
62+
else:
63+
center = 0.0
64+
radius = 1.0
65+
data.ctrl[j] = center + radius * noise * (2*mujoco.mju_Halton(i, j+2) - 1)
66+
67+
68+
class ModelsTest(parameterized.TestCase):
69+
"""Tests that MuJoCo models load and do not emit warnings."""
70+
71+
@parameterized.named_parameters(_MODEL_XMLS)
72+
def test_compiles_and_steps(self, xml_path: pathlib.Path) -> None:
73+
model = mujoco.MjModel.from_xml_path(str(xml_path))
74+
data = mujoco.MjData(model)
75+
i = 0
76+
while data.time < _MAX_SIM_TIME:
77+
_pseudorandom_ctrlnoise(model, data, i, _NOISE_SCALE)
78+
mujoco.mj_step(model, data)
79+
i += 1
80+
# Check no warnings were triggered during the simulation.
81+
if not all(data.warning.number == 0):
82+
warning_info = '\n'.join([
83+
f'{mujoco.mjtWarning(enum_value).name}: count={count}'
84+
for enum_value, count in enumerate(data.warning.number) if count
85+
])
86+
self.fail(f'MuJoCo warning(s) encountered:\n{warning_info}')
87+
88+
89+
class MjxModelsTest(parameterized.TestCase):
90+
"""Tests that MJX models load and do not return NaNs."""
91+
92+
@parameterized.named_parameters(_MJX_MODEL_XMLS)
93+
def test_compiles_and_steps(self, xml_path: pathlib.Path) -> None:
94+
model = mujoco.MjModel.from_xml_path(str(xml_path))
95+
model = mjx.put_model(model)
96+
data = mjx.make_data(model)
97+
ctrlrange = jp.where(
98+
model.actuator_ctrllimited[:, None],
99+
model.actuator_ctrlrange,
100+
jp.array([-10.0, 10.0]),
101+
)
102+
103+
def step(x, _):
104+
data, rng = x
105+
rng, key = jax.random.split(rng)
106+
ctrl = jax.random.uniform(
107+
key,
108+
shape=(model.nu,),
109+
minval=ctrlrange[:, 0],
110+
maxval=ctrlrange[:, 1],
111+
)
112+
data = mjx.step(model, data.replace(ctrl=ctrl))
113+
return (data, rng), ()
114+
115+
(data, _), _ = jax.lax.scan(
116+
step,
117+
(data, jax.random.PRNGKey(0)),
118+
(),
119+
length=min(_MAX_SIM_TIME // model.opt.timestep, 100),
120+
)
121+
122+
self.assertFalse(jp.isnan(data.qpos).any())
123+
124+
125+
if __name__ == '__main__':
126+
absltest.main()

test/requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
mujoco>=3.2.0
2+
mujoco-mjx
3+
absl-py
4+
pytest-xdist

0 commit comments

Comments
 (0)