Skip to content

Commit ac76540

Browse files
Merge pull request #2377 from aftersomemath:rollout-fix-ndata-check
PiperOrigin-RevId: 718181763 Change-Id: Ie88bf1d26c77478ce3a5ca8fdf29ae90e54b5834
2 parents 67dd482 + cee1221 commit ac76540

File tree

3 files changed

+89
-6
lines changed

3 files changed

+89
-6
lines changed

python/mujoco/rollout.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,16 +254,20 @@ class Rollout {
254254
}
255255

256256
// check length d and nthread are consistent
257-
if (this->nthread_ == 0 && py::len(d) > 1) {
257+
if (py::len(d) == 0) {
258+
std::ostringstream msg;
259+
msg << "The list of data instances is empty";
260+
throw py::value_error(msg.str());
261+
} else if (this->nthread_ == 0 && py::len(d) > 1) {
258262
std::ostringstream msg;
259263
msg << "More than one data instance passed but "
260264
<< "rollout is configured to run on main thread";
261-
py::value_error(msg.str());
262-
} else if (this->nthread_ != py::len(d)) {
265+
throw py::value_error(msg.str());
266+
} else if (this->nthread_ > 0 && this->nthread_ != py::len(d)) {
263267
std::ostringstream msg;
264268
msg << "Length of data: " << py::len(d)
265269
<< " not equal to nthread: " << this->nthread_;
266-
py::value_error(msg.str());
270+
throw py::value_error(msg.str());
267271
}
268272

269273
std::vector<raw::MjData*> data_ptrs(py::len(d));

python/mujoco/rollout.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def rollout(
172172
if isinstance(model, list) and nroll == 1:
173173
nroll = len(model)
174174

175-
if isinstance(model, list) and len(model) != nroll:
175+
if isinstance(model, list) and len(model) > 1 and len(model) != nroll:
176176
raise ValueError(
177177
f'nroll inferred as {nroll} but model is length {len(model)}'
178178
)

python/mujoco/rollout_test.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020

2121
from absl.testing import absltest
2222
from absl.testing import parameterized
23+
import numpy as np
24+
2325
import mujoco
2426
from mujoco import rollout
25-
import numpy as np
2627

2728

2829
# -------------------------- models used for testing ---------------------------
@@ -796,6 +797,84 @@ def test_stateless(self):
796797
np.testing.assert_array_equal(state, state2)
797798
np.testing.assert_array_equal(sensordata, sensordata2)
798799

800+
def test_length_one_model_list(self):
801+
model = mujoco.MjModel.from_xml_string(TEST_XML)
802+
nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS)
803+
data = mujoco.MjData(model)
804+
805+
initial_state = np.random.randn(nstate)
806+
control = np.random.randn(3, 3, model.nu)
807+
808+
state, sensordata = rollout.rollout(model, data, initial_state, control)
809+
state2, sensordata2 = rollout.rollout([model], data, initial_state, control)
810+
811+
# assert that we get same outputs
812+
np.testing.assert_array_equal(state, state2)
813+
np.testing.assert_array_equal(sensordata, sensordata2)
814+
815+
def test_data_sizes(self):
816+
model = mujoco.MjModel.from_xml_string(TEST_XML)
817+
nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS)
818+
data = mujoco.MjData(model)
819+
820+
initial_state = np.random.randn(nstate)
821+
control = np.random.randn(3, 3, model.nu)
822+
823+
# Test passing empty lists for data
824+
with self.assertRaisesWithLiteralMatch(
825+
ValueError, 'The list of data instances is empty'
826+
):
827+
rollout.rollout(model, [], initial_state, control)
828+
829+
with self.assertRaisesWithLiteralMatch(
830+
ValueError, 'The list of data instances is empty'
831+
):
832+
with rollout.Rollout(nthread=0) as rollout_:
833+
rollout_.rollout(model, [], initial_state, control)
834+
835+
with self.assertRaisesWithLiteralMatch(
836+
ValueError, 'The list of data instances is empty'
837+
):
838+
with rollout.Rollout(nthread=1) as rollout_:
839+
rollout_.rollout(model, [], initial_state, control)
840+
841+
with self.assertRaisesWithLiteralMatch(
842+
ValueError, 'The list of data instances is empty'
843+
):
844+
with rollout.Rollout(nthread=2) as rollout_:
845+
rollout_.rollout(model, [], initial_state, control)
846+
847+
# Test checking that len(data) equals nthread
848+
with self.assertRaisesWithLiteralMatch(
849+
ValueError,
850+
'More than one data instance passed but rollout is configured to run on'
851+
' main thread',
852+
):
853+
with rollout.Rollout(nthread=0) as rollout_:
854+
rollout_.rollout(
855+
model, [copy.copy(data) for i in range(2)], initial_state, control
856+
)
857+
858+
with self.assertRaisesWithLiteralMatch(
859+
ValueError, 'Length of data: 1 not equal to nthread: 2'
860+
):
861+
with rollout.Rollout(nthread=2) as rollout_:
862+
rollout_.rollout(model, data, initial_state, control)
863+
864+
with self.assertRaisesWithLiteralMatch(
865+
ValueError, 'Length of data: 1 not equal to nthread: 2'
866+
):
867+
with rollout.Rollout(nthread=2) as rollout_:
868+
rollout_.rollout(model, [data], initial_state, control)
869+
870+
with self.assertRaisesWithLiteralMatch(
871+
ValueError, 'Length of data: 3 not equal to nthread: 2'
872+
):
873+
with rollout.Rollout(nthread=2) as rollout_:
874+
rollout_.rollout(
875+
model, [copy.copy(data) for i in range(3)], initial_state, control
876+
)
877+
799878

800879
# -------------- Python implementation of rollout functionality ----------------
801880

0 commit comments

Comments
 (0)