Skip to content

Commit ba183ad

Browse files
quaglacopybara-github
authored andcommitted
Add assets to MjSpec wrapper.
This enables associating assets with a spec object. Also, remove `spec.compile(assets)` and replace it with the `spec.assets` attribute, which must be be set before compile if assets are present. PiperOrigin-RevId: 714981385 Change-Id: Ic3a33c3b75d3a7e14868622aadb57b2a8461a649
1 parent 64d0f57 commit ba183ad

File tree

2 files changed

+56
-30
lines changed

2 files changed

+56
-30
lines changed

python/mujoco/specs.cc

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -71,25 +71,43 @@ using MjDoubleRefVec = Eigen::Ref<const Eigen::VectorXd>;
7171

7272
struct MjSpec {
7373
MjSpec() : ptr(mj_makeSpec()) {}
74-
MjSpec(raw::MjSpec* ptr) : ptr(ptr) {}
74+
MjSpec(raw::MjSpec* ptr,
75+
const std::unordered_map<std::string, py::bytes>& assets_ = {})
76+
: ptr(ptr) {
77+
for (const auto& asset : assets_) {
78+
assets[asset.first.c_str()] = asset.second;
79+
}
80+
}
7581

7682
// copy constructor and assignment
77-
MjSpec(const MjSpec& other) : ptr(mj_copySpec(other.ptr)) {}
83+
MjSpec(const MjSpec& other) : ptr(mj_copySpec(other.ptr)) {
84+
assets = other.assets;
85+
}
7886
MjSpec& operator=(const MjSpec& other) {
7987
ptr = mj_copySpec(other.ptr);
88+
assets = other.assets;
8089
return *this;
8190
}
8291

8392
// move constructor and move assignment
84-
MjSpec(MjSpec&& other) : ptr(other.ptr) { other.ptr = nullptr; }
93+
MjSpec(MjSpec&& other) : ptr(other.ptr) {
94+
other.ptr = nullptr;
95+
assets = other.assets;
96+
other.assets.clear();
97+
}
8598
MjSpec& operator=(MjSpec&& other) {
8699
ptr = other.ptr;
87100
other.ptr = nullptr;
101+
assets = other.assets;
102+
other.assets.clear();
88103
return *this;
89104
}
90105

91-
~MjSpec() { mj_deleteSpec(ptr); }
106+
~MjSpec() {
107+
mj_deleteSpec(ptr);
108+
}
92109
raw::MjSpec* ptr;
110+
py::dict assets;
93111
};
94112

95113
template <typename LoadFunc>
@@ -263,6 +281,9 @@ PYBIND11_MODULE(_specs, m) {
263281
throw py::value_error(error);
264282
}
265283
}
284+
if (assets.has_value()) {
285+
return MjSpec(spec, assets.value());
286+
}
266287
return MjSpec(spec);
267288
},
268289
py::arg("filename"), py::arg("assets") = py::none(), R"mydelimiter(
@@ -305,6 +326,9 @@ PYBIND11_MODULE(_specs, m) {
305326
throw py::value_error(error);
306327
}
307328
}
329+
if (assets.has_value()) {
330+
return MjSpec(spec, assets.value());
331+
}
308332
return MjSpec(spec);
309333
},
310334
py::arg("xml"), py::arg("assets") = py::none(), R"mydelimiter(
@@ -324,7 +348,7 @@ PYBIND11_MODULE(_specs, m) {
324348
m, d);
325349
});
326350
mjSpec.def("copy", [](const MjSpec& self) -> MjSpec {
327-
return MjSpec(mj_copySpec(self.ptr));
351+
return MjSpec(self);
328352
});
329353
mjSpec.def_property_readonly(
330354
"worldbody",
@@ -370,33 +394,33 @@ PYBIND11_MODULE(_specs, m) {
370394
return mjs_findDefault(self.ptr, classname.c_str());
371395
},
372396
py::return_value_policy::reference_internal);
373-
mjSpec.def("compile", [mjmodel_from_spec_ptr](MjSpec& self) {
374-
return mjmodel_from_spec_ptr(reinterpret_cast<uintptr_t>(self.ptr));
397+
mjSpec.def("compile", [mjmodel_from_spec_ptr](MjSpec& self) -> py::object {
398+
if (self.assets.empty()) {
399+
return mjmodel_from_spec_ptr(reinterpret_cast<uintptr_t>(self.ptr));
400+
}
401+
mjVFS vfs;
402+
mj_defaultVFS(&vfs);
403+
for (auto item : self.assets) {
404+
std::string buffer = py::cast<std::string>(item.second);
405+
mj_addBufferVFS(&vfs, py::cast<std::string>(item.first).c_str(),
406+
buffer.c_str(), buffer.size());
407+
};
408+
auto model =
409+
mjmodel_from_spec_ptr(reinterpret_cast<uintptr_t>(self.ptr),
410+
reinterpret_cast<uintptr_t>(&vfs));
411+
mj_deleteVFS(&vfs);
412+
return model;
375413
});
376-
mjSpec.def(
377-
"compile",
378-
[mjmodel_from_spec_ptr](MjSpec& self, py::dict& assets) -> py::object {
379-
mjVFS vfs;
380-
mj_defaultVFS(&vfs);
414+
mjSpec.def_property(
415+
"assets",
416+
[](MjSpec& self) -> py::dict {
417+
return self.assets;
418+
},
419+
[](MjSpec& self, py::dict& assets) {
381420
for (auto item : assets) {
382-
std::string buffer = py::cast<std::string>(item.second);
383-
mj_addBufferVFS(&vfs, py::cast<std::string>(item.first).c_str(),
384-
buffer.c_str(), buffer.size());
421+
self.assets[item.first] = item.second;
385422
};
386-
auto model =
387-
mjmodel_from_spec_ptr(reinterpret_cast<uintptr_t>(self.ptr),
388-
reinterpret_cast<uintptr_t>(&vfs));
389-
mj_deleteVFS(&vfs);
390-
return model;
391-
}, R"mydelimiter(
392-
Compiles the spec and returns the compiled model.
393-
394-
Parameters
395-
----------
396-
assets : dict, optional
397-
A dictionary of assets to be used by the spec. The keys are asset names
398-
and the values are asset contents.
399-
)mydelimiter");
423+
}, py::return_value_policy::reference_internal);
400424
mjSpec.def("to_xml", [](MjSpec& self) -> std::string {
401425
int size = mj_saveXMLString(self.ptr, nullptr, 0, nullptr, 0);
402426
std::unique_ptr<char[]> buf(new char[size + 1]);

python/mujoco/specs_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -714,8 +714,10 @@ def test_assets(self):
714714
geom = spec.worldbody.add_geom()
715715
geom.type = mujoco.mjtGeom.mjGEOM_MESH
716716
geom.meshname = 'cube'
717-
model = spec.compile({'cube.obj': cube})
717+
spec.assets = {'cube.obj': cube}
718+
model = spec.compile()
718719
self.assertEqual(model.nmeshvert, 8)
720+
self.assertEqual(spec.assets['cube.obj'], cube)
719721

720722
def test_include(self):
721723
included_xml = """

0 commit comments

Comments
 (0)