Skip to content

Commit 9637a69

Browse files
authored
Merge pull request #81 from erikfrey/no_alloc
Remove inline allocs from smooth and xfrc_accumulate.
2 parents dad5ba2 + c5ca4cb commit 9637a69

File tree

6 files changed

+56
-103
lines changed

6 files changed

+56
-103
lines changed

mujoco/mjx/_src/forward.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -635,27 +635,18 @@ def qfrc_actuator_kernel(
635635
def fwd_acceleration(m: Model, d: Data):
636636
"""Add up all non-constraint forces, compute qacc_smooth."""
637637

638-
qfrc_applied = d.qfrc_applied
639-
qfrc_accumulated = xfrc_accumulate(m, d)
640-
641638
@kernel
642-
def _qfrc_smooth(
643-
d: Data,
644-
qfrc_applied: wp.array(ndim=2, dtype=wp.float32),
645-
qfrc_accumulated: wp.array(ndim=2, dtype=wp.float32),
646-
):
639+
def _qfrc_smooth(d: Data):
647640
worldid, dofid = wp.tid()
648641
d.qfrc_smooth[worldid, dofid] = (
649642
d.qfrc_passive[worldid, dofid]
650643
- d.qfrc_bias[worldid, dofid]
651644
+ d.qfrc_actuator[worldid, dofid]
652-
+ qfrc_applied[worldid, dofid]
653-
+ qfrc_accumulated[worldid, dofid]
645+
+ d.qfrc_applied[worldid, dofid]
654646
)
655647

656-
wp.launch(
657-
_qfrc_smooth, dim=(d.nworld, m.nv), inputs=[d, qfrc_applied, qfrc_accumulated]
658-
)
648+
wp.launch(_qfrc_smooth, dim=(d.nworld, m.nv), inputs=[d])
649+
xfrc_accumulate(m, d, d.qfrc_smooth)
659650

660651
smooth.solve_m(m, d, d.qacc_smooth, d.qfrc_smooth)
661652

mujoco/mjx/_src/io.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,9 @@ def make_data(
441441
d.qfrc_constraint = wp.zeros((nworld, mjm.nv), dtype=wp.float32)
442442
d.qacc_smooth = wp.zeros((nworld, mjm.nv), dtype=wp.float32)
443443

444+
d.rne_cacc = wp.zeros(shape=(d.nworld, mjm.nbody), dtype=wp.spatial_vector)
445+
d.rne_cfrc = wp.zeros(shape=(d.nworld, mjm.nbody), dtype=wp.spatial_vector)
446+
444447
d.xfrc_applied = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector)
445448

446449
# internal tmp arrays
@@ -661,6 +664,9 @@ def tile(x):
661664
d.contact.efc_address = wp.array(con_efc_address, dtype=wp.int32, ndim=1)
662665
d.contact.worldid = wp.array(con_worldid, dtype=wp.int32, ndim=1)
663666

667+
d.rne_cacc = wp.zeros(shape=(d.nworld, mjm.nbody), dtype=wp.spatial_vector)
668+
d.rne_cfrc = wp.zeros(shape=(d.nworld, mjm.nbody), dtype=wp.spatial_vector)
669+
664670
d.efc = _constraint(mjm.nv, d.nworld, d.njmax)
665671
d.efc.J = wp.array(efc_J_fill, dtype=wp.float32, ndim=2)
666672
d.efc.D = wp.array(efc_D_fill, dtype=wp.float32, ndim=1)

mujoco/mjx/_src/smooth.py

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -403,74 +403,68 @@ def factor_m(m: Model, d: Data):
403403
def rne(m: Model, d: Data):
404404
"""Computes inverse dynamics using Newton-Euler algorithm."""
405405

406-
cacc = wp.zeros(shape=(d.nworld, m.nbody), dtype=wp.spatial_vector)
407-
cfrc = wp.zeros(shape=(d.nworld, m.nbody), dtype=wp.spatial_vector)
408-
409406
@kernel
410-
def cacc_gravity(m: Model, cacc: wp.array(dtype=wp.spatial_vector, ndim=2)):
407+
def cacc_gravity(m: Model, d: Data):
411408
worldid = wp.tid()
412-
cacc[worldid, 0] = wp.spatial_vector(wp.vec3(0.0), -m.opt.gravity)
409+
d.rne_cacc[worldid, 0] = wp.spatial_vector(wp.vec3(0.0), -m.opt.gravity)
413410

414411
@kernel
415412
def cacc_level(
416413
m: Model,
417414
d: Data,
418-
cacc: wp.array(dtype=wp.spatial_vector, ndim=2),
419415
leveladr: int,
420416
):
421417
worldid, nodeid = wp.tid()
422418
bodyid = m.body_tree[leveladr + nodeid]
423419
dofnum = m.body_dofnum[bodyid]
424420
pid = m.body_parentid[bodyid]
425421
dofadr = m.body_dofadr[bodyid]
426-
local_cacc = cacc[worldid, pid]
422+
local_cacc = d.rne_cacc[worldid, pid]
427423
for i in range(dofnum):
428424
local_cacc += d.cdof_dot[worldid, dofadr + i] * d.qvel[worldid, dofadr + i]
429-
cacc[worldid, bodyid] = local_cacc
425+
d.rne_cacc[worldid, bodyid] = local_cacc
430426

431427
@kernel
432-
def frc_fn(
433-
d: Data,
434-
cfrc: wp.array(dtype=wp.spatial_vector, ndim=2),
435-
cacc: wp.array(dtype=wp.spatial_vector, ndim=2),
436-
):
428+
def frc_fn(d: Data):
437429
worldid, bodyid = wp.tid()
438-
frc = math.inert_vec(d.cinert[worldid, bodyid], cacc[worldid, bodyid])
430+
frc = math.inert_vec(d.cinert[worldid, bodyid], d.rne_cacc[worldid, bodyid])
439431
frc += math.motion_cross_force(
440432
d.cvel[worldid, bodyid],
441433
math.inert_vec(d.cinert[worldid, bodyid], d.cvel[worldid, bodyid]),
442434
)
443-
cfrc[worldid, bodyid] += frc
435+
d.rne_cfrc[worldid, bodyid] = frc
444436

445437
@kernel
446-
def cfrc_fn(m: Model, cfrc: wp.array(dtype=wp.spatial_vector, ndim=2), leveladr: int):
438+
def cfrc_fn(m: Model, d: Data, leveladr: int):
447439
worldid, nodeid = wp.tid()
448440
bodyid = m.body_tree[leveladr + nodeid]
449441
pid = m.body_parentid[bodyid]
450-
wp.atomic_add(cfrc[worldid], pid, cfrc[worldid, bodyid])
442+
wp.atomic_add(d.rne_cfrc[worldid], pid, d.rne_cfrc[worldid, bodyid])
451443

452444
@kernel
453-
def qfrc_bias(m: Model, d: Data, cfrc: wp.array(dtype=wp.spatial_vector, ndim=2)):
445+
def qfrc_bias(m: Model, d: Data):
454446
worldid, dofid = wp.tid()
455447
bodyid = m.dof_bodyid[dofid]
456-
d.qfrc_bias[worldid, dofid] = wp.dot(d.cdof[worldid, dofid], cfrc[worldid, bodyid])
448+
d.qfrc_bias[worldid, dofid] = wp.dot(
449+
d.cdof[worldid, dofid], d.rne_cfrc[worldid, bodyid]
450+
)
457451

458-
wp.launch(cacc_gravity, dim=[d.nworld], inputs=[m, cacc])
452+
wp.launch(cacc_gravity, dim=[d.nworld], inputs=[m, d])
459453

460454
body_treeadr = m.body_treeadr.numpy()
461455
for i in range(len(body_treeadr)):
462456
beg = body_treeadr[i]
463457
end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1]
464-
wp.launch(cacc_level, dim=(d.nworld, end - beg), inputs=[m, d, cacc, beg])
458+
wp.launch(cacc_level, dim=(d.nworld, end - beg), inputs=[m, d, beg])
465459

