Skip to content

Commit c81daa3

Browse files
authored
Merge pull request #25 from eric-heiden/fwd-actuation
fwd_actuation, (transmission)
2 parents 547a2f7 + 20aad23 commit c81daa3

File tree

9 files changed

+259
-13
lines changed

9 files changed

+259
-13
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Exclude temporary folders
22
*.egg-info/
33
env/
4+
.history
45

56
# Python byte-compiled / optimized / DLL files
67
__pycache__/

mujoco/mjx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from ._src.forward import euler
1919
from ._src.forward import forward
20+
from ._src.forward import fwd_actuation
2021
from ._src.forward import fwd_acceleration
2122
from ._src.forward import fwd_position
2223
from ._src.forward import fwd_velocity
@@ -31,6 +32,7 @@
3132
from ._src.smooth import kinematics
3233
from ._src.smooth import rne
3334
from ._src.smooth import solve_m
35+
from ._src.smooth import transmission
3436
from ._src.support import is_sparse
3537
from ._src.test_util import benchmark
3638
from ._src.types import *

mujoco/mjx/_src/forward.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from . import passive
2222
from . import smooth
2323

24-
from .types import array2df
24+
from .types import array2df, array3df
2525
from .types import Model
2626
from .types import Data
2727
from .types import MJ_MINVAL
@@ -218,7 +218,7 @@ def fwd_position(m: Model, d: Data):
218218
smooth.factor_m(m, d)
219219
# TODO(team): collision_driver.collision
220220
# TODO(team): constraint.make_constraint
221-
# TODO(team): smooth.transmission
221+
smooth.transmission(m, d)
222222

223223

224224
def fwd_velocity(m: Model, d: Data):
@@ -241,6 +241,64 @@ def _actuator_velocity(d: Data):
241241
smooth.rne(m, d)
242242

243243

244+
def fwd_actuation(m: Model, d: Data):
245+
"""Actuation-dependent computations."""
246+
if not m.nu:
247+
return
248+
249+
# TODO support stateful actuators
250+
251+
@wp.kernel
252+
def _force(
253+
m: Model,
254+
ctrl: array2df,
255+
# outputs
256+
force: array2df,
257+
):
258+
worldid, dofid = wp.tid()
259+
gain = m.actuator_gainprm[dofid, 0]
260+
bias = m.actuator_biasprm[dofid, 0]
261+
# TODO support gain types other than FIXED
262+
c = ctrl[worldid, dofid]
263+
if m.actuator_ctrllimited[dofid]:
264+
r = m.actuator_ctrlrange[dofid]
265+
c = wp.clamp(c, r[0], r[1])
266+
f = gain * c + bias
267+
if m.actuator_forcelimited[dofid]:
268+
r = m.actuator_forcerange[dofid]
269+
f = wp.clamp(f, r[0], r[1])
270+
force[worldid, dofid] = f
271+
272+
wp.launch(
273+
_force, dim=[d.nworld, m.nu], inputs=[m, d.ctrl], outputs=[d.actuator_force]
274+
)
275+
276+
@wp.kernel
277+
def _qfrc(m: Model, moment: array3df, force: array2df, qfrc: array2df):
278+
worldid, vid = wp.tid()
279+
280+
s = float(0.0)
281+
for uid in range(m.nu):
282+
# TODO consider using Tile API or transpose moment for better access pattern
283+
s += moment[worldid, uid, vid] * force[worldid, uid]
284+
jntid = m.dof_jntid[vid]
285+
if m.jnt_actfrclimited[jntid]:
286+
r = m.jnt_actfrcrange[jntid]
287+
s = wp.clamp(s, r[0], r[1])
288+
qfrc[worldid, vid] = s
289+
290+
wp.launch(
291+
_qfrc,
292+
dim=(d.nworld, m.nv),
293+
inputs=[m, d.actuator_moment, d.actuator_force],
294+
outputs=[d.qfrc_actuator],
295+
)
296+
297+
# TODO actuator-level gravity compensation, skip if added as passive force
298+
299+
return d
300+
301+
244302
def fwd_acceleration(m: Model, d: Data):
245303
"""Add up all non-constraint forces, compute qacc_smooth."""
246304

