Skip to content

Commit 2e13c64

Browse files
havesscopybara-github
authored andcommitted
Simplify parsing from USD stage to kinematic tree for mjSpec population.
Summary of changes: - Move stage parsing from usd_to_mjspec to kinematic_tree - A Node now has lists of paths for prims that we support parsing. - When parsing the stage, we create a list of all the Nodes that we will place in our tree. - Removed the need for expensive SdfPath maps and sorting of bodies in favor of indexed integer arrays. - Simplifies usd_to_mjspec parsing as we can process each node and it's owned prims one at a time. PiperOrigin-RevId: 784187799 Change-Id: I4907cfa1137014af21c45b10fa29ac53a49b6dac
1 parent e232265 commit 2e13c64

File tree

3 files changed

+247
-246
lines changed

3 files changed

+247
-246
lines changed

src/experimental/usd/kinematic_tree.cc

Lines changed: 172 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,36 @@
1414

1515
#include "experimental/usd/kinematic_tree.h"
1616

17-
#include <algorithm>
18-
#include <deque>
1917
#include <map>
2018
#include <memory>
21-
#include <set>
2219
#include <utility>
2320
#include <vector>
2421

22+
#include <mujoco/experimental/usd/mjcPhysics/actuator.h>
23+
#include <mujoco/experimental/usd/mjcPhysics/keyframe.h>
24+
#include <mujoco/experimental/usd/mjcPhysics/siteAPI.h>
2525
#include <mujoco/mujoco.h>
2626
#include <pxr/usd/sdf/path.h>
27+
#include <pxr/usd/usd/common.h>
28+
#include <pxr/usd/usd/primRange.h>
29+
#include <pxr/usd/usdGeom/xformCache.h>
30+
#include <pxr/usd/usdPhysics/collisionAPI.h>
2731
#include <pxr/usd/usdPhysics/joint.h>
32+
#include <pxr/usd/usdPhysics/rigidBodyAPI.h>
33+
#include <pxr/usd/usdPhysics/scene.h>
2834

