Skip to content

Commit 924ee30

Browse files
havesscopybara-github
authored andcommitted
Modify user values in mujoco to allow users to provide cleanup functions to avoid memory leaks.
For the MJCF -> USD plugin, and I suspect other usecases for user values, it's necessary to give ownership of the object to Mujoco. However since they get type erased, we can't clean up the data automatically for the user. Instead this allows c++ clients to provide a cleanup function with their data. PiperOrigin-RevId: 756832010 Change-Id: I80b8e7822e1a0e399a0d19dcaa57ee69b3ccc16a
1 parent baf8426 commit 924ee30

File tree

9 files changed

+91
-14
lines changed

9 files changed

+91
-14
lines changed

doc/APIreference/functions.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4459,6 +4459,17 @@ Transform body into a frame.
44594459

44604460
Set user payload, overriding the existing value for the specified key if present.
44614461

4462+
.. _mjs_setUserValueWithCleanup:
4463+
4464+
`mjs_setUserValueWithCleanup <#mjs_setUserValueWithCleanup>`__
4465+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4466+
4467+
.. mujoco-include:: mjs_setUserValueWithCleanup
4468+
4469+
Set user payload, overriding the existing value for the specified key if
4470+
present. This version differs from mjs_setUserValue in that it takes a
4471+
cleanup function that will be called when the user payload is deleted.
4472+
44624473
.. _mjs_getUserValue:
44634474

44644475
`mjs_getUserValue <#mjs_getUserValue>`__

doc/includes/references.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3732,6 +3732,9 @@ const char* mjs_resolveOrientation(double quat[4], mjtByte degree, const char* s
37323732
const mjsOrientation* orientation);
37333733
mjsFrame* mjs_bodyToFrame(mjsBody** body);
37343734
void mjs_setUserValue(mjsElement* element, const char* key, const void* data);
3735+
void mjs_setUserValueWithCleanup(mjsElement* element, const char* key,
3736+
const void* data,
3737+
void (*cleanup)(const void*));
37353738
const void* mjs_getUserValue(mjsElement* element, const char* key);
37363739
void mjs_deleteUserValue(mjsElement* element, const char* key);
37373740
void mjs_defaultSpec(mjSpec* spec);

include/mujoco/mujoco.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1644,6 +1644,13 @@ MJAPI mjsFrame* mjs_bodyToFrame(mjsBody** body);
16441644
// Set user payload, overriding the existing value for the specified key if present.
16451645
MJAPI void mjs_setUserValue(mjsElement* element, const char* key, const void* data);
16461646

1647+
// Set user payload, overriding the existing value for the specified key if
1648+
// present. This version differs from mjs_setUserValue in that it takes a
1649+
// cleanup function that will be called when the user payload is deleted.
1650+
MJAPI void mjs_setUserValueWithCleanup(mjsElement* element, const char* key,
1651+
const void* data,
1652+
void (*cleanup)(const void*));
1653+
16471654
// Return user payload or NULL if none found.
16481655
MJAPI const void* mjs_getUserValue(mjsElement* element, const char* key);
16491656

python/mujoco/introspect/codegen/generate_functions.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,13 @@ def _make_comment(self, node: ClangJsonNode) -> str:
9999
return ''.join(strings)
100100

101101
def visit(self, node: ClangJsonNode) -> None:
102-
if (node.get('kind') == 'FunctionDecl' and
103-
node.get('name', '').startswith('mj')):
102+
# Skip mjs_setUserValueWithCleanup as it's only useful for heap allocated
103+
# objects and doesn't need a python wrapper.
104+
if (
105+
node.get('kind') == 'FunctionDecl'
106+
and node.get('name', '').startswith('mj')
107+
and node.get('name', '') != 'mjs_setUserValueWithCleanup'
108+
):
104109
func_decl = self._make_function(node)
105110
self._functions[func_decl.name] = func_decl
106111

src/user/user_api.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -860,16 +860,18 @@ mjsFrame* mjs_bodyToFrame(mjsBody** body) {
860860
return &frameC->spec;
861861
}
862862

863-
863+
void mjs_setUserValue(mjsElement* element, const char* key, const void* data) {
864+
mjs_setUserValueWithCleanup(element, key, data, nullptr);
865+
}
864866