@@ -269,7 +327,7 @@ def forward(m: Model, d: Data):
269327
# TODO(team): sensor.sensor_pos
270328
fwd_velocity(m, d)
271329
# TODO(team): sensor.sensor_vel
272-
# TODO(team): fwd_actuation
330+
fwd_actuation(m, d)
273331
fwd_acceleration(m, d)
274332
# TODO(team): sensor.sensor_acc
275333

mujoco/mjx/_src/forward_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import numpy as np
2121
import warp as wp
2222

23+
wp.config.verify_cuda = True
24+
2325
import mujoco
2426
from mujoco import mjx
2527

@@ -42,6 +44,8 @@ def _load(self, fname: str, is_sparse: bool = True):
4244
mjd = mujoco.MjData(mjm)
4345
mujoco.mj_resetDataKeyframe(mjm, mjd, 1) # reset to stand_on_left_leg
4446
mjd.qvel = np.random.uniform(low=-0.01, high=0.01, size=mjd.qvel.shape)
47+
mjd.ctrl = np.random.normal(scale=10, size=mjd.ctrl.shape)
48+
mjd.act = np.random.normal(scale=10, size=mjd.act.shape)
4549
mujoco.mj_forward(mjm, mjd)
4650
m = mjx.put_model(mjm)
4751
d = mjx.put_data(mjm, mjd)
@@ -59,6 +63,21 @@ def test_fwd_velocity(self):
5963
)
6064
_assert_eq(d.qfrc_bias.numpy()[0], mjd.qfrc_bias, "qfrc_bias")
6165

66+
def test_fwd_actuation(self):
67+
"""Tests MJX fwd_actuation."""
68+
mjm, mjd, m, d = self._load("humanoid/humanoid.xml", is_sparse=False)
69+
70+
mujoco.mj_fwdActuation(mjm, mjd)
71+
72+
for arr in (d.actuator_force, d.qfrc_actuator):
73+
arr.zero_()
74+
75+
mjx.fwd_actuation(m, d)
76+
77+
_assert_eq(d.ctrl.numpy()[0], mjd.ctrl, "ctrl")
78+
_assert_eq(d.actuator_force.numpy()[0], mjd.actuator_force, "actuator_force")
79+
_assert_eq(d.qfrc_actuator.numpy()[0], mjd.qfrc_actuator, "qfrc_actuator")
80+
6281
def test_fwd_acceleration(self):
6382
"""Tests MJX fwd_acceleration."""
6483
_, mjd, m, d = self._load("humanoid/humanoid.xml", is_sparse=False)

