|
14 | 14 |
|
15 | 15 | #include "experimental/usd/kinematic_tree.h"
|
16 | 16 |
|
17 |
| -#include <algorithm> |
18 |
| -#include <deque> |
19 | 17 | #include <map>
|
20 | 18 | #include <memory>
|
21 |
| -#include <set> |
22 | 19 | #include <utility>
|
23 | 20 | #include <vector>
|
24 | 21 |
|
| 22 | +#include <mujoco/experimental/usd/mjcPhysics/actuator.h> |
| 23 | +#include <mujoco/experimental/usd/mjcPhysics/keyframe.h> |
| 24 | +#include <mujoco/experimental/usd/mjcPhysics/siteAPI.h> |
25 | 25 | #include <mujoco/mujoco.h>
|
26 | 26 | #include <pxr/usd/sdf/path.h>
|
| 27 | +#include <pxr/usd/usd/common.h> |
| 28 | +#include <pxr/usd/usd/primRange.h> |
| 29 | +#include <pxr/usd/usdGeom/xformCache.h> |
| 30 | +#include <pxr/usd/usdPhysics/collisionAPI.h> |
27 | 31 | #include <pxr/usd/usdPhysics/joint.h>
|
| 32 | +#include <pxr/usd/usdPhysics/rigidBodyAPI.h> |
| 33 | +#include <pxr/usd/usdPhysics/scene.h> |
28 | 34 |
|
29 | 35 | namespace mujoco {
|
30 | 36 | namespace usd {
|
31 | 37 |
|
32 |
| -bool GetJointBodies(const pxr::UsdPhysicsJoint& joint, |
33 |
| - const pxr::SdfPath& default_prim_path, pxr::SdfPath* from, |
| 38 | +bool GetJointBodies(const pxr::UsdPhysicsJoint& joint, pxr::SdfPath* from, |
34 | 39 | pxr::SdfPath* to) {
|
| 40 | + // Grab the default prim path |
| 41 | + pxr::SdfPath default_prim_path; |
| 42 | + auto stage = joint.GetPrim().GetStage(); |
| 43 | + if (stage->GetDefaultPrim().IsValid()) { |
| 44 | + default_prim_path = stage->GetDefaultPrim().GetPath(); |
| 45 | + } |
| 46 | + |
35 | 47 | pxr::SdfPathVector body1_paths;
|
36 | 48 | joint.GetBody1Rel().GetTargets(&body1_paths);
|
37 | 49 | if (body1_paths.empty()) {
|
@@ -62,121 +74,187 @@ bool GetJointBodies(const pxr::UsdPhysicsJoint& joint,
|
62 | 74 | return true;
|
63 | 75 | }
|
64 | 76 |
|
65 |
| -std::unique_ptr<KinematicNode> BuildKinematicTree( |
66 |
| - const std::vector<pxr::UsdPhysicsJoint>& joints, |
67 |
| - const std::vector<pxr::SdfPath>& all_body_paths, |
68 |
| - const pxr::SdfPath& default_prim_path) { |
69 |
| - std::map<pxr::SdfPath, std::vector<pxr::SdfPath>> children_map; |
70 |
| - std::map<pxr::SdfPath, pxr::SdfPath> parent_map; |
71 |
| - std::map<std::pair<pxr::SdfPath, pxr::SdfPath>, pxr::SdfPath> |
72 |
| - edge_to_joint_map; |
73 |
| - std::set<pxr::SdfPath> all_nodes(all_body_paths.begin(), |
74 |
| - all_body_paths.end()); |
75 |
| - |
76 |
| - for (const auto& joint : joints) { |
77 |
| - pxr::SdfPath from, to; |
78 |
| - if (!GetJointBodies(joint, default_prim_path, &from, &to)) { |
| 77 | +struct ExtractedPrims { |
| 78 | + std::vector<std::unique_ptr<Node>> nodes; |
| 79 | + std::vector<pxr::UsdPhysicsJoint> joints; |
| 80 | +}; |
| 81 | + |
| 82 | +ExtractedPrims ExtractPrims(pxr::UsdStageRefPtr stage) { |
| 83 | + pxr::UsdPhysicsScene physics_scene; |
| 84 | + std::vector<std::unique_ptr<Node>> nodes; |
| 85 | + std::vector<pxr::UsdPhysicsJoint> joints; |
| 86 | + nodes.push_back(std::make_unique<Node>()); |
| 87 | + Node* root = nodes.back().get(); |
| 88 | + |
| 89 | + // ========================================================================= |
| 90 | + // PASS 1: Collect Bodies, Joints, and Geoms/Sites/etc. |
| 91 | + // ========================================================================= |
| 92 | + // A single DFS pass to find all bodies, joints, and determine |
| 93 | + // which body owns each geom/site/etc. prim. |
| 94 | + std::vector<Node*> owner_stack; |
| 95 | + owner_stack.push_back(root); // Start with the world as owner. |
| 96 | + |
| 97 | + const auto range = pxr::UsdPrimRange::PreAndPostVisit( |
| 98 | + stage->GetPseudoRoot(), pxr::UsdTraverseInstanceProxies()); |
| 99 | + |
| 100 | + pxr::UsdGeomXformCache xform_cache; |
| 101 | + for (auto it = range.begin(); it != range.end(); ++it) { |
| 102 | + pxr::UsdPrim prim = *it; |
| 103 | + |
| 104 | + bool is_body = prim.HasAPI<pxr::UsdPhysicsRigidBodyAPI>(); |
| 105 | + bool resets = xform_cache.GetResetXformStack(prim); |
| 106 | + // Only update (push/pop) the owner stack for bodies (becomes new owner) and |
| 107 | + // resetXformStack (reset owner to world). |
| 108 | + bool is_pushed_to_stack = is_body || resets; |
| 109 | + |
| 110 | + if (it.IsPostVisit()) { |
| 111 | + if (is_pushed_to_stack) { |
| 112 | + owner_stack.pop_back(); |
| 113 | + } |
79 | 114 | continue;
|
80 | 115 | }
|
81 | 116 |
|
82 |
| - auto edge_key = std::make_pair(from, to); |
83 |
| - auto it = edge_to_joint_map.find(edge_key); |
84 |
| - if (it == edge_to_joint_map.end()) { |
85 |
| - edge_to_joint_map[edge_key] = joint.GetPath(); |
86 |
| - } else { |
87 |
| - mju_warning( |
88 |
| - "Multiple explicit joints defined between body %s and body %s. " |
89 |
| - "Joint1: %s, Joint2: %s. Keeping the first one found: %s", |
90 |
| - (from.IsEmpty() ? "<worldbody>" : from.GetString()).c_str(), |
91 |
| - to.GetString().c_str(), it->second.GetString().c_str(), |
92 |
| - joint.GetPath().GetString().c_str(), it->second.GetString().c_str()); |
93 |
| - continue; |
| 117 | + pxr::SdfPath prim_path = prim.GetPath(); |
| 118 | + Node* current_node = owner_stack.back(); |
| 119 | + |
| 120 | + if (is_body) { |
| 121 | + auto new_node = std::make_unique<Node>(); |
| 122 | + new_node->body_path = prim_path; |
| 123 | + nodes.push_back(std::move(new_node)); |
| 124 | + current_node = nodes.back().get(); |
| 125 | + } else if (resets) { |
| 126 | + current_node = root; // Reset owner to world. |
94 | 127 | }
|
95 | 128 |
|
96 |
| - if (from == to) { |
97 |
| - mju_error("Self-loop detected at node %s", to.GetString().c_str()); |
98 |
| - return nullptr; |
| 129 | + if (is_pushed_to_stack) { |
| 130 | + owner_stack.push_back(current_node); |
99 | 131 | }
|
100 |
| - if (parent_map.count(to)) { |
101 |
| - mju_error("Node %s has multiple parents ('%s' and '%s').", |
102 |
| - to.GetString().c_str(), parent_map.at(to).GetString().c_str(), |
103 |
| - from.GetString().c_str()); |
104 |
| - return nullptr; |
| 132 | + |
| 133 | + if (prim.IsA<pxr::UsdPhysicsScene>() && root->physics_scene.IsEmpty()) { |
| 134 | + root->physics_scene = prim_path; |
| 135 | + } |
| 136 | + |
| 137 | + if (prim.HasAPI<pxr::UsdPhysicsCollisionAPI>()) { |
| 138 | + current_node->colliders.push_back(prim.GetPath()); |
| 139 | + } |
| 140 | + |
| 141 | + if (prim.HasAPI<pxr::MjcPhysicsSiteAPI>()) { |
| 142 | + current_node->sites.push_back(prim.GetPath()); |
| 143 | + // Sites should not have children. |
| 144 | + it.PruneChildren(); |
| 145 | + } |
| 146 | + |
| 147 | + if (prim.IsA<pxr::UsdPhysicsJoint>()) { |
| 148 | + // We may not know which body this belongs to yet so we'll add it to a |
| 149 | + // list and the caller can assign the joints when building the tree. |
| 150 | + joints.push_back(pxr::UsdPhysicsJoint(prim)); |
| 151 | + // Joints should not have children. |
| 152 | + it.PruneChildren(); |
| 153 | + } |
| 154 | + |
| 155 | + if (prim.IsA<pxr::MjcPhysicsActuator>()) { |
| 156 | + root->actuators.push_back(prim_path); |
| 157 | + // Joints should not have children. |
| 158 | + it.PruneChildren(); |
| 159 | + } |
| 160 | + |
| 161 | + if (prim.IsA<pxr::MjcPhysicsKeyframe>()) { |
| 162 | + root->keyframes.push_back(prim.GetPath()); |
| 163 | + // Keyframes should not have children. |
| 164 | + it.PruneChildren(); |
105 | 165 | }
|
106 |
| - children_map[from].push_back(to); |
107 |
| - parent_map[to] = from; |
108 |
| - all_nodes.insert(from); |
109 |
| - all_nodes.insert(to); |
110 | 166 | }
|
| 167 | + return {.nodes = std::move(nodes), .joints = std::move(joints)}; |
| 168 | +} |
| 169 | + |
| 170 | +std::unique_ptr<Node> BuildKinematicTree(const pxr::UsdStageRefPtr stage) { |
| 171 | + ExtractedPrims extraction = ExtractPrims(stage); |
111 | 172 |
|
112 |
| - // Sort children in children_map to respect the DFS order from the stage. |
113 |
| - for (auto& [_, children] : children_map) { |
114 |
| - std::sort( |
115 |
| - children.begin(), children.end(), |
116 |
| - [&v = all_body_paths](const auto& a, const auto& b) { |
117 |
| - return std::distance(v.begin(), std::find(v.begin(), v.end(), a)) < |
118 |
| - std::distance(v.begin(), std::find(v.begin(), v.end(), b)); |
119 |
| - }); |
| 173 | + std::map<pxr::SdfPath, int> body_index; |
| 174 | + body_index[pxr::SdfPath()] = 0; |
| 175 | + for (int i = 0; i < extraction.nodes.size(); ++i) { |
| 176 | + body_index[extraction.nodes[i]->body_path] = i; |
120 | 177 | }
|
121 | 178 |
|
122 |
| - // The world body is represented by an empty SdfPath. |
123 |
| - auto world_root = std::make_unique<KinematicNode>(); |
124 |
| - std::map<pxr::SdfPath, KinematicNode*> node_map; |
125 |
| - node_map[pxr::SdfPath()] = world_root.get(); |
126 |
| - |
127 |
| - // Use a deque for traversal. We will add roots to the back and children |
128 |
| - // to the front to perform a DFS on each root's tree. |
129 |
| - std::deque<pxr::SdfPath> q; |
130 |
| - |
131 |
| - // Add roots (floating-base bodies and children of the world) to the queue, |
132 |
| - // preserving the DFS order from the USD stage. |
133 |
| - for (const auto& body_path : all_body_paths) { |
134 |
| - if (!body_path.IsEmpty()) { |
135 |
| - const auto it = parent_map.find(body_path); |
136 |
| - // A root is a body that has no parent, or its parent is the world. |
137 |
| - if (it == parent_map.end() || it->second.IsEmpty()) { |
138 |
| - q.push_back(body_path); |
139 |
| - } |
| 179 | + // List of direct children for each body. |
| 180 | + std::vector<std::vector<bool>> children( |
| 181 | + extraction.nodes.size(), std::vector<bool>(extraction.nodes.size())); |
| 182 | + // List of joint prim paths associated with a child for each body. |
| 183 | + std::vector<std::vector<pxr::SdfPath>> parent_joints(extraction.nodes.size()); |
| 184 | + for (const pxr::UsdPhysicsJoint& joint : extraction.joints) { |
| 185 | + pxr::SdfPath from, to; |
| 186 | + if (!GetJointBodies(joint, &from, &to)) { |
| 187 | + continue; |
| 188 | + } |
| 189 | + |
| 190 | + int from_idx = body_index[from]; |
| 191 | + int to_idx = body_index[to]; |
| 192 | + if (from_idx == to_idx) { |
| 193 | + mju_error("Cycle detected: self referencing joint at node %s", |
| 194 | + to.GetString().c_str()); |
| 195 | + return nullptr; |
140 | 196 | }
|
| 197 | + |
| 198 | + children[from_idx][to_idx] = true; |
| 199 | + parent_joints[to_idx].push_back(joint.GetPath()); |
| 200 | + // Now that we know all the bodies, we can assign joints to respective |
| 201 | + // nodes. |
| 202 | + extraction.nodes[to_idx]->joints.push_back(joint.GetPath()); |
141 | 203 | }
|
142 | 204 |
|
143 |
| - while (!q.empty()) { |
144 |
| - pxr::SdfPath current_path = q.front(); |
145 |
| - q.pop_front(); |
| 205 | + // The world body is represented by an empty SdfPath. |
| 206 | + auto world_root = std::move(extraction.nodes[0]); |
146 | 207 |
|
147 |
| - pxr::SdfPath parent_path = parent_map.count(current_path) |
148 |
| - ? parent_map.at(current_path) |
149 |
| - : pxr::SdfPath(); |
150 |
| - KinematicNode* parent_node = node_map.at(parent_path); |
| 208 | + std::vector<std::pair<int, Node*>> stack; |
| 209 | + stack.emplace_back(0, world_root.get()); |
151 | 210 |
|
152 |
| - auto new_node = std::make_unique<KinematicNode>(); |
153 |
| - new_node->body_path = current_path; |
154 |
| - if (edge_to_joint_map.count({parent_path, current_path})) { |
155 |
| - new_node->joint_path = edge_to_joint_map.at({parent_path, current_path}); |
| 211 | + // A node without any joints from a parent has a free joint. |
| 212 | + // Add all free joints as children of the world body. |
| 213 | + for (int i = 1; i < extraction.nodes.size(); ++i) { |
| 214 | + if (parent_joints[i].empty()) { |
| 215 | + children[0][i] = true; |
156 | 216 | }
|
157 |
| - node_map[current_path] = new_node.get(); |
158 |
| - parent_node->children.push_back(std::move(new_node)); |
159 |
| - |
160 |
| - if (children_map.count(current_path)) { |
161 |
| - const auto& children = children_map.at(current_path); |
162 |
| - // Add children to the front of the queue in reverse order to ensure |
163 |
| - // they are processed in the correct order by the DFS. |
164 |
| - for (auto it = children.rbegin(); it != children.rend(); ++it) { |
165 |
| - q.push_front(*it); |
| 217 | + } |
| 218 | + |
| 219 | + std::vector<bool> visited(extraction.nodes.size()); |
| 220 | + while (!stack.empty()) { |
| 221 | + auto [current_body, parent] = stack.back(); |
| 222 | + stack.pop_back(); |
| 223 | + visited[current_body] = true; |
| 224 | + |
| 225 | + Node* current_node = nullptr; |
| 226 | + if (current_body > 0) { |
| 227 | + parent->children.push_back(std::move(extraction.nodes[current_body])); |
| 228 | + current_node = parent->children.back().get(); |
| 229 | + } else { |
| 230 | + current_node = world_root.get(); |
| 231 | + } |
| 232 | + |
| 233 | + // Process children in reverse to maintain DFS. |
| 234 | + for (int i = extraction.nodes.size() - 1; i > 0; --i) { |
| 235 | + if (!children[current_body][i]) { |
| 236 | + continue; |
166 | 237 | }
|
| 238 | + if (visited[i]) { |
| 239 | + mju_error("Cycle detected in the kinematic tree at node %s", |
| 240 | + current_node->body_path.GetString().c_str()); |
| 241 | + return nullptr; |
| 242 | + } |
| 243 | + stack.emplace_back( |
| 244 | + i, current_body > 0 ? parent->children.back().get() : parent); |
167 | 245 | }
|
168 | 246 | }
|
169 | 247 |
|
170 |
| - // After traversal, check for unvisited nodes. |
171 |
| - // Unvisited nodes at this point imply a cycle. |
172 |
| - for (const auto& node : all_nodes) { |
173 |
| - if (!node.IsEmpty() && !node_map.count(node)) { |
174 |
| - mju_error("Cycle detected involving node %s.", node.GetString().c_str()); |
| 248 | + for (int i = 1; i < visited.size(); ++i) { |
| 249 | + if (!visited[i]) { |
| 250 | + mju_error("Cycle detected: Node %s is not reachable from the world.", |
| 251 | + extraction.nodes[i]->body_path.GetString().c_str()); |
175 | 252 | return nullptr;
|
176 | 253 | }
|
177 | 254 | }
|
178 | 255 |
|
179 | 256 | return world_root;
|
180 | 257 | }
|
| 258 | + |
181 | 259 | } // namespace usd
|
182 | 260 | } // namespace mujoco
|
0 commit comments