466-
wp.launch(frc_fn, dim=[d.nworld, m.nbody], inputs=[d, cfrc, cacc])
460+
wp.launch(frc_fn, dim=[d.nworld, m.nbody], inputs=[d])
467461

468462
for i in reversed(range(len(body_treeadr))):
469463
beg = body_treeadr[i]
470464
end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1]
471-
wp.launch(cfrc_fn, dim=[d.nworld, end - beg], inputs=[m, cfrc, beg])
465+
wp.launch(cfrc_fn, dim=[d.nworld, end - beg], inputs=[m, d, beg])
472466

473-
wp.launch(qfrc_bias, dim=[d.nworld, m.nv], inputs=[m, d, cfrc])
467+
wp.launch(qfrc_bias, dim=[d.nworld, m.nv], inputs=[m, d])
474468

475469

476470
@event_scope

mujoco/mjx/_src/support.py

Lines changed: 21 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -113,40 +113,25 @@ def _mul_m_sparse_ij(
113113
)
114114

115115

116-
@wp.kernel
117-
def process_level(
118-
body_tree: wp.array(ndim=1, dtype=int),
119-
body_parentid: wp.array(ndim=1, dtype=int),
120-
dof_bodyid: wp.array(ndim=1, dtype=int),
121-
mask: wp.array2d(dtype=wp.bool),
122-
beg: int,
123-
):
124-
dofid, tid_y = wp.tid()
125-
j = beg + tid_y
126-
el = body_tree[j]
127-
parent_id = body_parentid[el]
128-
parent_val = mask[dofid, parent_id]
129-
mask[dofid, el] = parent_val or (dof_bodyid[dofid] == el)
130-
131-
132-
@wp.kernel
133-
def compute_qfrc(
134-
d: Data,
135-
m: Model,
136-
mask: wp.array2d(dtype=wp.bool),
137-
qfrc_total: array2df,
138-
):
139-
worldid, dofid = wp.tid()
140-
accumul = float(0.0)
141-
cdof_vec = d.cdof[worldid, dofid]
142-
rotational_cdof = wp.vec3(cdof_vec[0], cdof_vec[1], cdof_vec[2])
143-
144-
jac = wp.spatial_vector(
145-
cdof_vec[3], cdof_vec[4], cdof_vec[5], cdof_vec[0], cdof_vec[1], cdof_vec[2]
146-
)
147-
148-
for bodyid in range(m.nbody):
149-
if mask[dofid, bodyid]:
116+
@event_scope
117+
def xfrc_accumulate(m: Model, d: Data, qfrc: array2df):
118+
@wp.kernel
119+
def _accumulate(m: Model, d: Data, qfrc: array2df):
120+
worldid, dofid = wp.tid()
121+
cdof = d.cdof[worldid, dofid]
122+
rotational_cdof = wp.vec3(cdof[0], cdof[1], cdof[2])
123+
jac = wp.spatial_vector(cdof[3], cdof[4], cdof[5], cdof[0], cdof[1], cdof[2])
124+
125+
dof_bodyid = m.dof_bodyid[dofid]
126+
accumul = float(0.0)
127+
128+
for bodyid in range(dof_bodyid, m.nbody):
129+
# any body that is in the subtree of dof_bodyid is part of the jacobian
130+
parentid = bodyid
131+
while parentid != 0 and parentid != dof_bodyid:
132+
parentid = m.body_parentid[parentid]
133+
if parentid == 0:
134+
continue # body is not part of the subtree
150135
offset = d.xipos[worldid, bodyid] - d.subtree_com[worldid, m.body_rootid[bodyid]]
151136
cross_term = wp.cross(rotational_cdof, offset)
152137
accumul += wp.dot(jac, d.xfrc_applied[worldid, bodyid]) + wp.dot(
@@ -158,30 +143,9 @@ def compute_qfrc(
158143
),
159144
)
160145