mujoco/mjx/_src/io.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ def put_model(mjm: mujoco.MjModel) -> types.Model:
122122
m.jnt_axis = wp.array(mjm.jnt_axis, dtype=wp.vec3, ndim=1)
123123
m.jnt_pos = wp.array(mjm.jnt_pos, dtype=wp.vec3, ndim=1)
124124
m.jnt_stiffness = wp.array(mjm.jnt_stiffness, dtype=wp.float32, ndim=1)
125+
m.jnt_actfrclimited = wp.array(mjm.jnt_actfrclimited, dtype=wp.bool, ndim=1)
126+
m.jnt_actfrcrange = wp.array(mjm.jnt_actfrcrange, dtype=wp.vec2, ndim=1)
125127
m.geom_pos = wp.array(mjm.geom_pos, dtype=wp.vec3, ndim=1)
126128
m.geom_quat = wp.array(mjm.geom_quat, dtype=wp.quat, ndim=1)
127129
m.site_pos = wp.array(mjm.site_pos, dtype=wp.vec3, ndim=1)
@@ -132,8 +134,17 @@ def put_model(mjm: mujoco.MjModel) -> types.Model:
132134
m.dof_Madr = wp.array(mjm.dof_Madr, dtype=wp.int32, ndim=1)
133135
m.dof_armature = wp.array(mjm.dof_armature, dtype=wp.float32, ndim=1)
134136
m.dof_damping = wp.array(mjm.dof_damping, dtype=wp.float32, ndim=1)
135-
m.actuator_actlimited = wp.array(mjm.actuator_actlimited, dtype=wp.int32, ndim=1)
136-
m.actuator_actrange = wp.array(mjm.actuator_actrange, dtype=wp.vec2f, ndim=1)
137+
m.actuator_trntype = wp.array(mjm.actuator_trntype, dtype=wp.int32, ndim=1)
138+
m.actuator_trnid = wp.array(mjm.actuator_trnid, dtype=wp.int32, ndim=2)
139+
m.actuator_ctrllimited = wp.array(mjm.actuator_ctrllimited, dtype=wp.bool, ndim=1)
140+
m.actuator_ctrlrange = wp.array(mjm.actuator_ctrlrange, dtype=wp.vec2, ndim=1)
141+
m.actuator_forcelimited = wp.array(mjm.actuator_forcelimited, dtype=wp.bool, ndim=1)
142+
m.actuator_forcerange = wp.array(mjm.actuator_forcerange, dtype=wp.vec2, ndim=1)
143+
m.actuator_gainprm = wp.array(mjm.actuator_gainprm, dtype=wp.float32, ndim=2)
144+
m.actuator_biasprm = wp.array(mjm.actuator_biasprm, dtype=wp.float32, ndim=2)
145+
m.actuator_gear = wp.array(mjm.actuator_gear, dtype=wp.spatial_vector, ndim=1)
146+
m.actuator_actlimited = wp.array(mjm.actuator_actlimited, dtype=wp.bool, ndim=1)
147+
m.actuator_actrange = wp.array(mjm.actuator_actrange, dtype=wp.vec2, ndim=1)
137148
m.actuator_actadr = wp.array(mjm.actuator_actadr, dtype=wp.int32, ndim=1)
138149
m.actuator_dyntype = wp.array(mjm.actuator_dyntype, dtype=wp.int32, ndim=1)
139150
m.actuator_dynprm = wp.array(mjm.actuator_dynprm, dtype=types.vec10f, ndim=1)
@@ -167,6 +178,10 @@ def make_data(mjm: mujoco.MjModel, nworld: int = 1) -> types.Data:
167178
d.site_xmat = wp.zeros((nworld, mjm.nsite), dtype=wp.mat33)
168179
d.cinert = wp.zeros((nworld, mjm.nbody), dtype=types.vec10)
169180
d.cdof = wp.zeros((nworld, mjm.nv), dtype=wp.spatial_vector)
181+
d.ctrl = wp.zeros((nworld, mjm.nu), dtype=wp.float32)
182+
d.actuator_velocity = wp.zeros((nworld, mjm.nu), dtype=wp.float32)
183+
d.actuator_force = wp.zeros((nworld, mjm.nu), dtype=wp.float32)
184+
d.actuator_length = wp.zeros((nworld, mjm.nu), dtype=wp.float32)
170185
d.actuator_moment = wp.zeros((nworld, mjm.nu, mjm.nv), dtype=wp.float32)
171186
d.crb = wp.zeros((nworld, mjm.nbody), dtype=types.vec10)
172187
if support.is_sparse(mjm):
@@ -178,7 +193,6 @@ def make_data(mjm: mujoco.MjModel, nworld: int = 1) -> types.Data:
178193
d.act_dot = wp.zeros((nworld, mjm.na), dtype=wp.float32)
179194
d.act = wp.zeros((nworld, mjm.na), dtype=wp.float32)
180195
d.qLDiagInv = wp.zeros((nworld, mjm.nv), dtype=wp.float32)
181-
d.actuator_velocity = wp.zeros((nworld, mjm.nu), dtype=wp.float32)
182196
d.cvel = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector)
183197
d.cdof_dot = wp.zeros((nworld, mjm.nv), dtype=wp.spatial_vector)
184198
d.qfrc_bias = wp.zeros((nworld, mjm.nv), dtype=wp.float32)
@@ -247,12 +261,15 @@ def tile(x):
247261
d.site_xmat = wp.array(tile(mjd.site_xmat), dtype=wp.mat33, ndim=2)
248262
d.cinert = wp.array(tile(mjd.cinert), dtype=types.vec10, ndim=2)
249263
d.cdof = wp.array(tile(mjd.cdof), dtype=wp.spatial_vector, ndim=2)
250-
d.actuator_moment = wp.array(tile(actuator_moment), dtype=wp.float32, ndim=3)
251264
d.crb = wp.array(tile(mjd.crb), dtype=types.vec10, ndim=2)
252265
d.qM = wp.array(tile(qM), dtype=wp.float32, ndim=3)
253266
d.qLD = wp.array(tile(qLD), dtype=wp.float32, ndim=3)
254267
d.qLDiagInv = wp.array(tile(mjd.qLDiagInv), dtype=wp.float32, ndim=2)
268+
d.ctrl = wp.array(tile(mjd.ctrl), dtype=wp.float32, ndim=2)
255269
d.actuator_velocity = wp.array(tile(mjd.actuator_velocity), dtype=wp.float32, ndim=2)
270+
d.actuator_force = wp.array(tile(mjd.actuator_force), dtype=wp.float32, ndim=2)
271+
d.actuator_length = wp.array(tile(mjd.actuator_length), dtype=wp.float32, ndim=2)
272+
d.actuator_moment = wp.array(tile(actuator_moment), dtype=wp.float32, ndim=3)
256273
d.cvel = wp.array(tile(mjd.cvel), dtype=wp.spatial_vector, ndim=2)
257274
d.cdof_dot = wp.array(tile(mjd.cdof_dot), dtype=wp.spatial_vector, ndim=2)
258275
d.qfrc_bias = wp.array(tile(mjd.qfrc_bias), dtype=wp.float32, ndim=2)

