Skip to content

Commit e97f291

Browse files
authored
Merge pull request #22 from UM-ARM-Lab/mujoco
drop dm_control in favor of just mujoco
2 parents e7e1f3a + 69f610f commit e97f291

File tree

9 files changed

+183
-103
lines changed

9 files changed

+183
-103
lines changed

.github/workflows/python-package.yml

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,30 @@ jobs:
1616
strategy:
1717
fail-fast: false
1818
matrix:
19-
python-version: ["3.8", "3.9", "3.10"]
19+
python-version: [ "3.8", "3.9", "3.10" ]
2020

2121
steps:
22-
- uses: actions/checkout@v3
23-
- name: Set up Python ${{ matrix.python-version }}
24-
uses: actions/setup-python@v3
25-
with:
26-
python-version: ${{ matrix.python-version }}
27-
- name: Install dependencies
28-
run: |
29-
python -m pip install --upgrade pip
30-
python -m pip install .[test]
31-
python -m pip install flake8 pytest
32-
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
33-
- name: Lint with flake8
34-
run: |
35-
# stop the build if there are Python syntax errors or undefined names
36-
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
37-
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
38-
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
39-
- name: Test with pytest
40-
run: |
41-
pytest
22+
- uses: actions/checkout@v3
23+
- name: Set up Python ${{ matrix.python-version }}
24+
uses: actions/setup-python@v3
25+
with:
26+
python-version: ${{ matrix.python-version }}
27+
- name: Install dependencies
28+
run: |
29+
python -m pip install --upgrade pip
30+
python -m pip install .[test]
31+
python -m pip install flake8 pytest
32+
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
33+
- name: Clone mujoco_menagerie repository into the tests/ folder
34+
run: |
35+
git clone https://github.com/google-deepmind/mujoco_menagerie
36+
working-directory: ${{ runner.workspace }}/pytorch_kinematics/tests
37+
- name: Lint with flake8
38+
run: |
39+
# stop the build if there are Python syntax errors or undefined names
40+
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
41+
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
42+
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
43+
- name: Test with pytest
44+
run: |
45+
pytest

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,8 @@
33
__pycache__
44
temp*
55
build
6-
dist
6+
dist
7+
# These are cloned/generated when testing with mujoco
8+
tests/MUJOCO_LOG.TXT
9+
tests/mujoco_menagerie/
10+

