Skip to content

Commit 53567d7

Browse files
authored
Merge pull request #50 from thowell/actuator_velocity_tile
compute actuator_velocity with tile operations
2 parents 0342f53 + a91d341 commit 53567d7

File tree

4 files changed

+99
-41
lines changed

4 files changed

+99
-41
lines changed

mujoco/mjx/_src/forward.py

Lines changed: 72 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -286,11 +286,11 @@ def qderiv_actuator_fused_kernel(
286286
m: Model, d: Data, damping: wp.array(dtype=wp.float32), leveladr: int
287287
):
288288
worldid, nodeid = wp.tid()
289-
offset_nv = m.qderiv_implicit_offset_nv[leveladr + nodeid]
289+
offset_nv = m.actuator_moment_offset_nv[leveladr + nodeid]
290290

291291
# skip tree with no actuators.
292292
if wp.static(actuation_enabled and tilesize_nu != 0):
293-
offset_nu = m.qderiv_implicit_offset_nu[leveladr + nodeid]
293+
offset_nu = m.actuator_moment_offset_nu[leveladr + nodeid]
294294
actuator_moment_tile = wp.tile_load(
295295
d.actuator_moment[worldid],
296296
shape=(tilesize_nu, tilesize_nv),
@@ -340,9 +340,9 @@ def qderiv_actuator_fused_kernel(
340340
block_dim=block_dim,
341341
)
342342

343-
qderiv_tilesize_nv = m.qderiv_implicit_tilesize_nv.numpy()
344-
qderiv_tilesize_nu = m.qderiv_implicit_tilesize_nu.numpy()
345-
qderiv_tileadr = m.qderiv_implicit_tileadr.numpy()
343+
qderiv_tilesize_nv = m.actuator_moment_tilesize_nv.numpy()
344+
qderiv_tilesize_nu = m.actuator_moment_tilesize_nu.numpy()
345+
qderiv_tileadr = m.actuator_moment_tileadr.numpy()
346346

347347
for i in range(len(qderiv_tileadr)):
348348
beg = qderiv_tileadr[i]
@@ -389,17 +389,75 @@ def fwd_position(m: Model, d: Data):
389389
def fwd_velocity(m: Model, d: Data):
390390
"""Velocity-dependent computations."""
391391

392-
# TODO(team): tile operations?
393-
d.actuator_velocity.zero_()
392+
if m.opt.is_sparse:
393+
# TODO(team): sparse version
394+
d.actuator_velocity.zero_()
394395

395-
@wp.kernel
396-
def _actuator_velocity(d: Data):
397-
worldid, actid, dofid = wp.tid()
398-
moment = d.actuator_moment[worldid, actid]
399-
qvel = d.qvel[worldid]
400-
wp.atomic_add(d.actuator_velocity[worldid], actid, moment[dofid] * qvel[dofid])
396+
@wp.kernel
397+
def _actuator_velocity(d: Data):
398+
worldid, actid, dofid = wp.tid()
399+
moment = d.actuator_moment[worldid, actid]
400+
qvel = d.qvel[worldid]
401+
wp.atomic_add(d.actuator_velocity[worldid], actid, moment[dofid] * qvel[dofid])
402+
403+
wp.launch(_actuator_velocity, dim=(d.nworld, m.nu, m.nv), inputs=[d])
404+
else:
405+
406+
def actuator_velocity(
407+
adr: int,
408+
size: int,
409+
tilesize_nu: int,
410+
tilesize_nv: int,
411+
):
412+
@wp.kernel
413+
def _actuator_velocity(
414+
m: Model, d: Data, leveladr: int, velocity: array3df, qvel: array3df
415+
):
416+
worldid, nodeid = wp.tid()
417+
offset_nu = m.actuator_moment_offset_nu[leveladr + nodeid]
418+
offset_nv = m.actuator_moment_offset_nv[leveladr + nodeid]
419+
actuator_moment_tile = wp.tile_load(
420+
d.actuator_moment[worldid],
421+
shape=(tilesize_nu, tilesize_nv),
422+
offset=(offset_nu, offset_nv),
423+
)
424+
qvel_tile = wp.tile_load(
425+
qvel[worldid], shape=(tilesize_nv, 1), offset=(offset_nv, 0)
426+
)
427+
velocity_tile = wp.tile_matmul(actuator_moment_tile, qvel_tile)
428+
429+
wp.tile_store(velocity[worldid], velocity_tile, offset=(offset_nu, 0))
430+
431+
wp.launch_tiled(
432+
_actuator_velocity,
433+
dim=(d.nworld, size),
434+
inputs=[
435+
m,
436+
d,
437+
adr,
438+
d.actuator_velocity.reshape(d.actuator_velocity.shape + (1,)),
439+
d.qvel.reshape(d.qvel.shape + (1,)),
440+
],
441+
block_dim=32,
442+
)
443+
444+
actuator_moment_tilesize_nu = m.actuator_moment_tilesize_nu.numpy()
445+
actuator_moment_tilesize_nv = m.actuator_moment_tilesize_nv.numpy()
446+
actuator_moment_tileadr = m.actuator_moment_tileadr.numpy()
401447

402-
wp.launch(_actuator_velocity, dim=(d.nworld, m.nu, m.nv), inputs=[d])
448+
for i in range(len(actuator_moment_tileadr)):
449+
beg = actuator_moment_tileadr[i]
450+
end = (
451+
m.actuator_moment_tileadr.shape[0]
452+
if i == len(actuator_moment_tileadr) - 1
453+
else actuator_moment_tileadr[i + 1]
454+
)
455+
actuator_velocity(
456+
beg,
457+
end - beg,
458+
int(actuator_moment_tilesize_nu[i]),
459+
int(actuator_moment_tilesize_nv[i]),
460+
)
403461

404462
smooth.com_vel(m, d)
405463
passive.passive(m, d)

mujoco/mjx/_src/forward_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def _load(self, fname: str, is_sparse: bool = True):
5757

5858
def test_fwd_velocity(self):
5959
"""Tests MJX fwd_velocity."""
60-
_, mjd, m, d = self._load("humanoid/humanoid.xml")
60+
_, mjd, m, d = self._load("humanoid/humanoid.xml", is_sparse=False)
6161

6262
d.actuator_velocity.zero_()
6363
mjx.fwd_velocity(m, d)

mujoco/mjx/_src/io.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,12 @@ def put_model(mjm: mujoco.MjModel) -> types.Model:
139139
qLD_tileadr = np.cumsum(tile_off)[:-1]
140140
qLD_tilesize = np.array(sorted(tiles.keys()))
141141

142-
# tiles for implicit integration - needs nu + nv tile size and offset
143-
qderiv_implicit_offset_nv = np.empty(shape=(0,), dtype=int)
144-
qderiv_implicit_offset_nu = np.empty(shape=(0,), dtype=int)
145-
qderiv_implicit_tileadr = np.empty(shape=(0,), dtype=int)
146-
qderiv_implicit_tilesize_nv = np.empty(shape=(0,), dtype=int)
147-
qderiv_implicit_tilesize_nu = np.empty(shape=(0,), dtype=int)
142+
# tiles for actuator_moment - needs nu + nv tile size and offset
143+
actuator_moment_offset_nv = np.empty(shape=(0,), dtype=int)
144+
actuator_moment_offset_nu = np.empty(shape=(0,), dtype=int)
145+
actuator_moment_tileadr = np.empty(shape=(0,), dtype=int)
146+
actuator_moment_tilesize_nv = np.empty(shape=(0,), dtype=int)
147+
actuator_moment_tilesize_nu = np.empty(shape=(0,), dtype=int)
148148

149149
if not support.is_sparse(mjm):
150150
# how many actuators for each tree
@@ -166,18 +166,18 @@ def put_model(mjm: mujoco.MjModel) -> types.Model:
166166
act_beg += act_num
167167

168168
sorted_keys = sorted(tiles.keys())
169-
qderiv_implicit_offset_nv = [
169+
actuator_moment_offset_nv = [
170170
t[0] for key in sorted_keys for t in tiles.get(key, [])
171171
]
172-
qderiv_implicit_offset_nu = [
172+
actuator_moment_offset_nu = [
173173
t[1] for key in sorted_keys for t in tiles.get(key, [])
174174
]
175175
tile_off = [0] + [len(tiles[sz]) for sz in sorted(tiles.keys())]
176-
qderiv_implicit_tileadr = np.cumsum(tile_off)[:-1] # offset
177-
qderiv_implicit_tilesize_nv = np.array(
176+
actuator_moment_tileadr = np.cumsum(tile_off)[:-1] # offset
177+
actuator_moment_tilesize_nv = np.array(
178178
[a[0] for a in sorted_keys]
179179
) # for this level
180-
qderiv_implicit_tilesize_nu = np.array(
180+
actuator_moment_tilesize_nu = np.array(
181181
[int(a[1]) for a in sorted_keys]
182182
) # for this level
183183

@@ -193,20 +193,20 @@ def put_model(mjm: mujoco.MjModel) -> types.Model:
193193
m.qLD_tile = wp.array(qLD_tile, dtype=wp.int32, ndim=1)
194194
m.qLD_tileadr = wp.array(qLD_tileadr, dtype=wp.int32, ndim=1, device="cpu")
195195
m.qLD_tilesize = wp.array(qLD_tilesize, dtype=wp.int32, ndim=1, device="cpu")
196-
m.qderiv_implicit_offset_nv = wp.array(
197-
qderiv_implicit_offset_nv, dtype=wp.int32, ndim=1
196+
m.actuator_moment_offset_nv = wp.array(
197+
actuator_moment_offset_nv, dtype=wp.int32, ndim=1
198198
)
199-
m.qderiv_implicit_offset_nu = wp.array(
200-
qderiv_implicit_offset_nu, dtype=wp.int32, ndim=1
199+
m.actuator_moment_offset_nu = wp.array(
200+
actuator_moment_offset_nu, dtype=wp.int32, ndim=1
201201
)
202-
m.qderiv_implicit_tileadr = wp.array(
203-
qderiv_implicit_tileadr, dtype=wp.int32, ndim=1, device="cpu"
202+
m.actuator_moment_tileadr = wp.array(
203+
actuator_moment_tileadr, dtype=wp.int32, ndim=1, device="cpu"
204204
)
205-
m.qderiv_implicit_tilesize_nv = wp.array(
206-
qderiv_implicit_tilesize_nv, dtype=wp.int32, ndim=1, device="cpu"
205+
m.actuator_moment_tilesize_nv = wp.array(
206+
actuator_moment_tilesize_nv, dtype=wp.int32, ndim=1, device="cpu"
207207
)
208-
m.qderiv_implicit_tilesize_nu = wp.array(
209-
qderiv_implicit_tilesize_nu, dtype=wp.int32, ndim=1, device="cpu"
208+
m.actuator_moment_tilesize_nu = wp.array(
209+
actuator_moment_tilesize_nu, dtype=wp.int32, ndim=1, device="cpu"
210210
)
211211
m.body_dofadr = wp.array(mjm.body_dofadr, dtype=wp.int32, ndim=1)
212212
m.body_dofnum = wp.array(mjm.body_dofnum, dtype=wp.int32, ndim=1)

mujoco/mjx/_src/types.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -239,11 +239,11 @@ class Model:
239239
qpos_spring: wp.array(dtype=wp.float32, ndim=1)
240240
body_tree: wp.array(dtype=wp.int32, ndim=1) # warp only
241241
body_treeadr: wp.array(dtype=wp.int32, ndim=1) # warp only
242-
qderiv_implicit_offset_nv: wp.array(dtype=wp.int32, ndim=1) # warp only
243-
qderiv_implicit_offset_nu: wp.array(dtype=wp.int32, ndim=1) # warp only
244-
qderiv_implicit_tileadr: wp.array(dtype=wp.int32, ndim=1) # warp only
245-
qderiv_implicit_tilesize_nv: wp.array(dtype=wp.int32, ndim=1) # warp only
246-
qderiv_implicit_tilesize_nu: wp.array(dtype=wp.int32, ndim=1) # warp only
242+
actuator_moment_offset_nv: wp.array(dtype=wp.int32, ndim=1) # warp only
243+
actuator_moment_offset_nu: wp.array(dtype=wp.int32, ndim=1) # warp only
244+
actuator_moment_tileadr: wp.array(dtype=wp.int32, ndim=1) # warp only
245+
actuator_moment_tilesize_nv: wp.array(dtype=wp.int32, ndim=1) # warp only
246+
actuator_moment_tilesize_nu: wp.array(dtype=wp.int32, ndim=1) # warp only
247247
qM_fullm_i: wp.array(dtype=wp.int32, ndim=1) # warp only
248248
qM_fullm_j: wp.array(dtype=wp.int32, ndim=1) # warp only
249249
qM_mulm_i: wp.array(dtype=wp.int32, ndim=1) # warp only

0 commit comments

Comments
 (0)