mujoco/mjx/_src/math.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ def quat_to_mat(quat: wp.quat) -> wp.mat33:
6262
)
6363

6464

65+
@wp.func
66+
def quat_inv(quat: wp.quat) -> wp.quat:
67+
return wp.quat(quat[0], -quat[1], -quat[2], -quat[3])
68+
69+
6570
@wp.func
6671
def inert_vec(i: types.vec10, v: wp.spatial_vector) -> wp.spatial_vector:
6772
"""mju_mulInertVec: multiply 6D vector (rotation, translation) by 6D inertia matrix."""

mujoco/mjx/_src/smooth.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
# ==============================================================================
1515

1616
import warp as wp
17-
1817
from . import math
1918

2019
from .types import Model
2120
from .types import Data
2221
from .types import array2df
2322
from .types import array3df
2423
from .types import vec10
24+
from .types import JointType, TrnType
2525

2626

2727
def kinematics(m: Model, d: Data):
@@ -433,6 +433,75 @@ def qfrc_bias(m: Model, d: Data, cfrc: wp.array(dtype=wp.spatial_vector, ndim=2)
433433
wp.launch(qfrc_bias, dim=[d.nworld, m.nv], inputs=[m, d, cfrc])
434434

435435

436+
def transmission(m: Model, d: Data):
437+
"""Computes actuator/transmission lengths and moments."""
438+
if not m.nu:
439+
return d
440+
441+
@wp.kernel
442+
def _transmission(
443+
m: Model,
444+
d: Data,
445+
# outputs
446+
length: array2df,
447+
moment: array3df,
448+
):
449+
worldid, actid = wp.tid()
450+
qpos = d.qpos[worldid]
451+
jntid = m.actuator_trnid[actid, 0]
452+
jnt_typ = m.jnt_type[jntid]
453+
qadr = m.jnt_qposadr[jntid]
454+
vadr = m.jnt_dofadr[jntid]
455+
trntype = m.actuator_trntype[actid]
456+
gear = m.actuator_gear[actid]
457+
if trntype == wp.static(TrnType.JOINT.value) or trntype == wp.static(
458+
TrnType.JOINTINPARENT.value
459+
):
460+
if jnt_typ == wp.static(JointType.FREE.value):
461+
length[worldid, actid] = 0.0
462+
if trntype == wp.static(TrnType.JOINTINPARENT.value):
463+
quat_neg = math.quat_inv(
464+
wp.quat(qpos[qadr + 3], qpos[qadr + 4], qpos[qadr + 5], qpos[qadr + 6])
465+
)
466+
gearaxis = math.rot_vec_quat(wp.spatial_bottom(gear), quat_neg)
467+
moment[worldid, actid, vadr + 0] = gear[0]
468+
moment[worldid, actid, vadr + 1] = gear[1]
469+
moment[worldid, actid, vadr + 2] = gear[2]
470+
moment[worldid, actid, vadr + 3] = gearaxis[0]
471+
moment[worldid, actid, vadr + 4] = gearaxis[1]
472+
moment[worldid, actid, vadr + 5] = gearaxis[2]
473+
else:
474+
for i in range(6):
475+
moment[worldid, actid, vadr + i] = gear[i]
476+
elif jnt_typ == wp.static(JointType.BALL.value):
477+
q = wp.quat(qpos[qadr + 0], qpos[qadr + 1], qpos[qadr + 2], qpos[qadr + 3])
478+
axis_angle = math.quat_to_vel(q)
479+
gearaxis = wp.spatial_top(gear) # [:3]
480+
if trntype == wp.static(TrnType.JOINTINPARENT.value):
481+
quat_neg = math.quat_inv(q)
482+
gearaxis = math.rot_vec_quat(gearaxis, quat_neg)
483+
length[worldid, actid] = wp.dot(axis_angle, gearaxis)
484+
for i in range(3):
485+
moment[worldid, actid, vadr + i] = gearaxis[i]
486+
elif jnt_typ == wp.static(JointType.SLIDE.value) or jnt_typ == wp.static(
487+
JointType.HINGE.value
488+
):
489+
length[worldid, actid] = qpos[qadr] * gear[0]
490+
moment[worldid, actid, vadr] = gear[0]
491+
else:
492+
wp.printf("unrecognized joint type")
493+
else:
494+
# TODO handle site, tendon transmission types
495+
wp.printf("unhandled transmission type %d\n", trntype)
496+
497+
wp.launch(
498+
_transmission,
499+
dim=[d.nworld, m.nu],
500+
inputs=[m, d],
501+
outputs=[d.actuator_length, d.actuator_moment],
502+
)
503+
504+
436505
def com_vel(m: Model, d: Data):
437506
"""Computes cvel, cdof_dot."""
438507

mujoco/mjx/_src/smooth_test.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222
import numpy as np
2323
import warp as wp
2424

25-
import mujoco
26-
from mujoco import mjx
25+
wp.config.verify_cuda = True
2726

2827
from . import test_util
2928

@@ -141,6 +140,26 @@ def test_com_vel(self):
141140
_assert_eq(d.cvel.numpy()[0], mjd.cvel, "cvel")
142141
_assert_eq(d.cdof_dot.numpy()[0], mjd.cdof_dot, "cdof_dot")
143142

143+
def test_transmission(self):
144+
"""Tests transmission."""
145+
mjm, mjd, m, d = test_util.fixture("pendula.xml")
146+
147+
for arr in (d.actuator_length, d.actuator_moment):
148+
arr.zero_()
149+
150+
actuator_moment = np.zeros((mjm.nu, mjm.nv))
151+
mujoco.mju_sparse2dense(
152+
actuator_moment,
153+
mjd.actuator_moment,
154+
mjd.moment_rownnz,
155+
mjd.moment_rowadr,
156+
mjd.moment_colind,
157+
)
158+
159+
mjx._src.smooth.transmission(m, d)
160+
_assert_eq(d.actuator_length.numpy()[0], mjd.actuator_length, "actuator_length")
161+
_assert_eq(d.actuator_moment.numpy()[0], actuator_moment, "actuator_moment")
162+
144163

145164
if __name__ == "__main__":
146165
wp.init()

0 commit comments

Comments
 (0)