161-
qfrc_total[worldid, dofid] = accumul
162-
163-
164-
@event_scope
165-
def xfrc_accumulate(m: Model, d: Data) -> array2df:
166-
body_treeadr_np = m.body_treeadr.numpy()
167-
mask = wp.zeros((m.nv, m.nbody), dtype=wp.bool)
168-
169-
for i in range(len(body_treeadr_np)):
170-
beg = body_treeadr_np[i]
171-
end = m.nbody if i == len(body_treeadr_np) - 1 else body_treeadr_np[i + 1]
172-
173-
if end > beg:
174-
wp.launch(
175-
kernel=process_level,
176-
dim=[m.nv, (end - beg)],
177-
inputs=[m.body_tree, m.body_parentid, m.dof_bodyid, mask, beg],
178-
)
179-
180-
qfrc_total = wp.zeros((d.nworld, m.nv), dtype=float)
181-
182-
wp.launch(kernel=compute_qfrc, dim=(d.nworld, m.nv), inputs=[d, m, mask, qfrc_total])
146+
qfrc[worldid, dofid] += accumul
183147

184-
return qfrc_total
148+
wp.launch(kernel=_accumulate, dim=(d.nworld, m.nv), inputs=[m, d, qfrc])
185149

186150

187151
@wp.func

mujoco/mjx/_src/support_test.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,14 @@ def test_xfrc_accumulated(self):
6060
mjm, mjd, m, d = test_util.fixture("pendula.xml")
6161
xfrc = np.random.randn(*d.xfrc_applied.numpy().shape)
6262
d.xfrc_applied = wp.from_numpy(xfrc, dtype=wp.spatial_vector)
63-
qfrc = xfrc_accumulate(m, d)
63+
qfrc = wp.zeros((1, mjm.nv), dtype=wp.float32)
64+
xfrc_accumulate(m, d, qfrc)
6465

6566
qfrc_expected = np.zeros(m.nv)
6667
xfrc = xfrc[0]
67-
mjd.xfrc_applied[:] = xfrc
6868
for i in range(1, m.nbody):
6969
mujoco.mj_applyFT(
70-
mjm,
71-
mjd,
72-
mjd.xfrc_applied[i, :3],
73-
mjd.xfrc_applied[i, 3:],
74-
mjd.xipos[i],
75-
i,
76-
qfrc_expected,
70+
mjm, mjd, xfrc[i, :3], xfrc[i, 3:], mjd.xipos[i], i, qfrc_expected
7771
)
7872
np.testing.assert_almost_equal(qfrc.numpy()[0], qfrc_expected, 6)
7973

mujoco/mjx/_src/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,10 @@ class Data:
464464
contact: Contact
465465
efc: Constraint
466466

467+
# arrays used for smooth.rne
468+
rne_cacc: wp.array(dtype=wp.spatial_vector, ndim=2)
469+
rne_cfrc: wp.array(dtype=wp.spatial_vector, ndim=2)
470+
467471
# temp arrays
468472
qfrc_integration: wp.array(dtype=wp.float32, ndim=2)
469473
qacc_integration: wp.array(dtype=wp.float32, ndim=2)

0 commit comments

Comments
 (0)