865867
// set user payload
866-
void mjs_setUserValue(mjsElement* element, const char* key, const void* data) {
868+
void mjs_setUserValueWithCleanup(mjsElement* element, const char* key,
869+
const void* data,
870+
void (*cleanup)(const void*)) {
867871
mjCBase* baseC = static_cast<mjCBase*>(element);
868-
baseC->SetUserValue(key, data);
872+
baseC->SetUserValue(key, data, cleanup);
869873
}
870874

871-
872-
873875
// return user payload or NULL if none found
874876
const void* mjs_getUserValue(mjsElement* element, const char* key) {
875877
mjCBase* baseC = static_cast<mjCBase*>(element);

src/user/user_api.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,11 @@ MJAPI mjsFrame* mjs_bodyToFrame(mjsBody** body);
374374
// Set user payload.
375375
MJAPI void mjs_setUserValue(mjsElement* element, const char* key, const void* data);
376376

377+
// Set user payload.
378+
MJAPI void mjs_setUserValueWithCleanup(mjsElement* element, const char* key,
379+
const void* data,
380+
void (*cleanup)(const void*));
381+
377382
// Return user payload or NULL if none found.
378383
MJAPI const void* mjs_getUserValue(mjsElement* element, const char* key);
379384

src/user/user_objects.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -784,15 +784,14 @@ void mjCBase::SetFrame(mjCFrame* _frame) {
784784
frame = _frame;
785785
}
786786

787-
788-
void mjCBase::SetUserValue(std::string_view key, const void* data) {
789-
user_payload_[std::string(key)] = data;
787+
void mjCBase::SetUserValue(std::string_view key, const void* data,
788+
void (*cleanup)(const void*)) {
789+
user_payload_[std::string(key)] = UserValue(data, cleanup);
790790
}
791791

792-
793792
const void* mjCBase::GetUserValue(std::string_view key) {
794793
auto found = user_payload_.find(std::string(key));
795-
return found != user_payload_.end() ? found->second : nullptr;
794+
return found != user_payload_.end() ? found->second.value : nullptr;
796795
}
797796

798797

src/user/user_objects.h

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,8 @@ class mjCBase : public mjCBase_ {
281281
}
282282

283283
// Set and get user payload
284-
void SetUserValue(std::string_view key, const void* data);
284+
void SetUserValue(std::string_view key, const void* data,
285+
void (*cleanup)(const void*));
285286
const void* GetUserValue(std::string_view key);
286287
void DeleteUserValue(std::string_view key);
287288

@@ -292,8 +293,44 @@ class mjCBase : public mjCBase_ {
292293
// reference count for allowing deleting an attached object
293294
int refcount = 1;
294295

296+
// Arbitrary user value that cleans up the data when destroyed.
297+
struct UserValue {
298+
const void* value = nullptr;
299+
void (*cleanup)(const void*) = nullptr;
300+
301+
UserValue() {}
302+
UserValue(const void* value, void (*cleanup)(const void*))
303+
: value(value), cleanup(cleanup) {}
304+
UserValue(const UserValue& other) = delete;
305+
UserValue& operator=(const UserValue& other) = delete;
306+
307+
UserValue(UserValue&& other) : value(other.value), cleanup(other.cleanup) {
308+
other.value = nullptr;
309+
other.cleanup = nullptr;
310+
}
311+
312+
UserValue& operator=(UserValue&& other) {
313+
if (this != &other) {
314+
if (cleanup && value) {
315+
cleanup(value);
316+
}
317+
value = other.value;
318+
cleanup = other.cleanup;
319+
other.value = nullptr;
320+
other.cleanup = nullptr;
321+
}
322+
return *this;
323+
}
324+
325+
~UserValue() {
326+
if (cleanup && value) {
327+
cleanup(value);
328+
}
329+
}
330+
};
331+
295332
// user payload
296-
std::unordered_map<std::string, const void*> user_payload_;
333+
std::unordered_map<std::string, UserValue> user_payload_;
297334
};
298335

299336

test/user/user_api_test.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2845,6 +2845,14 @@ TEST_F(MujocoTest, UserValue) {
28452845
EXPECT_STREQ(static_cast<const char*>(payload), data.c_str());
28462846
mjs_deleteUserValue(body->element, "key");
28472847
EXPECT_THAT(mjs_getUserValue(body->element, "key"), IsNull());
2848+
2849+
std::string* heap_data = new std::string("heap_data");
2850+
mjs_setUserValueWithCleanup(
2851+
body->element, "key", heap_data,
2852+
[](const void* data) { delete static_cast<const std::string*>(data); });
2853+
payload = mjs_getUserValue(body->element, "key");
2854+
EXPECT_STREQ(static_cast<const std::string*>(payload)->c_str(),
2855+
heap_data->c_str());
28482856
mj_deleteSpec(spec);
28492857
}
28502858

0 commit comments

Comments
 (0)