Skip to content

Commit fc13995

Browse files
quaglacopybara-github
authored andcommitted
Yield ownership of vector<double>, vector<float>, and vector<int> to mjSpec in Python bindings.
Fixes #2756 PiperOrigin-RevId: 785417251 Change-Id: Ib399c46ee59585258ace7d9583c68b7d06d0a9e4
1 parent 60f9b34 commit fc13995

File tree

3 files changed

+13
-6
lines changed

3 files changed

+13
-6
lines changed

python/mujoco/codegen/generate_spec_bindings.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def _ptr_binding_code(
193193
[]({rawclassname}& self, std::string_view {varname}) {{
194194
*(self.{fullvarname}) = {varname};
195195
}});"""
196-
elif ( # C++ vectors of values -> Python array
196+
elif ( # C++ vectors of values -> custom array
197197
vartype == 'mjDoubleVec'
198198
or vartype == 'mjFloatVec'
199199
or vartype == 'mjIntVec'
@@ -202,17 +202,17 @@ def _ptr_binding_code(
202202
return f"""\
203203
{classname}.def_property(
204204
"{varname}",
205-
[]({rawclassname}& self) -> py::array_t<{vartype}> {{
206-
return py::array_t<{vartype}>(self.{fullvarname}->size(),
207-
self.{fullvarname}->data());
205+
[]({rawclassname}& self) -> MjTypeVec<{vartype}> {{
206+
return MjTypeVec<{vartype}>(self.{fullvarname}->data(),
207+
self.{fullvarname}->size());
208208
}},
209209
[]({rawclassname}& self, py::object rhs) {{
210210
self.{fullvarname}->clear();
211211
self.{fullvarname}->reserve(py::len(rhs));
212212
for (auto val : rhs) {{
213213
self.{fullvarname}->push_back(py::cast<{vartype}>(val));
214214
}}
215-
}}, py::return_value_policy::reference_internal);"""
215+
}}, py::return_value_policy::move);"""
216216
elif vartype == 'mjByteVec':
217217
return f"""\
218218
{classname}.def_property(

python/mujoco/specs.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,9 @@ PYBIND11_MODULE(_specs, m) {
221221
DefineArray<char>(m, "MjCharVec");
222222
DefineArray<std::string>(m, "MjStringVec");
223223
DefineArray<std::byte>(m, "MjByteVec");
224+
DefineArray<double>(m, "MjDoubleVec");
225+
DefineArray<float>(m, "MjFloatVec");
226+
DefineArray<int>(m, "MjIntVec");
224227

225228
// ============================= MJSPEC =====================================
226229
mjSpec.def(py::init<>());

python/mujoco/specs_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,13 @@ def test_basic(self):
9898
site = body.add_site()
9999
site.name = 'sitename'
100100
site.type = mujoco.mjtGeom.mjGEOM_BOX
101-
site.userdata = [1, 2, 3, 4, 5, 6]
101+
site.userdata = [7, 2, 3, 4, 5, 6]
102102
self.assertEqual(site.name, 'sitename')
103103
self.assertEqual(site.type, mujoco.mjtGeom.mjGEOM_BOX)
104+
np.testing.assert_array_equal(site.userdata, [7, 2, 3, 4, 5, 6])
105+
106+
# Modify a single element of userdata.
107+
site.userdata[0] = 1
104108
np.testing.assert_array_equal(site.userdata, [1, 2, 3, 4, 5, 6])
105109

106110
# Compile the spec and check for expected values in the model.

0 commit comments

Comments
 (0)