|
20 | 20 |
|
21 | 21 | from absl.testing import absltest |
22 | 22 | from absl.testing import parameterized |
| 23 | +import numpy as np |
| 24 | + |
23 | 25 | import mujoco |
24 | 26 | from mujoco import rollout |
25 | | -import numpy as np |
26 | 27 |
|
27 | 28 |
|
28 | 29 | # -------------------------- models used for testing --------------------------- |
@@ -796,6 +797,84 @@ def test_stateless(self): |
796 | 797 | np.testing.assert_array_equal(state, state2) |
797 | 798 | np.testing.assert_array_equal(sensordata, sensordata2) |
798 | 799 |
|
| 800 | + def test_length_one_model_list(self): |
| 801 | + model = mujoco.MjModel.from_xml_string(TEST_XML) |
| 802 | + nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS) |
| 803 | + data = mujoco.MjData(model) |
| 804 | + |
| 805 | + initial_state = np.random.randn(nstate) |
| 806 | + control = np.random.randn(3, 3, model.nu) |
| 807 | + |
| 808 | + state, sensordata = rollout.rollout(model, data, initial_state, control) |
| 809 | + state2, sensordata2 = rollout.rollout([model], data, initial_state, control) |
| 810 | + |
| 811 | + # assert that we get same outputs |
| 812 | + np.testing.assert_array_equal(state, state2) |
| 813 | + np.testing.assert_array_equal(sensordata, sensordata2) |
| 814 | + |
| 815 | + def test_data_sizes(self): |
| 816 | + model = mujoco.MjModel.from_xml_string(TEST_XML) |
| 817 | + nstate = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_FULLPHYSICS) |
| 818 | + data = mujoco.MjData(model) |
| 819 | + |
| 820 | + initial_state = np.random.randn(nstate) |
| 821 | + control = np.random.randn(3, 3, model.nu) |
| 822 | + |
| 823 | + # Test passing empty lists for data |
| 824 | + with self.assertRaisesWithLiteralMatch( |
| 825 | + ValueError, 'The list of data instances is empty' |
| 826 | + ): |
| 827 | + rollout.rollout(model, [], initial_state, control) |
| 828 | + |
| 829 | + with self.assertRaisesWithLiteralMatch( |
| 830 | + ValueError, 'The list of data instances is empty' |
| 831 | + ): |
| 832 | + with rollout.Rollout(nthread=0) as rollout_: |
| 833 | + rollout_.rollout(model, [], initial_state, control) |
| 834 | + |
| 835 | + with self.assertRaisesWithLiteralMatch( |
| 836 | + ValueError, 'The list of data instances is empty' |
| 837 | + ): |
| 838 | + with rollout.Rollout(nthread=1) as rollout_: |
| 839 | + rollout_.rollout(model, [], initial_state, control) |
| 840 | + |
| 841 | + with self.assertRaisesWithLiteralMatch( |
| 842 | + ValueError, 'The list of data instances is empty' |
| 843 | + ): |
| 844 | + with rollout.Rollout(nthread=2) as rollout_: |
| 845 | + rollout_.rollout(model, [], initial_state, control) |
| 846 | + |
| 847 | + # Test checking that len(data) equals nthread |
| 848 | + with self.assertRaisesWithLiteralMatch( |
| 849 | + ValueError, |
| 850 | + 'More than one data instance passed but rollout is configured to run on' |
| 851 | + ' main thread', |
| 852 | + ): |
| 853 | + with rollout.Rollout(nthread=0) as rollout_: |
| 854 | + rollout_.rollout( |
| 855 | + model, [copy.copy(data) for i in range(2)], initial_state, control |
| 856 | + ) |
| 857 | + |
| 858 | + with self.assertRaisesWithLiteralMatch( |
| 859 | + ValueError, 'Length of data: 1 not equal to nthread: 2' |
| 860 | + ): |
| 861 | + with rollout.Rollout(nthread=2) as rollout_: |
| 862 | + rollout_.rollout(model, data, initial_state, control) |
| 863 | + |
| 864 | + with self.assertRaisesWithLiteralMatch( |
| 865 | + ValueError, 'Length of data: 1 not equal to nthread: 2' |
| 866 | + ): |
| 867 | + with rollout.Rollout(nthread=2) as rollout_: |
| 868 | + rollout_.rollout(model, [data], initial_state, control) |
| 869 | + |
| 870 | + with self.assertRaisesWithLiteralMatch( |
| 871 | + ValueError, 'Length of data: 3 not equal to nthread: 2' |
| 872 | + ): |
| 873 | + with rollout.Rollout(nthread=2) as rollout_: |
| 874 | + rollout_.rollout( |
| 875 | + model, [copy.copy(data) for i in range(3)], initial_state, control |
| 876 | + ) |
| 877 | + |
799 | 878 |
|
800 | 879 | # -------------- Python implementation of rollout functionality ---------------- |
801 | 880 |
|
|
0 commit comments