CITATION.cff

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ authors:
99
orcid: https://orcid.org/0000-0002-2439-3262
1010
- family-names: Gupta
1111
given-names: Ashwin
12+
- family-names: Mitrano
13+
given-names: Peter
14+
orcid: https://orcid.org/0000-0002-8701-9809
1215
title: PyTorch Kinematics
1316
doi: 10.5281/zenodo.7700588
1417
version: v0.5.4

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ dependencies = [# Optional
6969
# Similar to `dependencies` above, these must be valid existing
7070
# projects.
7171
[project.optional-dependencies] # Optional
72-
test = ["pytest", "dm_control"]
73-
mujoco = ["dm_control"]
72+
test = ["pytest", "mujoco"]
73+
mujoco = ["mujoco"]
7474

7575
# List URLs that are relevant to your project
7676
#

src/pytorch_kinematics/mjcf.py

Lines changed: 57 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,87 @@
1+
from typing import Union
2+
3+
import mujoco
4+
from mujoco._structs import _MjModelBodyViews as MjModelBodyViews
5+
16
import pytorch_kinematics.transforms as tf
27
from . import chain
38
from . import frame
49

5-
JOINT_TYPE_MAP = {'hinge': 'revolute', "slide": "prismatic"}
10+
# Converts from MuJoCo joint types to pytorch_kinematics joint types
11+
JOINT_TYPE_MAP = {
12+
mujoco.mjtJoint.mjJNT_HINGE: 'revolute',
13+
mujoco.mjtJoint.mjJNT_SLIDE: "prismatic"
14+
}
615

716

8-
def geoms_to_visuals(geom):
17+
def body_to_geoms(m: mujoco.MjModel, body: MjModelBodyViews):
18+
# Find all geoms which have body as parent
919
visuals = []
10-
for g in geom:
11-
if g.type == 'capsule':
12-
param = (g.size[0], g.fromto)
13-
elif g.type == 'sphere':
14-
param = g.size[0]
15-
elif g.type == 'mesh':
16-
param = None
17-
else:
18-
raise ValueError('Invalid geometry type %s.' % g.type)
19-
visuals.append(frame.Visual(offset=tf.Transform3d(rot=g.quat, pos=g.pos), geom_type=g.type, geom_param=param))
20+
for geom_id in range(m.ngeom):
21+
geom = m.geom(geom_id)
22+
if geom.bodyid == body.id:
23+
visuals.append(frame.Visual(offset=tf.Transform3d(rot=geom.quat, pos=geom.pos), geom_type=geom.type,
24+
geom_param=geom.size))
2025
return visuals
2126

2227

23-
def _build_chain_recurse(parent_frame, parent_body):
24-
parent_frame.link.visuals = geoms_to_visuals(parent_body.geom)
25-
for b in parent_body.body:
26-
n_joints = len(b.joint)
27-
if n_joints > 1:
28-
raise ValueError("composite joints not supported (could implement this if needed)")
29-
if n_joints == 1:
30-
joint = b.joint[0]
31-
child_joint = frame.Joint(joint.name, tf.Transform3d(pos=joint.pos), axis=joint.axis,
32-
joint_type=JOINT_TYPE_MAP[joint.type])
33-
else:
34-
child_joint = frame.Joint(b.name + "_imaginary_fixed_joint")
35-
child_link = frame.Link(b.name, offset=tf.Transform3d(rot=b.quat, pos=b.pos))
36-
child_frame = frame.Frame(name=b.name, link=child_link, joint=child_joint)
37-
parent_frame.children = parent_frame.children + (child_frame,)
38-
_build_chain_recurse(child_frame, b)
39-
40-
for site in parent_body.site:
41-
site_link = frame.Link(site.name, offset=tf.Transform3d(rot=site.quat, pos=site.pos))
42-
site_frame = frame.Frame(name=site.name, link=site_link)
43-
parent_frame.children = parent_frame.children + (site_frame,)
44-
45-
46-
def build_chain_from_mjcf(data):
28+
def _build_chain_recurse(m, parent_frame, parent_body):
29+
parent_frame.link.visuals = body_to_geoms(m, parent_body)
30+
# iterate through all bodies that are children of parent_body
31+
for body_id in range(m.nbody):
32+
body = m.body(body_id)
33+
if body.parentid == parent_body.id and body_id != parent_body.id:
34+
n_joints = body.jntnum
35+
if n_joints > 1:
36+
raise ValueError("composite joints not supported (could implement this if needed)")
37+
if n_joints == 1:
38+
# Find the joint for this body
39+
for jntid in body.jntadr:
40+
joint = m.joint(jntid)
41+
child_joint = frame.Joint(joint.name, tf.Transform3d(pos=joint.pos), axis=joint.axis,
42+
joint_type=JOINT_TYPE_MAP[joint.type[0]])
43+
else:
44+
child_joint = frame.Joint(body.name + "_fixed_joint")
45+
child_link = frame.Link(body.name, offset=tf.Transform3d(rot=body.quat, pos=body.pos))
46+
child_frame = frame.Frame(name=body.name, link=child_link, joint=child_joint)
47+
parent_frame.children = parent_frame.children + (child_frame,)
48+
_build_chain_recurse(m, child_frame, body)
49+
50+
# iterate through all sites that are children of parent_body
51+
for site_id in range(m.nsite):
52+
site = m.site(site_id)
53+
if site.bodyid == parent_body.id:
54+
site_link = frame.Link(site.name, offset=tf.Transform3d(rot=site.quat, pos=site.pos))
55+
site_frame = frame.Frame(name=site.name, link=site_link)
56+
parent_frame.children = parent_frame.children + (site_frame,)
57+
58+
59+
def build_chain_from_mjcf(data, body: Union[None, str, int] = None):
4760
"""
4861
Build a Chain object from MJCF data.
4962
5063
Parameters
5164
----------
5265
data : str
5366
MJCF string data.
67+
body : str or int, optional
68+
The name or index of the body to use as the root of the chain. If None, body idx=0 is used.
5469
5570
Returns
5671
-------
5772
chain.Chain
5873
Chain object created from MJCF.
5974
"""
60-
from dm_control import mjcf
61-
62-
model = mjcf.from_xml_string(data)
63-
root_body = model.worldbody.body[0]
75+
m = mujoco.MjModel.from_xml_string(data)
76+
if body is None:
77+
root_body = m.body(0)
78+
else:
79+
root_body = m.body(body)
6480
root_frame = frame.Frame(root_body.name + "_frame",
6581
link=frame.Link(root_body.name,
6682
offset=tf.Transform3d(rot=root_body.quat, pos=root_body.pos)),
6783
joint=frame.Joint())
68-
_build_chain_recurse(root_frame, root_body)
84+
_build_chain_recurse(m, root_frame, root_body)
6985
return chain.Chain(root_frame)
7086

7187

tests/ant.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
<mujoco model="ant">
2-
<compiler angle="degree" coordinate="local" inertiafromgeom="true"/>
2+
<compiler angle="degree"/>
33
<option integrator="RK4" timestep="0.01"/>
44
<custom>
55
<numeric data="0.0 0.0 0.55 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -1.0 0.0 -1.0 0.0 1.0" name="init_qpos"/>

tests/hopper.xml

Lines changed: 38 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,41 @@
11
<mujoco model="hopper">
2-
<compiler angle="degree" coordinate="global" inertiafromgeom="true"/>
3-
<default>
4-
<joint armature="1" damping="1" limited="true"/>
5-
<geom conaffinity="1" condim="1" contype="1" margin="0.001" friction="0.8 .1 .1" material="geom" rgba="0.8 0.6 .4 1" solimp=".8 .8 .01" solref=".02 1"/>
6-
<motor ctrllimited="true" ctrlrange="-.4 .4"/>
7-
</default>
8-
<option integrator="RK4" timestep="0.002"/>
9-
<worldbody>
10-
<!-- CHANGE: body pos="" deleted for all bodies (you can also set pos="0 0 0", it works)
11-
Interpretation of body pos="" depends on coordinate="global" above.
12-
Bullet doesn't support global coordinates in bodies, little motivation to fix this, as long as it works without pos="" as well.
13-
After this change, Hopper still loads and works in MuJoCo simulator.
14-
-->
15-
<body name="torso">
16-
<joint armature="0" axis="1 0 0" damping="0" limited="false" name="ignore1" pos="0 0 0" stiffness="0" type="slide"/>
17-
<joint armature="0" axis="0 0 1" damping="0" limited="false" name="ignore2" pos="0 0 0" ref="1.25" stiffness="0" type="slide"/>
18-
<joint armature="0" axis="0 1 0" damping="0" limited="false" name="ignore3" pos="0 0 0" stiffness="0" type="hinge"/>
19-
<geom fromto="0 0 1.45 0 0 1.05" name="torso_geom" size="0.05" type="capsule"/>
20-
<body name="thigh">
21-
<joint axis="0 -1 0" name="thigh_joint" pos="0 0 1.05" range="-150 0" type="hinge"/>
22-
<geom fromto="0 0 1.05 0 0 0.6" name="thigh_geom" size="0.05" type="capsule"/>
23-
<body name="leg">
24-
<joint axis="0 -1 0" name="leg_joint" pos="0 0 0.6" range="-150 0" type="hinge"/>
25-
<geom fromto="0 0 0.6 0 0 0.1" name="leg_geom" size="0.04" type="capsule"/>
26-
<body name="foot">
27-
<joint axis="0 -1 0" name="foot_joint" pos="0 0 0.1" range="-45 45" type="hinge"/>
28-
<geom fromto="-0.13 0 0.1 0.26 0 0.1" name="foot_geom" size="0.06" type="capsule"/>
29-
</body>
2+
<compiler angle="degree"/>
3+
<default>
4+
<joint armature="1" damping="1" limited="true"/>
5+
<geom conaffinity="1" condim="1" contype="1" margin="0.001" friction="0.8 .1 .1" rgba="0.8 0.6 .4 1"
6+
solimp=".8 .8 .01" solref=".02 1"/>
7+
<motor ctrllimited="true" ctrlrange="-.4 .4"/>
8+
</default>
9+
<option integrator="RK4" timestep="0.002"/>
10+
<worldbody>
11+
<body name="torso">
12+
<inertial mass="10" pos="0 0 0" diaginertia="0.1 0.1 0.1"/>
13+
<joint armature="0" axis="1 0 0" damping="0" limited="false" name="ignore1" stiffness="0" type="slide"/>
14+
<body>
15+
<inertial mass="1" pos="0 0 0" diaginertia="0.01 0.01 0.01"/>
16+
<joint armature="0" axis="0 0 1" damping="0" limited="false" name="ignore2" ref="1.25" stiffness="0" type="slide"/>
17+
<body>
18+
<joint armature="0" axis="0 1 0" damping="0" limited="false" name="ignore3" stiffness="0" type="hinge"/>
19+
<geom fromto="0 0 1.45 0 0 1.05" name="torso_geom" size="0.05" type="capsule"/>
20+
<body name="thigh">
21+
<joint axis="0 -1 0" name="thigh_joint" pos="0 0 1.05" range="-150 0" type="hinge"/>
22+
<geom fromto="0 0 1.05 0 0 0.6" name="thigh_geom" size="0.05" type="capsule"/>
23+
<body name="leg">
24+
<joint axis="0 -1 0" name="leg_joint" pos="0 0 0.6" range="-150 0" type="hinge"/>
25+
<geom fromto="0 0 0.6 0 0 0.1" name="leg_geom" size="0.04" type="capsule"/>
26+
<body name="foot">
27+
<joint axis="0 -1 0" name="foot_joint" pos="0 0 0.1" range="-45 45" type="hinge"/>
28+
<geom fromto="-0.13 0 0.1 0.26 0 0.1" name="foot_geom" size="0.06" type="capsule"/>
29+
</body>
30+
</body>
31+
</body>
32+
</body>
33+
</body>
3034
</body>
31-
</body>
32-
</body>
33-
</worldbody>
34-
<actuator>
35-
<motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="thigh_joint"/>
36-
<motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="leg_joint"/>
37-
<motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="foot_joint"/>
38-
</actuator>
35+
</worldbody>
36+
<actuator>
37+
<motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="thigh_joint"/>
38+
<motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="leg_joint"/>
39+
<motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="foot_joint"/>
40+
</actuator>
3941
</mujoco>

tests/humanoid.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
<mujoco model='humanoid'>
2-
<compiler inertiafromgeom='true' angle='degree'/>
2+
<compiler angle='degree'/>
33

44
<default>
55
<joint limited='true' damping='1' armature='0' />

tests/test_menagerie.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import os
2+
import pathlib
3+
4+
import numpy as np
5+
6+
import pytorch_kinematics as pk
7+
8+
# Find all files named "scene*.xml" in the "mujoco_menagerie" directory
9+
_MENAGERIE_ROOT = pathlib.Path(__file__).parent / 'mujoco_menagerie'
10+
_XMLS_AND_BODIES = {
11+
# 'agility_cassie/scene.xml': 'cassie-pelvis', # not supported because it has a ball joint
12+
'anybotics_anymal_b/scene.xml': 'base',
13+
'anybotics_anymal_c/scene.xml': 'base',
14+
'franka_emika_panda/scene.xml': 'link0',
15+
'google_barkour_v0/scene.xml': 'chassis',
16+
'google_barkour_v0/scene_barkour.xml': 'chassis',
17+
# 'hello_robot_stretch/scene.xml': 'base_link', # not supported because it has composite joints
18+
'kuka_iiwa_14/scene.xml': 'base',
19+
'rethink_robotics_sawyer/scene.xml': 'base',
20+
'robotiq_2f85/scene.xml': 'base_mount',
21+
'robotis_op3/scene.xml': 'body_link',
22+
'shadow_hand/scene_left.xml': 'lh_forearm',
23+
'shadow_hand/scene_right.xml': 'rh_forearm',
24+
'ufactory_xarm7/scene.xml': 'link_base',
25+
'unitree_a1/scene.xml': 'trunk',
26+
'unitree_go1/scene.xml': 'trunk',
27+
'universal_robots_ur5e/scene.xml': 'base',
28+
'wonik_allegro/scene_left.xml': 'palm',
29+
'wonik_allegro/scene_right.xml': 'palm',
30+
}
31+
32+
33+
def test_menagerie():
34+
for xml_filename, body in _XMLS_AND_BODIES.items():
35+
xml_filename = _MENAGERIE_ROOT / xml_filename
36+
xml_dir = xml_filename.parent
37+
# Menagerie files assume the current working directory is the directory of the scene.xml
38+
os.chdir(xml_dir)
39+
with xml_filename.open('r') as f:
40+
xml = f.read()
41+
chain = pk.build_chain_from_mjcf(xml, body)
42+
print(xml_filename)
43+
print("=" * 32)
44+
print(f"\t {chain.get_frame_names()}")
45+
print(f"\t {chain.get_joint_parameter_names()}")
46+
th = np.zeros(len(chain.get_joint_parameter_names()))
47+
fk_dict = chain.forward_kinematics(th, end_only=True)
48+
49+
50+
if __name__ == '__main__':
51+
test_menagerie()

0 commit comments

Comments
 (0)