|
| 1 | +// Copyright 2025 DeepMind Technologies Limited |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include "experimental/toolbox/helpers.h" |
| 16 | + |
| 17 | +#include <cstddef> |
| 18 | +#include <cstdint> |
| 19 | +#include <cstdio> |
| 20 | +#include <cstring> |
| 21 | +#include <fstream> |
| 22 | +#include <ios> |
| 23 | +#include <iterator> |
| 24 | +#include <string> |
| 25 | +#include <vector> |
| 26 | + |
| 27 | +#include "third_party/libwebp/src/webp/encode.h" |
| 28 | +#include "third_party/libwebp/src/webp/types.h" |
| 29 | +#include <mujoco/mjrender.h> |
| 30 | +#include <mujoco/mjxmacro.h> |
| 31 | +#include <mujoco/mujoco.h> |
| 32 | +#include "xml/xml_api.h" |
| 33 | + |
| 34 | +namespace mujoco::toolbox { |
| 35 | + |
| 36 | +mjModel* LoadMujocoModel(const std::string& model_file, const mjVFS* vfs) { |
| 37 | + mjModel* model = nullptr; |
| 38 | + |
| 39 | + if (model_file.empty()) { |
| 40 | + auto spec = mj_makeSpec(); |
| 41 | + model = mj_compile(spec, 0); |
| 42 | + mj_deleteSpec(spec); |
| 43 | + } else if (model_file.ends_with(".mjb")) { |
| 44 | + model = mj_loadModel(model_file.c_str(), 0); |
| 45 | + if (!model) { |
| 46 | + mju_error("Could not load binary model"); |
| 47 | + } |
| 48 | + } else if (model_file.ends_with(".xml")) { |
| 49 | + char error[1000] = ""; |
| 50 | + model = mj_loadXML(model_file.c_str(), vfs, error, sizeof(error)); |
| 51 | + if (!model) { |
| 52 | + mju_error("Load model error: %s", error); |
| 53 | + } |
| 54 | + } else { |
| 55 | + char error[1000] = ""; |
| 56 | + auto spec = |
| 57 | + mj_parseXMLString(model_file.c_str(), nullptr, error, sizeof(error)); |
| 58 | + if (!spec) { |
| 59 | + mju_error("Load model error: %s", error); |
| 60 | + } |
| 61 | + model = mj_compile(spec, 0); |
| 62 | + mj_deleteSpec(spec); |
| 63 | + } |
| 64 | + return model; |
| 65 | +} |
| 66 | + |
| 67 | +void SaveText(const std::string& contents, const std::string& filename) { |
| 68 | + std::ofstream file(filename); |
| 69 | + file.write(contents.data(), contents.size()); |
| 70 | + file.close(); |
| 71 | +} |
| 72 | + |
| 73 | +std::string LoadText(const std::string& filename) { |
| 74 | + std::ifstream file(filename); |
| 75 | + std::string contents((std::istreambuf_iterator<char>(file)), |
| 76 | + std::istreambuf_iterator<char>()); |
| 77 | + file.close(); |
| 78 | + return contents; |
| 79 | +} |
| 80 | + |
| 81 | +void SaveColorToWebp(int width, int height, const unsigned char* data, |
| 82 | + const std::string& filename) { |
| 83 | + uint8_t* webp = nullptr; |
| 84 | + const size_t size = |
| 85 | + WebPEncodeLosslessRGB(data, width, height, width * 3, &webp); |
| 86 | + |
| 87 | + std::ofstream file(filename, std::ios::binary); |
| 88 | + file.write(reinterpret_cast<const char*>(webp), size); |
| 89 | + file.close(); |
| 90 | + WebPFree(webp); |
| 91 | +} |
| 92 | + |
| 93 | +void SaveDepthToWebp(int width, int height, const float* data, |
| 94 | + const std::string& filename) { |
| 95 | + const int size = width * height; |
| 96 | + |
| 97 | + // Turn the depth buffer into a greyscale color buffer. |
| 98 | + std::vector<unsigned char> byte_buffer; |
| 99 | + byte_buffer.reserve(size * 3); |
| 100 | + for (int i = 0; i < size; ++i) { |
| 101 | + auto byte = static_cast<int>(255.0 * data[i]); |
| 102 | + byte_buffer.push_back(byte); |
| 103 | + byte_buffer.push_back(byte); |
| 104 | + byte_buffer.push_back(byte); |
| 105 | + } |
| 106 | + SaveColorToWebp(width, height, byte_buffer.data(), filename); |
| 107 | +} |
| 108 | + |
| 109 | +void SaveScreenshotToWebp(int width, int height, mjrContext* con, |
| 110 | + const std::string& filename) { |
| 111 | + mjr_setBuffer(mjFB_OFFSCREEN, con); |
| 112 | + auto rgb_buffer = std::vector<unsigned char>(3 * width * height); |
| 113 | + auto depth_buffer = std::vector<float>(width * height, 1.0f); |
| 114 | + mjrRect viewport = {0, 0, width, height}; |
| 115 | + mjr_readPixels(rgb_buffer.data(), depth_buffer.data(), viewport, con); |
| 116 | + mjr_setBuffer(mjFB_WINDOW, con); |
| 117 | + SaveColorToWebp(width, height, rgb_buffer.data(), filename); |
| 118 | +} |
| 119 | + |
| 120 | +const void* GetValue(const mjModel* model, const mjData* data, |
| 121 | + const char* field, int index) { |
| 122 | + MJDATA_POINTERS_PREAMBLE(model); |
| 123 | +#define X(TYPE, NAME, NR, NC) \ |
| 124 | + if (!std::strcmp(#NAME, field) && !std::strcmp(#TYPE, "mjtNum")) { \ |
| 125 | + if (index >= 0 && index < model->NR * NC) { \ |
| 126 | + return &data->NAME[index]; \ |
| 127 | + } else { \ |
| 128 | + return nullptr; \ |
| 129 | + } \ |
| 130 | + } |
| 131 | + MJDATA_POINTERS |
| 132 | +#undef X |
| 133 | + return nullptr; // Invalid field. |
| 134 | +} |
| 135 | + |
| 136 | +std::string CameraToString(const mjvScene* scene) { |
| 137 | + const mjvGLCamera* cameras = scene->camera; |
| 138 | + const float pos_x = (cameras[0].pos[0] + cameras[1].pos[0]) / 2; |
| 139 | + const float pos_y = (cameras[0].pos[1] + cameras[1].pos[1]) / 2; |
| 140 | + const float pos_z = (cameras[0].pos[2] + cameras[1].pos[2]) / 2; |
| 141 | + |
| 142 | + mjtNum cam_forward[3]; |
| 143 | + mju_f2n(cam_forward, cameras[0].forward, 3); |
| 144 | + mjtNum cam_up[3]; |
| 145 | + mju_f2n(cam_up, cameras[0].up, 3); |
| 146 | + mjtNum cam_right[3]; |
| 147 | + mju_cross(cam_right, cam_forward, cam_up); |
| 148 | + |
| 149 | + char str[500]; |
| 150 | + std::snprintf(str, sizeof(str), |
| 151 | + "<camera pos=\"%.3f %.3f %.3f\" xyaxes=\"%.3f %.3f %.3f %.3f " |
| 152 | + "%.3f %.3f\"/>\n", |
| 153 | + pos_x, pos_y, pos_z, cam_right[0], cam_right[1], cam_right[2], |
| 154 | + cam_up[0], cam_up[1], cam_up[2]); |
| 155 | + return str; |
| 156 | +} |
| 157 | + |
| 158 | +std::string KeyframeToString(const mjModel* model, const mjData* data, |
| 159 | + bool full_precision) { |
| 160 | + const int kStrLen = 5000; |
| 161 | + |
| 162 | + char buf[200]; |
| 163 | + const char p_regular[] = "%g"; |
| 164 | + const char p_full[] = "%-22.16g"; |
| 165 | + const char* format = full_precision ? p_full : p_regular; |
| 166 | + |
| 167 | + char str[kStrLen] = "<key\n"; |
| 168 | + |
| 169 | + // time |
| 170 | + std::strncat(str, " time=\"", kStrLen); |
| 171 | + std::snprintf(buf, sizeof(buf), format, data->time); |
| 172 | + std::strncat(str, buf, kStrLen); |
| 173 | + |
| 174 | + // qpos |
| 175 | + std::strncat(str, "\"\n qpos=\"", kStrLen); |
| 176 | + for (int i = 0; i < model->nq; i++) { |
| 177 | + std::snprintf(buf, sizeof(buf), format, data->qpos[i]); |
| 178 | + if (i < model->nq - 1) std::strncat(buf, " ", 200); |
| 179 | + std::strncat(str, buf, kStrLen); |
| 180 | + } |
| 181 | + |
| 182 | + // qvel |
| 183 | + std::strncat(str, "\"\n qvel=\"", kStrLen); |
| 184 | + for (int i = 0; i < model->nv; i++) { |
| 185 | + std::snprintf(buf, sizeof(buf), format, data->qvel[i]); |
| 186 | + if (i < model->nv - 1) std::strncat(buf, " ", 200); |
| 187 | + std::strncat(str, buf, kStrLen); |
| 188 | + } |
| 189 | + |
| 190 | + // act |
| 191 | + if (model->na > 0) { |
| 192 | + std::strncat(str, "\"\n act=\"", kStrLen); |
| 193 | + for (int i = 0; i < model->na; i++) { |
| 194 | + std::snprintf(buf, sizeof(buf), format, data->act[i]); |
| 195 | + if (i < model->na - 1) std::strncat(buf, " ", 200); |
| 196 | + std::strncat(str, buf, kStrLen); |
| 197 | + } |
| 198 | + } |
| 199 | + |
| 200 | + // ctrl |
| 201 | + if (model->nu > 0) { |
| 202 | + std::strncat(str, "\"\n ctrl=\"", kStrLen); |
| 203 | + for (int i = 0; i < model->nu; i++) { |
| 204 | + std::snprintf(buf, sizeof(buf), format, data->ctrl[i]); |
| 205 | + if (i < model->nu - 1) std::strncat(buf, " ", 200); |
| 206 | + std::strncat(str, buf, kStrLen); |
| 207 | + } |
| 208 | + } |
| 209 | + |
| 210 | + if (model->nmocap > 0) { |
| 211 | + std::strncat(str, "\"\n mpos=\"", kStrLen); |
| 212 | + for (int i = 0; i < 3 * model->nmocap; i++) { |
| 213 | + std::snprintf(buf, sizeof(buf), format, data->mocap_pos[i]); |
| 214 | + if (i < 3 * model->nmocap - 1) std::strncat(buf, " ", 200); |
| 215 | + std::strncat(str, buf, kStrLen); |
| 216 | + } |
| 217 | + |
| 218 | + // mocap_quat |
| 219 | + std::strncat(str, "\"\n mquat=\"", kStrLen); |
| 220 | + for (int i = 0; i < 4 * model->nmocap; i++) { |
| 221 | + std::snprintf(buf, sizeof(buf), format, data->mocap_quat[i]); |
| 222 | + if (i < 4 * model->nmocap - 1) std::strncat(buf, " ", 200); |
| 223 | + std::strncat(str, buf, kStrLen); |
| 224 | + } |
| 225 | + } |
| 226 | + |
| 227 | + std::strncat(str, "\"\n/>", kStrLen); |
| 228 | + return str; |
| 229 | +} |
| 230 | + |
| 231 | +} // namespace mujoco::toolbox |
0 commit comments