Skip to content

Commit 56aa3f4

Browse files
author
Peter
committed
drop dm_control in favor of just mujoco
- dm_control is not maintained - add pass-through support so more models can be loaded
1 parent e7e1f3a commit 56aa3f4

File tree

3 files changed

+62
-43
lines changed

3 files changed

+62
-43
lines changed

CITATION.cff

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ authors:
99
orcid: https://orcid.org/0000-0002-2439-3262
1010
- family-names: Gupta
1111
given-names: Ashwin
12+
- family-names: Mitrano
13+
given-names: Peter
14+
orcid: https://orcid.org/0000-0002-8701-9809
1215
title: PyTorch Kinematics
1316
doi: 10.5281/zenodo.7700588
1417
version: v0.5.4

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ dependencies = [# Optional
6969
# Similar to `dependencies` above, these must be valid existing
7070
# projects.
7171
[project.optional-dependencies] # Optional
72-
test = ["pytest", "dm_control"]
73-
mujoco = ["dm_control"]
72+
test = ["pytest", "mujoco"]
73+
mujoco = ["mujoco"]
7474

7575
# List URLs that are relevant to your project
7676
#

src/pytorch_kinematics/mjcf.py

Lines changed: 57 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,87 @@
1+
from typing import Union
2+
3+
import mujoco
4+
from mujoco._structs import _MjModelBodyViews as MjModelBodyViews
5+
16
import pytorch_kinematics.transforms as tf
27
from . import chain
38
from . import frame
49

5-
JOINT_TYPE_MAP = {'hinge': 'revolute', "slide": "prismatic"}
10+
# Converts from MuJoCo joint types to pytorch_kinematics joint types
11+
JOINT_TYPE_MAP = {
12+
mujoco.mjtJoint.mjJNT_HINGE: 'revolute',
13+
mujoco.mjtJoint.mjJNT_SLIDE: "prismatic"
14+
}
615

716

8-
def geoms_to_visuals(geom):
17+
def body_to_geoms(m: mujoco.MjModel, body: MjModelBodyViews):
18+
# Find all geoms which have body as parent
919
visuals = []
10-
for g in geom:
11-
if g.type == 'capsule':
12-
param = (g.size[0], g.fromto)
13-
elif g.type == 'sphere':
14-
param = g.size[0]
15-
elif g.type == 'mesh':
16-
param = None
17-
else:
18-
raise ValueError('Invalid geometry type %s.' % g.type)
19-
visuals.append(frame.Visual(offset=tf.Transform3d(rot=g.quat, pos=g.pos), geom_type=g.type, geom_param=param))
20+
for geom_id in range(m.ngeom):
21+
geom = m.geom(geom_id)
22+
if geom.bodyid == body.id:
23+
visuals.append(frame.Visual(offset=tf.Transform3d(rot=geom.quat, pos=geom.pos), geom_type=geom.type,
24+
geom_param=geom.size))
2025
return visuals
2126

2227

23-
def _build_chain_recurse(parent_frame, parent_body):
24-
parent_frame.link.visuals = geoms_to_visuals(parent_body.geom)
25-
for b in parent_body.body:
26-
n_joints = len(b.joint)
27-
if n_joints > 1:
28-
raise ValueError("composite joints not supported (could implement this if needed)")
29-
if n_joints == 1:
30-
joint = b.joint[0]
31-
child_joint = frame.Joint(joint.name, tf.Transform3d(pos=joint.pos), axis=joint.axis,
32-
joint_type=JOINT_TYPE_MAP[joint.type])
33-
else:
34-
child_joint = frame.Joint(b.name + "_imaginary_fixed_joint")
35-
child_link = frame.Link(b.name, offset=tf.Transform3d(rot=b.quat, pos=b.pos))
36-
child_frame = frame.Frame(name=b.name, link=child_link, joint=child_joint)
37-
parent_frame.children = parent_frame.children + (child_frame,)
38-
_build_chain_recurse(child_frame, b)
39-
40-
for site in parent_body.site:
41-
site_link = frame.Link(site.name, offset=tf.Transform3d(rot=site.quat, pos=site.pos))
42-
site_frame = frame.Frame(name=site.name, link=site_link)
43-
parent_frame.children = parent_frame.children + (site_frame,)
44-
45-
46-
def build_chain_from_mjcf(data):
28+
def _build_chain_recurse(m, parent_frame, parent_body):
29+
parent_frame.link.visuals = body_to_geoms(m, parent_body)
30+
# iterate through all bodies that are children of parent_body
31+
for body_id in range(m.nbody):
32+
body = m.body(body_id)
33+
if body.parentid == parent_body.id and body_id != parent_body.id:
34+
n_joints = body.jntnum
35+
if n_joints > 1:
36+
raise ValueError("composite joints not supported (could implement this if needed)")
37+
if n_joints == 1:
38+
# Find the joint for this body
39+
for jntid in body.jntadr:
40+
joint = m.joint(jntid)
41+
child_joint = frame.Joint(joint.name, tf.Transform3d(pos=joint.pos), axis=joint.axis,
42+
joint_type=JOINT_TYPE_MAP[joint.type[0]])
43+
else:
44+
child_joint = frame.Joint(body.name + "_fixed_joint")
45+
child_link = frame.Link(body.name, offset=tf.Transform3d(rot=body.quat, pos=body.pos))
46+
child_frame = frame.Frame(name=body.name, link=child_link, joint=child_joint)
47+
parent_frame.children = parent_frame.children + (child_frame,)
48+
_build_chain_recurse(m, child_frame, body)
49+
50+
# iterate through all sites that are children of parent_body
51+
for site_id in range(m.nsite):
52+
site = m.site(site_id)
53+
if site.bodyid == parent_body.id:
54+
site_link = frame.Link(site.name, offset=tf.Transform3d(rot=site.quat, pos=site.pos))
55+
site_frame = frame.Frame(name=site.name, link=site_link)
56+
parent_frame.children = parent_frame.children + (site_frame,)
57+
58+
59+
def build_chain_from_mjcf(data, body: Union[None, str, int] = None):
4760
"""
4861
Build a Chain object from MJCF data.
4962
5063
Parameters
5164
----------
5265
data : str
5366
MJCF string data.
67+
body : str or int, optional
68+
The name or index of the body to use as the root of the chain. If None, body idx=0 is used.
5469
5570
Returns
5671
-------
5772
chain.Chain
5873
Chain object created from MJCF.
5974
"""
60-
from dm_control import mjcf
61-
62-
model = mjcf.from_xml_string(data)
63-
root_body = model.worldbody.body[0]
75+
m = mujoco.MjModel.from_xml_string(data)
76+
if body is None:
77+
root_body = m.body(0)
78+
else:
79+
root_body = m.body(body)
6480
root_frame = frame.Frame(root_body.name + "_frame",
6581
link=frame.Link(root_body.name,
6682
offset=tf.Transform3d(rot=root_body.quat, pos=root_body.pos)),
6783
joint=frame.Joint())
68-
_build_chain_recurse(root_frame, root_body)
84+
_build_chain_recurse(m, root_frame, root_body)
6985
return chain.Chain(root_frame)
7086

7187

0 commit comments

Comments
 (0)