Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 34 additions & 24 deletions src/user/user_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4049,8 +4049,12 @@ void mjCModel::StoreKeyframes(mjCModel* dest) {
"To prevent this, compile the child model before attaching it again.");
}

// do not change compilation quantities in case the user wants to recompile preserving the state
// rebuild tree lists so that SaveDofOffsets computes correct sizes even when including
// things like `replicate` tags.
if (!compiled) {
ResetTreeLists();
MakeTreeLists();
ProcessLists(/*checkrepeat=*/false);
SaveDofOffsets(/*computesize=*/true);
ComputeReference();
}
Expand All @@ -4067,30 +4071,36 @@ void mjCModel::StoreKeyframes(mjCModel* dest) {
info.mpos = !key->spec_mpos_.empty();
info.mquat = !key->spec_mquat_.empty();
dest->key_pending_.push_back(info);
if (!key->spec_qpos_.empty() && key->spec_qpos_.size() != nq) {
throw mjCError(nullptr, "Keyframe '%s' has invalid qpos size, got %d, should be %d",
key->name.c_str(), key->spec_qpos_.size(), nq);
}
if (!key->spec_qvel_.empty() && key->spec_qvel_.size() != nv) {
throw mjCError(nullptr, "Keyframe %s has invalid qvel size, got %d, should be %d",
key->name.c_str(), key->spec_qvel_.size(), nv);
}
if (!key->spec_act_.empty() && key->spec_act_.size() != na) {
throw mjCError(nullptr, "Keyframe %s has invalid act size, got %d, should be %d",
key->name.c_str(), key->spec_act_.size(), na);
}
if (!key->spec_ctrl_.empty() && key->spec_ctrl_.size() != nu) {
throw mjCError(nullptr, "Keyframe %s has invalid ctrl size, got %d, should be %d",
key->name.c_str(), key->spec_ctrl_.size(), nu);
}
if (!key->spec_mpos_.empty() && key->spec_mpos_.size() != 3*nmocap) {
throw mjCError(nullptr, "Keyframe %s has invalid mpos size, got %d, should be %d",
key->name.c_str(), key->spec_mpos_.size(), 3*nmocap);
}
if (!key->spec_mquat_.empty() && key->spec_mquat_.size() != 4*nmocap) {
throw mjCError(nullptr, "Keyframe %s has invalid mquat size, got %d, should be %d",
key->name.c_str(), key->spec_mquat_.size(), 4*nmocap);

// Only validate if the model has already been compiled, otherwise the spec may not include
// all data.
if (compiled) {
if (!key->spec_qpos_.empty() && key->spec_qpos_.size() != nq) {
throw mjCError(nullptr, "Keyframe '%s' has invalid qpos size, got %d, should be %d",
key->name.c_str(), key->spec_qpos_.size(), nq);
}
if (!key->spec_qvel_.empty() && key->spec_qvel_.size() != nv) {
throw mjCError(nullptr, "Keyframe %s has invalid qvel size, got %d, should be %d",
key->name.c_str(), key->spec_qvel_.size(), nv);
}
if (!key->spec_act_.empty() && key->spec_act_.size() != na) {
throw mjCError(nullptr, "Keyframe %s has invalid act size, got %d, should be %d",
key->name.c_str(), key->spec_act_.size(), na);
}
if (!key->spec_ctrl_.empty() && key->spec_ctrl_.size() != nu) {
throw mjCError(nullptr, "Keyframe %s has invalid ctrl size, got %d, should be %d",
key->name.c_str(), key->spec_ctrl_.size(), nu);
}
if (!key->spec_mpos_.empty() && key->spec_mpos_.size() != 3*nmocap) {
throw mjCError(nullptr, "Keyframe %s has invalid mpos size, got %d, should be %d",
key->name.c_str(), key->spec_mpos_.size(), 3*nmocap);
}
if (!key->spec_mquat_.empty() && key->spec_mquat_.size() != 4*nmocap) {
throw mjCError(nullptr, "Keyframe %s has invalid mquat size, got %d, should be %d",
key->name.c_str(), key->spec_mquat_.size(), 4*nmocap);
}
}

SaveState(info.name, key->spec_qpos_.data(), key->spec_qvel_.data(),
key->spec_act_.data(), key->spec_ctrl_.data(),
key->spec_mpos_.data(), key->spec_mquat_.data());
Expand Down
81 changes: 81 additions & 0 deletions test/user/user_model_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -975,5 +975,86 @@ TEST_F(MujocoTest, ConvertSpringdamper) {
mj_deleteModel(model);
}

TEST_F(UserModelTest, KeyframeValidationChecks) {
static constexpr char xml[] = R"(
<mujoco>
<worldbody>
<body>
<joint axis="0 1 0"/>
<geom type="box" size="0.1 0.1 0.1"/>
</body>
</worldbody>

<keyframe>
<!-- Including an invalid keyframe -->
<key name="test" qpos="0 1"/>
</keyframe>
</mujoco>
)";

std::array<char, 1024> err;

// model is not compiled, so no errors are expected
mjSpec* spec = mj_parseXMLString(xml, 0, err.data(), err.size());
EXPECT_THAT(spec, NotNull()) << err.data();

// expect failure after compiling the model, validate the message
mjModel* model = mj_compile(spec, 0);
const char* spec_error = mjs_getError(spec);
EXPECT_THAT(model, IsNull());
EXPECT_THAT(spec_error, HasSubstr("expected 1, got 2"));
}

TEST_F(UserModelTest, ReplicateWithKeyframeInIncludedFile) {
static constexpr char robot_xml[] = R"(
<mujoco model="robot">
<worldbody>
<body name="arm">
<joint name="j1" type="hinge" axis="0 0 1"/>
<geom type="box" size="0.1 0.1 0.1"/>
<body name="replicate_body">
<replicate count="2" sep="-">
<site name="s"/>
</replicate>
</body>
</body>
</worldbody>
</mujoco>
)";

static constexpr char scene_xml[] = R"(
<mujoco model="scene">
<include file="robot.xml"/>
<worldbody>
<body name="cube" pos="0 0 1">
<freejoint/>
<inertial pos="0 0 0" mass="1" diaginertia="1 1 1"/>
<geom type="box" size="0.1 0.1 0.1"/>
</body>
</worldbody>
<keyframe>
<key name="test" qpos="0 0 0 1 1 0 0 0"/>
</keyframe>
</mujoco>
)";

std::array<char, 1024> err;

// add both files to VFS to test parsing the include + replicate
mjVFS vfs;
mj_defaultVFS(&vfs);
mj_addBufferVFS(&vfs, "robot.xml", robot_xml, sizeof(robot_xml) - 1);
mj_addBufferVFS(&vfs, "scene.xml", scene_xml, sizeof(scene_xml) - 1);

// load scene, expect success with correct keyframe size
mjModel* m = mj_loadXML("scene.xml", &vfs, err.data(), err.size());
ASSERT_THAT(m, NotNull()) << err.data();
EXPECT_EQ(m->nq, 8);
EXPECT_GE(m->nkey, 1);

mj_deleteModel(m);
mj_deleteVFS(&vfs);
}

} // namespace
} // namespace mujoco