Skip to content

Commit cee1221

Browse files
committed
rollout: fix bugs in checking list lengths, add tests
1 parent 699a676 commit cee1221

File tree

3 files changed

+90
-6
lines changed

3 files changed

+90
-6
lines changed

python/mujoco/rollout.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,16 +254,20 @@ class Rollout {
254254
}
255255

256256
// check length d and nthread are consistent
257-
if (this->nthread_ == 0 && py::len(d) > 1) {
257+
if (py::len(d) == 0) {
258+
std::ostringstream msg;
259+
msg << "The list of data instances is empty";
260+
throw py::value_error(msg.str());
261+
} else if (this->nthread_ == 0 && py::len(d) > 1) {
258262
std::ostringstream msg;
259263
msg << "More than one data instance passed but "
260264
<< "rollout is configured to run on main thread";
261-
py::value_error(msg.str());
262-
} else if (this->nthread_ != py::len(d)) {
265+
throw py::value_error(msg.str());
266+
} else if (this->nthread_ > 0 && this->nthread_ != py::len(d)) {
263267
std::ostringstream msg;
264268
msg << "Length of data: " << py::len(d)
265269
<< " not equal to nthread: " << this->nthread_;
266-
py::value_error(msg.str());
270+
throw py::value_error(msg.str());
267271
}
268272

269273
std::vector<raw::MjData*> data_ptrs(py::len(d));

python/mujoco/rollout.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def rollout(
172172
if isinstance(model, list) and nroll == 1:
173173
nroll = len(model)
174174

175-
if isinstance(model, list) and len(model) != nroll:
175+
if isinstance(model, list) and len(model) > 1 and len(model) != nroll:
176176
raise ValueError(
177177
f'nroll inferred as {nroll} but model is length {len(model)}'
178178
)

python/mujoco/rollout_test.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
"""tests for rollout function."""
1616

1717
import concurrent.futures
18+
import copy
1819
import threading
1920

2021
from absl.testing import absltest
2122
from absl.testing import parameterized
23+
import numpy as np
24+
2225
import mujoco
2326
from mujoco import rollout
24-
import numpy as np
2527

2628
# -------------------------- models used for testing ---------------------------
2729

@@ -794,6 +796,84 @@ def test_stateless(self):
794796
np.testing.assert_array_equal(state, state2)
795797
np.testing.assert_array_equal(sensordata, sensordata2)
796798

799+
def test_length_one_model_list(self):
800+
model = mujoco.MjModel.from_xml_string(TEST_XML)
801+
nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS)
802+
data = mujoco.MjData(model)
803+
804+
initial_state = np.random.randn(nstate)
805+
control = np.random.randn(3, 3, model.nu)
806+
807+
state, sensordata = rollout.rollout(model, data, initial_state, control)
808+
state2, sensordata2 = rollout.rollout([model], data, initial_state, control)
809+
810+
# assert that we get same outputs
811+
np.testing.assert_array_equal(state, state2)
812+
np.testing.assert_array_equal(sensordata, sensordata2)
813+
814+
def test_data_sizes(self):
815+
model = mujoco.MjModel.from_xml_string(TEST_XML)
816+
nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS)
817+
data = mujoco.MjData(model)
818+
819+
initial_state = np.random.randn(nstate)
820+
control = np.random.randn(3, 3, model.nu)
821+
822+
# Test passing empty lists for data
823+
with self.assertRaisesWithLiteralMatch(
824+
ValueError, 'The list of data instances is empty'
825+
):
826+
rollout.rollout(model, [], initial_state, control)
827+
828+
with self.assertRaisesWithLiteralMatch(
829+
ValueError, 'The list of data instances is empty'
830+
):
831+
with rollout.Rollout(nthread=0) as rollout_:
832+
rollout_.rollout(model, [], initial_state, control)
833+
834+
with self.assertRaisesWithLiteralMatch(
835+
ValueError, 'The list of data instances is empty'
836+
):
837+
with rollout.Rollout(nthread=1) as rollout_:
838+
rollout_.rollout(model, [], initial_state, control)
839+
840+
with self.assertRaisesWithLiteralMatch(
841+
ValueError, 'The list of data instances is empty'
842+
):
843+
with rollout.Rollout(nthread=2) as rollout_:
844+
rollout_.rollout(model, [], initial_state, control)
845+
846+
# Test checking that len(data) equals nthread
847+
with self.assertRaisesWithLiteralMatch(
848+
ValueError,
849+
'More than one data instance passed but rollout is configured to run on'
850+
' main thread',
851+
):
852+
with rollout.Rollout(nthread=0) as rollout_:
853+
rollout_.rollout(
854+
model, [copy.copy(data) for i in range(2)], initial_state, control
855+
)
856+
857+
with self.assertRaisesWithLiteralMatch(
858+
ValueError, 'Length of data: 1 not equal to nthread: 2'
859+
):
860+
with rollout.Rollout(nthread=2) as rollout_:
861+
rollout_.rollout(model, data, initial_state, control)
862+
863+
with self.assertRaisesWithLiteralMatch(
864+
ValueError, 'Length of data: 1 not equal to nthread: 2'
865+
):
866+
with rollout.Rollout(nthread=2) as rollout_:
867+
rollout_.rollout(model, [data], initial_state, control)
868+
869+
with self.assertRaisesWithLiteralMatch(
870+
ValueError, 'Length of data: 3 not equal to nthread: 2'
871+
):
872+
with rollout.Rollout(nthread=2) as rollout_:
873+
rollout_.rollout(
874+
model, [copy.copy(data) for i in range(3)], initial_state, control
875+
)
876+
797877

798878
# -------------- Python implementation of rollout functionality ----------------
799879

0 commit comments

Comments
 (0)