2935
namespace mujoco {
3036
namespace usd {
3137

32-
bool GetJointBodies(const pxr::UsdPhysicsJoint& joint,
33-
const pxr::SdfPath& default_prim_path, pxr::SdfPath* from,
38+
bool GetJointBodies(const pxr::UsdPhysicsJoint& joint, pxr::SdfPath* from,
3439
pxr::SdfPath* to) {
40+
// Grab the default prim path
41+
pxr::SdfPath default_prim_path;
42+
auto stage = joint.GetPrim().GetStage();
43+
if (stage->GetDefaultPrim().IsValid()) {
44+
default_prim_path = stage->GetDefaultPrim().GetPath();
45+
}
46+
3547
pxr::SdfPathVector body1_paths;
3648
joint.GetBody1Rel().GetTargets(&body1_paths);
3749
if (body1_paths.empty()) {
@@ -62,121 +74,187 @@ bool GetJointBodies(const pxr::UsdPhysicsJoint& joint,
6274
return true;
6375
}
6476

65-
std::unique_ptr<KinematicNode> BuildKinematicTree(
66-
const std::vector<pxr::UsdPhysicsJoint>& joints,
67-
const std::vector<pxr::SdfPath>& all_body_paths,
68-
const pxr::SdfPath& default_prim_path) {
69-
std::map<pxr::SdfPath, std::vector<pxr::SdfPath>> children_map;
70-
std::map<pxr::SdfPath, pxr::SdfPath> parent_map;
71-
std::map<std::pair<pxr::SdfPath, pxr::SdfPath>, pxr::SdfPath>
72-
edge_to_joint_map;
73-
std::set<pxr::SdfPath> all_nodes(all_body_paths.begin(),
74-
all_body_paths.end());
75-
76-
for (const auto& joint : joints) {
77-
pxr::SdfPath from, to;
78-
if (!GetJointBodies(joint, default_prim_path, &from, &to)) {
77+
struct ExtractedPrims {
78+
std::vector<std::unique_ptr<Node>> nodes;
79+
std::vector<pxr::UsdPhysicsJoint> joints;
80+
};
81+
82+
ExtractedPrims ExtractPrims(pxr::UsdStageRefPtr stage) {
83+
pxr::UsdPhysicsScene physics_scene;
84+
std::vector<std::unique_ptr<Node>> nodes;
85+
std::vector<pxr::UsdPhysicsJoint> joints;
86+
nodes.push_back(std::make_unique<Node>());
87+
Node* root = nodes.back().get();
88+
89+
// =========================================================================
90+
// PASS 1: Collect Bodies, Joints, and Geoms/Sites/etc.
91+
// =========================================================================
92+
// A single DFS pass to find all bodies, joints, and determine
93+
// which body owns each geom/site/etc. prim.
94+
std::vector<Node*> owner_stack;
95+
owner_stack.push_back(root); // Start with the world as owner.
96+
97+
const auto range = pxr::UsdPrimRange::PreAndPostVisit(
98+
stage->GetPseudoRoot(), pxr::UsdTraverseInstanceProxies());
99+
100+
pxr::UsdGeomXformCache xform_cache;
101+
for (auto it = range.begin(); it != range.end(); ++it) {
102+
pxr::UsdPrim prim = *it;
103+
104+
bool is_body = prim.HasAPI<pxr::UsdPhysicsRigidBodyAPI>();
105+
bool resets = xform_cache.GetResetXformStack(prim);
106+
// Only update (push/pop) the owner stack for bodies (becomes new owner) and
107+
// resetXformStack (reset owner to world).
108+
bool is_pushed_to_stack = is_body || resets;
109+
110+
if (it.IsPostVisit()) {
111+
if (is_pushed_to_stack) {
112+
owner_stack.pop_back();
113+
}
79114
continue;
80115
}
81116

82-
auto edge_key = std::make_pair(from, to);
83-
auto it = edge_to_joint_map.find(edge_key);
84-
if (it == edge_to_joint_map.end()) {
85-
edge_to_joint_map[edge_key] = joint.GetPath();
86-
} else {
87-
mju_warning(
88-
"Multiple explicit joints defined between body %s and body %s. "
89-
"Joint1: %s, Joint2: %s. Keeping the first one found: %s",
90-
(from.IsEmpty() ? "<worldbody>" : from.GetString()).c_str(),
91-
to.GetString().c_str(), it->second.GetString().c_str(),
92-
joint.GetPath().GetString().c_str(), it->second.GetString().c_str());
93-
continue;
117+
pxr::SdfPath prim_path = prim.GetPath();
118+
Node* current_node = owner_stack.back();
119+
120+
if (is_body) {
121+
auto new_node = std::make_unique<Node>();
122+
new_node->body_path = prim_path;
123+
nodes.push_back(std::move(new_node));
124+
current_node = nodes.back().get();
125+
} else if (resets) {
126+
current_node = root; // Reset owner to world.
94127
}
95128

96-
if (from == to) {
97-
mju_error("Self-loop detected at node %s", to.GetString().c_str());
98-
return nullptr;
129+
if (is_pushed_to_stack) {
130+
owner_stack.push_back(current_node);
99131
}
100-
if (parent_map.count(to)) {
101-
mju_error("Node %s has multiple parents ('%s' and '%s').",
102-
to.GetString().c_str(), parent_map.at(to).GetString().c_str(),
103-
from.GetString().c_str());
104-
return nullptr;
132+
133+
if (prim.IsA<pxr::UsdPhysicsScene>() && root->physics_scene.IsEmpty()) {
134+
root->physics_scene = prim_path;
135+
}
136+
137+
if (prim.HasAPI<pxr::UsdPhysicsCollisionAPI>()) {
138+
current_node->colliders.push_back(prim.GetPath());
139+
}
140+
141+
if (prim.HasAPI<pxr::MjcPhysicsSiteAPI>()) {
142+
current_node->sites.push_back(prim.GetPath());
143+
// Sites should not have children.
144+
it.PruneChildren();
145+
}
146+
147+
if (prim.IsA<pxr::UsdPhysicsJoint>()) {
148+
// We may not know which body this belongs to yet so we'll add it to a
149+
// list and the caller can assign the joints when building the tree.
150+
joints.push_back(pxr::UsdPhysicsJoint(prim));
151+
// Joints should not have children.
152+
it.PruneChildren();
153+
}
154+
155+
if (prim.IsA<pxr::MjcPhysicsActuator>()) {
156+
root->actuators.push_back(prim_path);
157+
// Joints should not have children.
158+
it.PruneChildren();
159+
}
160+
161+
if (prim.IsA<pxr::MjcPhysicsKeyframe>()) {
162+
root->keyframes.push_back(prim.GetPath());
163+
// Keyframes should not have children.
164+
it.PruneChildren();
105165
}
106-
children_map[from].push_back(to);
107-
parent_map[to] = from;
108-
all_nodes.insert(from);
109-
all_nodes.insert(to);
110166
}
167+
return {.nodes = std::move(nodes), .joints = std::move(joints)};
168+
}
169+
170+
std::unique_ptr<Node> BuildKinematicTree(const pxr::UsdStageRefPtr stage) {
171+
ExtractedPrims extraction = ExtractPrims(stage);
111172

112-
// Sort children in children_map to respect the DFS order from the stage.
113-
for (auto& [_, children] : children_map) {
114-
std::sort(
115-
children.begin(), children.end(),
116-
[&v = all_body_paths](const auto& a, const auto& b) {
117-
return std::distance(v.begin(), std::find(v.begin(), v.end(), a)) <
118-
std::distance(v.begin(), std::find(v.begin(), v.end(), b));
119-
});
173+
std::map<pxr::SdfPath, int> body_index;
174+
body_index[pxr::SdfPath()] = 0;
175+
for (int i = 0; i < extraction.nodes.size(); ++i) {
176+
body_index[extraction.nodes[i]->body_path] = i;
120177
}
121178

122-
// The world body is represented by an empty SdfPath.
123-
auto world_root = std::make_unique<KinematicNode>();
124-
std::map<pxr::SdfPath, KinematicNode*> node_map;
125-
node_map[pxr::SdfPath()] = world_root.get();
126-
127-
// Use a deque for traversal. We will add roots to the back and children
128-
// to the front to perform a DFS on each root's tree.
129-
std::deque<pxr::SdfPath> q;
130-
131-
// Add roots (floating-base bodies and children of the world) to the queue,
132-
// preserving the DFS order from the USD stage.
133-
for (const auto& body_path : all_body_paths) {
134-
if (!body_path.IsEmpty()) {
135-
const auto it = parent_map.find(body_path);
136-
// A root is a body that has no parent, or its parent is the world.
137-
if (it == parent_map.end() || it->second.IsEmpty()) {
138-
q.push_back(body_path);
139-
}
179+
// List of direct children for each body.
180+
std::vector<std::vector<bool>> children(
181+
extraction.nodes.size(), std::vector<bool>(extraction.nodes.size()));
182+
// List of joint prim paths associated with a child for each body.
183+
std::vector<std::vector<pxr::SdfPath>> parent_joints(extraction.nodes.size());
184+
for (const pxr::UsdPhysicsJoint& joint : extraction.joints) {
185+
pxr::SdfPath from, to;
186+
if (!GetJointBodies(joint, &from, &to)) {
187+
continue;
188+
}
189+
190+
int from_idx = body_index[from];
191+
int to_idx = body_index[to];
192+
if (from_idx == to_idx) {
193+
mju_error("Cycle detected: self referencing joint at node %s",
194+
to.GetString().c_str());
195+
return nullptr;
140196
}
197+
198+
children[from_idx][to_idx] = true;
199+
parent_joints[to_idx].push_back(joint.GetPath());
200+
// Now that we know all the bodies, we can assign joints to respective
201+
// nodes.
202+
extraction.nodes[to_idx]->joints.push_back(joint.GetPath());
141203
}
142204

143-
while (!q.empty()) {
144-
pxr::SdfPath current_path = q.front();
145-
q.pop_front();
205+
// The world body is represented by an empty SdfPath.
206+
auto world_root = std::move(extraction.nodes[0]);
146207

147-
pxr::SdfPath parent_path = parent_map.count(current_path)
148-
? parent_map.at(current_path)
149-
: pxr::SdfPath();
150-
KinematicNode* parent_node = node_map.at(parent_path);
208+
std::vector<std::pair<int, Node*>> stack;
209+
stack.emplace_back(0, world_root.get());
151210

152-
auto new_node = std::make_unique<KinematicNode>();
153-
new_node->body_path = current_path;
154-
if (edge_to_joint_map.count({parent_path, current_path})) {
155-
new_node->joint_path = edge_to_joint_map.at({parent_path, current_path});
211+
// A node without any joints from a parent has a free joint.
212+
// Add all free joints as children of the world body.
213+
for (int i = 1; i < extraction.nodes.size(); ++i) {
214+
if (parent_joints[i].empty()) {
215+
children[0][i] = true;
156216
}
157-
node_map[current_path] = new_node.get();
158-
parent_node->children.push_back(std::move(new_node));
159-
160-
if (children_map.count(current_path)) {
161-
const auto& children = children_map.at(current_path);
162-
// Add children to the front of the queue in reverse order to ensure
163-
// they are processed in the correct order by the DFS.
164-
for (auto it = children.rbegin(); it != children.rend(); ++it) {
165-
q.push_front(*it);
217+
}
218+
219+
std::vector<bool> visited(extraction.nodes.size());
220+
while (!stack.empty()) {
221+
auto [current_body, parent] = stack.back();
222+
stack.pop_back();
223+
visited[current_body] = true;
224+
225+
Node* current_node = nullptr;
226+
if (current_body > 0) {
227+
parent->children.push_back(std::move(extraction.nodes[current_body]));
228+
current_node = parent->children.back().get();
229+
} else {
230+
current_node = world_root.get();
231+
}
232+
233+
// Process children in reverse to maintain DFS.
234+
for (int i = extraction.nodes.size() - 1; i > 0; --i) {
235+
if (!children[current_body][i]) {
236+
continue;
166237
}
238+
if (visited[i]) {
239+
mju_error("Cycle detected in the kinematic tree at node %s",
240+
current_node->body_path.GetString().c_str());
241+
return nullptr;
242+
}
243+
stack.emplace_back(
244+
i, current_body > 0 ? parent->children.back().get() : parent);
167245
}
168246
}
169247

170-
// After traversal, check for unvisited nodes.
171-
// Unvisited nodes at this point imply a cycle.
172-
for (const auto& node : all_nodes) {
173-
if (!node.IsEmpty() && !node_map.count(node)) {
174-
mju_error("Cycle detected involving node %s.", node.GetString().c_str());
248+
for (int i = 1; i < visited.size(); ++i) {
249+
if (!visited[i]) {
250+
mju_error("Cycle detected: Node %s is not reachable from the world.",
251+
extraction.nodes[i]->body_path.GetString().c_str());
175252
return nullptr;
176253
}
177254
}
178255

179256
return world_root;
180257
}
258+
181259
} // namespace usd
182260
} // namespace mujoco

src/experimental/usd/kinematic_tree.h

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,27 @@ namespace usd {
2727
// A struct to represent a node in the kinematic tree.
2828
// Using a struct with a vector of children preserves the order of bodies,
2929
// which is important for things like keyframes and policy compatibility.
30-
struct KinematicNode {
30+
struct Node {
3131
pxr::SdfPath body_path;
32-
pxr::SdfPath joint_path; // Joint connecting this node to its parent.
33-
std::vector<std::unique_ptr<KinematicNode>> children;
32+
pxr::SdfPath physics_scene;
33+
std::vector<pxr::SdfPath> actuators;
34+
std::vector<pxr::SdfPath> joints;
35+
std::vector<pxr::SdfPath> colliders;
36+
std::vector<pxr::SdfPath> sites;
37+
std::vector<pxr::SdfPath> keyframes;
38+
std::vector<std::unique_ptr<Node>> children;
3439
};
3540

36-
// Builds a single kinematic tree from a list of joints.
41+
// A kinematic edge represents a joint.
42+
using JointVec = std::vector<pxr::UsdPhysicsJoint>;
43+
44+
// Builds a single kinematic tree from a list of directed edges.
3745
// The DFS order of bodies in the tree is determined by the order of bodies in
3846
// `all_body_paths`.
3947
// All bodies, including static and floating-base bodies, are organized under a
4048
// single world root. An empty 'from' path in an edge represents the world body.
4149
// Returns the root of the kinematic tree, or `nullptr` for invalid structures.
42-
std::unique_ptr<KinematicNode> BuildKinematicTree(
43-
const std::vector<pxr::UsdPhysicsJoint>& joints,
44-
const std::vector<pxr::SdfPath>& all_body_paths,
45-
const pxr::SdfPath& default_prim_path);
50+
std::unique_ptr<Node> BuildKinematicTree(const pxr::UsdStageRefPtr stage);
4651

4752
} // namespace usd
4853
} // namespace mujoco

0 commit comments

Comments
 (0)