Skip to content

Commit bfde8e1

Browse files
authored
Merge pull request #35 from thowell/contact
update contact
2 parents 4f2fb6c + 65628ed commit bfde8e1

File tree

3 files changed

+126
-59
lines changed

3 files changed

+126
-59
lines changed

mujoco/mjx/_src/constraint.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ def _update_efc_row(
2626
pos_aref: wp.float32,
2727
pos_imp: wp.float32,
2828
invweight: wp.float32,
29-
solref: wp.array(ndim=1, dtype=wp.float32),
30-
solimp: wp.array(ndim=1, dtype=wp.float32),
29+
solref: wp.vec2f,
30+
solimp: types.vec5,
3131
margin: wp.float32,
3232
refsafe: bool,
3333
Jqvel: float,
@@ -145,22 +145,26 @@ def _efc_contact_pyramidal(
145145
d: types.Data,
146146
refsafe: bool,
147147
):
148-
worldid, conid, dimid = wp.tid()
148+
conid, dimid = wp.tid()
149149

150-
if d.contact.dim[worldid, conid] != 3:
150+
if conid >= d.ncon_total[0]:
151151
return
152152

153-
pos = d.contact.dist[worldid, conid] - d.contact.includemargin[worldid, conid]
153+
if d.contact.dim[conid] != 3:
154+
return
155+
156+
pos = d.contact.dist[conid] - d.contact.includemargin[conid]
154157
active = pos < 0
155158

156159
if active:
157160
efcid = wp.atomic_add(d.nefc_total, 0, 1)
161+
worldid = d.contact.worldid[conid]
158162
d.efc_worldid[efcid] = worldid
159163

160-
body1 = m.geom_bodyid[d.contact.geom[worldid, conid, 0]]
161-
body2 = m.geom_bodyid[d.contact.geom[worldid, conid, 1]]
164+
body1 = m.geom_bodyid[d.contact.geom[conid][0]]
165+
body2 = m.geom_bodyid[d.contact.geom[conid][1]]
162166

163-
fri0 = d.contact.friction[worldid, conid, 0]
167+
fri0 = d.contact.friction[conid][0]
164168

165169
# pyramidal has common invweight across all edges
166170
invweight = m.body_invweight0[body1, 0] + m.body_invweight0[body2, 0]
@@ -174,16 +178,16 @@ def _efc_contact_pyramidal(
174178
diff_0 = float(0.0)
175179
diff_i = float(0.0)
176180
for xyz in range(3):
177-
con_pos = d.contact.pos[worldid, conid]
181+
con_pos = d.contact.pos[conid]
178182
jac1p = _jac(m, d, con_pos, xyz, body1, i, worldid)
179183
jac2p = _jac(m, d, con_pos, xyz, body2, i, worldid)
180184
jac_dif = jac2p - jac1p
181-
diff_0 += d.contact.frame[worldid, conid][0, xyz] * jac_dif
182-
diff_i += d.contact.frame[worldid, conid][dimid2, xyz] * jac_dif
185+
diff_0 += d.contact.frame[conid][0, xyz] * jac_dif
186+
diff_i += d.contact.frame[conid][dimid2, xyz] * jac_dif
183187
if dimid % 2 == 0:
184-
J = diff_0 + diff_i * d.contact.friction[worldid, conid, dimid2 - 1]
188+
J = diff_0 + diff_i * d.contact.friction[conid][dimid2 - 1]
185189
else:
186-
J = diff_0 - diff_i * d.contact.friction[worldid, conid, dimid2 - 1]
190+
J = diff_0 - diff_i * d.contact.friction[conid][dimid2 - 1]
187191

188192
d.efc_J[efcid, i] = J
189193
Jqvel += J * d.qvel[worldid, i]
@@ -196,9 +200,9 @@ def _efc_contact_pyramidal(
196200
pos,
197201
pos,
198202
invweight,
199-
d.contact.solref[worldid, conid],
200-
d.contact.solimp[worldid, conid],
201-
d.contact.includemargin[worldid, conid],
203+
d.contact.solref[conid],
204+
d.contact.solimp[conid],
205+
d.contact.includemargin[conid],
202206
refsafe,
203207
Jqvel,
204208
)
@@ -225,6 +229,6 @@ def make_constraint(m: types.Model, d: types.Data):
225229
if m.opt.cone == types.ConeType.PYRAMIDAL.value:
226230
wp.launch(
227231
_efc_contact_pyramidal,
228-
dim=(d.nworld, d.ncon, 4),
232+
dim=(d.nconmax, 4),
229233
inputs=[m, d, refsafe],
230234
)

mujoco/mjx/_src/io.py

Lines changed: 84 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,8 @@ def put_model(mjm: mujoco.MjModel) -> types.Model:
155155
jnt_limited_slide_hinge_adr, dtype=wp.int32, ndim=1
156156
)
157157
m.jnt_type = wp.array(mjm.jnt_type, dtype=wp.int32, ndim=1)
158-
m.jnt_solref = wp.array(mjm.jnt_solref, dtype=wp.float32, ndim=2)
159-
m.jnt_solimp = wp.array(mjm.jnt_solimp, dtype=wp.float32, ndim=2)
158+
m.jnt_solref = wp.array(mjm.jnt_solref, dtype=wp.vec2f, ndim=1)
159+
m.jnt_solimp = wp.array(mjm.jnt_solimp, dtype=types.vec5, ndim=1)
160160
m.jnt_qposadr = wp.array(mjm.jnt_qposadr, dtype=wp.int32, ndim=1)
161161
m.jnt_dofadr = wp.array(mjm.jnt_dofadr, dtype=wp.int32, ndim=1)
162162
m.jnt_axis = wp.array(mjm.jnt_axis, dtype=wp.vec3, ndim=1)
@@ -198,12 +198,19 @@ def put_model(mjm: mujoco.MjModel) -> types.Model:
198198
return m
199199

200200

201-
def make_data(mjm: mujoco.MjModel, nworld: int = 1, njmax: int = -1) -> types.Data:
201+
def make_data(
202+
mjm: mujoco.MjModel, nworld: int = 1, nconmax: int = -1, njmax: int = -1
203+
) -> types.Data:
202204
d = types.Data()
203205
d.nworld = nworld
206+
d.ncon_total = wp.zeros((1,), dtype=wp.int32, ndim=1)
204207
d.nefc_total = wp.zeros((1,), dtype=wp.int32, ndim=1)
205208

206209
# TODO(team): move to Model?
210+
if nconmax == -1:
211+
# TODO(team): heuristic for nconmax
212+
nconmax = 512
213+
d.nconmax = nconmax
207214
if njmax == -1:
208215
# TODO(team): heuristic for njmax
209216
njmax = 512
@@ -255,17 +262,18 @@ def make_data(mjm: mujoco.MjModel, nworld: int = 1, njmax: int = -1) -> types.Da
255262
d.cdof_dot = wp.zeros((nworld, mjm.nv), dtype=wp.spatial_vector)
256263
d.qfrc_bias = wp.zeros((nworld, mjm.nv), dtype=wp.float32)
257264
d.contact = types.Contact()
258-
d.contact.dist = wp.zeros((nworld, d.ncon), dtype=wp.float32)
259-
d.contact.pos = wp.zeros((nworld, d.ncon), dtype=wp.vec3f)
260-
d.contact.frame = wp.zeros((nworld, d.ncon), dtype=wp.mat33f)
261-
d.contact.includemargin = wp.zeros((nworld, d.ncon), dtype=wp.float32)
262-
d.contact.friction = wp.zeros((nworld, d.ncon, 5), dtype=wp.float32)
263-
d.contact.solref = wp.zeros((nworld, d.ncon, types.MJ_NREF), dtype=wp.float32)
264-
d.contact.solreffriction = wp.zeros((nworld, d.ncon, types.MJ_NREF), dtype=wp.float32)
265-
d.contact.solimp = wp.zeros((nworld, d.ncon, types.MJ_NIMP), dtype=wp.float32)
266-
d.contact.dim = wp.zeros((nworld, d.ncon), dtype=wp.int32)
267-
d.contact.geom = wp.zeros((nworld, d.ncon, 2), dtype=wp.int32)
268-
d.contact.efc_address = wp.zeros((nworld, d.ncon), dtype=wp.int32)
265+
d.contact.dist = wp.zeros((nconmax,), dtype=wp.float32)
266+
d.contact.pos = wp.zeros((nconmax,), dtype=wp.vec3f)
267+
d.contact.frame = wp.zeros((nconmax,), dtype=wp.mat33f)
268+
d.contact.includemargin = wp.zeros((nconmax,), dtype=wp.float32)
269+
d.contact.friction = wp.zeros((nconmax,), dtype=types.vec5)
270+
d.contact.solref = wp.zeros((nconmax,), dtype=wp.vec2f)
271+
d.contact.solreffriction = wp.zeros((nconmax,), dtype=wp.vec2f)
272+
d.contact.solimp = wp.zeros((nconmax,), dtype=types.vec5)
273+
d.contact.dim = wp.zeros((nconmax,), dtype=wp.int32)
274+
d.contact.geom = wp.zeros((nconmax,), dtype=wp.vec2i)
275+
d.contact.efc_address = wp.zeros((nconmax,), dtype=wp.int32)
276+
d.contact.worldid = wp.zeros((nconmax,), dtype=wp.int32)
269277
d.qfrc_passive = wp.zeros((nworld, mjm.nv), dtype=wp.float32)
270278
d.qfrc_spring = wp.zeros((nworld, mjm.nv), dtype=wp.float32)
271279
d.qfrc_damper = wp.zeros((nworld, mjm.nv), dtype=wp.float32)
@@ -316,13 +324,22 @@ def make_data(mjm: mujoco.MjModel, nworld: int = 1, njmax: int = -1) -> types.Da
316324

317325

318326
def put_data(
319-
mjm: mujoco.MjModel, mjd: mujoco.MjData, nworld: int = 1, njmax: int = -1
327+
mjm: mujoco.MjModel,
328+
mjd: mujoco.MjData,
329+
nworld: int = 1,
330+
nconmax: int = -1,
331+
njmax: int = -1,
320332
) -> types.Data:
321333
d = types.Data()
322334
d.nworld = nworld
335+
d.ncon_total = wp.array([mjd.ncon * nworld], dtype=wp.int32, ndim=1)
323336
d.nefc_total = wp.array([mjd.nefc * nworld], dtype=wp.int32, ndim=1)
324337

325338
# TODO(team): move to Model?
339+
if nconmax == -1:
340+
# TODO(team): heuristic for nconmax
341+
nconmax = 512
342+
d.nconmax = nconmax
326343
if njmax == -1:
327344
# TODO(team): heuristic for njmax
328345
njmax = 512
@@ -437,25 +454,63 @@ def tile(x):
437454
d.efc_force = wp.array(efc_force_fill, dtype=wp.float32, ndim=1)
438455
d.efc_margin = wp.array(efc_margin_fill, dtype=wp.float32, ndim=1)
439456
d.efc_worldid = wp.from_numpy(efc_worldid, dtype=wp.int32)
457+
440458
d.act = wp.array(tile(mjd.act), dtype=wp.float32, ndim=2)
441459
d.act_dot = wp.array(tile(mjd.act_dot), dtype=wp.float32, ndim=2)
442-
d.contact.dist = wp.array(tile(mjd.contact.dist), dtype=wp.float32, ndim=2)
443-
d.contact.pos = wp.array(tile(mjd.contact.pos), dtype=wp.vec3f, ndim=2)
444-
d.contact.frame = wp.array(tile(mjd.contact.frame), dtype=wp.mat33f, ndim=2)
445-
d.contact.includemargin = wp.array(
446-
tile(mjd.contact.includemargin), dtype=wp.float32, ndim=2
460+
461+
ncon = mjd.ncon
462+
con_efc_address = np.zeros(nconmax, dtype=int)
463+
con_worldid = np.zeros(nconmax, dtype=int)
464+
465+
for i in range(nworld):
466+
con_efc_address[i * ncon : (i + 1) * ncon] = mjd.contact.efc_address + i * ncon
467+
con_worldid[i * ncon : (i + 1) * ncon] = i
468+
469+
ncon_fill = nconmax - nworld * ncon
470+
471+
con_dist_fill = np.concatenate(
472+
[np.repeat(mjd.contact.dist, nworld, axis=0), np.zeros(ncon_fill)]
473+
)
474+
con_pos_fill = np.vstack(
475+
[np.repeat(mjd.contact.pos, nworld, axis=0), np.zeros((ncon_fill, 3))]
447476
)
448-
d.contact.friction = wp.array(tile(mjd.contact.friction), dtype=wp.float32, ndim=3)
449-
d.contact.solref = wp.array(tile(mjd.contact.solref), dtype=wp.float32, ndim=3)
450-
d.contact.solreffriction = wp.array(
451-
tile(mjd.contact.solreffriction), dtype=wp.float32, ndim=3
477+
con_frame_fill = np.vstack(
478+
[np.repeat(mjd.contact.frame, nworld, axis=0), np.zeros((ncon_fill, 9))]
452479
)
453-
d.contact.solimp = wp.array(tile(mjd.contact.solimp), dtype=wp.float32, ndim=3)
454-
d.contact.dim = wp.array(tile(mjd.contact.dim), dtype=wp.int32, ndim=2)
455-
d.contact.geom = wp.array(tile(mjd.contact.geom), dtype=wp.int32, ndim=3)
456-
d.contact.efc_address = wp.array(
457-
tile(mjd.contact.efc_address), dtype=wp.int32, ndim=2
480+
con_includemargin_fill = np.concatenate(
481+
[np.repeat(mjd.contact.includemargin, nworld, axis=0), np.zeros(ncon_fill)]
458482
)
483+
con_friction_fill = np.vstack(
484+
[np.repeat(mjd.contact.friction, nworld, axis=0), np.zeros((ncon_fill, 5))]
485+
)
486+
con_solref_fill = np.vstack(
487+
[np.repeat(mjd.contact.solref, nworld, axis=0), np.zeros((ncon_fill, 2))]
488+
)
489+
con_solreffriction_fill = np.vstack(
490+
[np.repeat(mjd.contact.solreffriction, nworld, axis=0), np.zeros((ncon_fill, 2))]
491+
)
492+
con_solimp_fill = np.vstack(
493+
[np.repeat(mjd.contact.solimp, nworld, axis=0), np.zeros((ncon_fill, 5))]
494+
)
495+
con_dim_fill = np.concatenate(
496+
[np.repeat(mjd.contact.dim, nworld, axis=0), np.zeros(ncon_fill)]
497+
)
498+
con_geom_fill = np.vstack(
499+
[np.repeat(mjd.contact.geom, nworld, axis=0), np.zeros((ncon_fill, 2))]
500+
)
501+
502+
d.contact.dist = wp.array(con_dist_fill, dtype=wp.float32, ndim=1)
503+
d.contact.pos = wp.array(con_pos_fill, dtype=wp.vec3f, ndim=1)
504+
d.contact.frame = wp.array(con_frame_fill, dtype=wp.mat33f, ndim=1)
505+
d.contact.includemargin = wp.array(con_includemargin_fill, dtype=wp.float32, ndim=1)
506+
d.contact.friction = wp.array(con_friction_fill, dtype=types.vec5, ndim=1)
507+
d.contact.solref = wp.array(con_solref_fill, dtype=wp.vec2f, ndim=1)
508+
d.contact.solreffriction = wp.array(con_solreffriction_fill, dtype=wp.vec2f, ndim=1)
509+
d.contact.solimp = wp.array(con_solimp_fill, dtype=types.vec5, ndim=1)
510+
d.contact.dim = wp.array(con_dim_fill, dtype=wp.int32, ndim=1)
511+
d.contact.geom = wp.array(con_geom_fill, dtype=wp.vec2i, ndim=1)
512+
d.contact.efc_address = wp.array(con_efc_address, dtype=wp.int32, ndim=1)
513+
d.contact.worldid = wp.array(con_worldid, dtype=wp.int32, ndim=1)
459514

460515
d.xfrc_applied = wp.array(tile(mjd.xfrc_applied), dtype=wp.spatial_vector, ndim=2)
461516
# internal tmp arrays

mujoco/mjx/_src/types.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,15 @@ class ConeType(enum.IntEnum):
126126
ELLIPTIC = mujoco.mjtCone.mjCONE_ELLIPTIC
127127

128128

129+
class vec5f(wp.types.vector(length=5, dtype=wp.float32)):
130+
pass
131+
132+
129133
class vec10f(wp.types.vector(length=10, dtype=wp.float32)):
130134
pass
131135

132136

137+
vec5 = vec5f
133138
vec10 = vec10f
134139
array2df = wp.array2d(dtype=wp.float32)
135140
array3df = wp.array3d(dtype=wp.float32)
@@ -200,8 +205,8 @@ class Model:
200205
jnt_bodyid: wp.array(dtype=wp.int32, ndim=1)
201206
jnt_limited: wp.array(dtype=wp.int32, ndim=1)
202207
jnt_limited_slide_hinge_adr: wp.array(dtype=wp.int32, ndim=1) # warp only
203-
jnt_solref: wp.array(dtype=wp.float32, ndim=2)
204-
jnt_solimp: wp.array(dtype=wp.float32, ndim=2)
208+
jnt_solref: wp.array(dtype=wp.vec2f, ndim=1)
209+
jnt_solimp: wp.array(dtype=vec5, ndim=1)
205210
jnt_type: wp.array(dtype=wp.int32, ndim=1)
206211
jnt_qposadr: wp.array(dtype=wp.int32, ndim=1)
207212
jnt_dofadr: wp.array(dtype=wp.int32, ndim=1)
@@ -246,23 +251,26 @@ class Model:
246251

247252
@wp.struct
248253
class Contact:
249-
dist: wp.array(dtype=wp.float32, ndim=2)
250-
pos: wp.array(dtype=wp.vec3f, ndim=2)
251-
frame: wp.array(dtype=wp.mat33f, ndim=2)
252-
includemargin: wp.array(dtype=wp.float32, ndim=2)
253-
friction: wp.array(dtype=wp.float32, ndim=3)
254-
solref: wp.array(dtype=wp.float32, ndim=3)
255-
solreffriction: wp.array(dtype=wp.float32, ndim=3)
256-
solimp: wp.array(dtype=wp.float32, ndim=3)
257-
dim: wp.array(dtype=wp.int32, ndim=2)
258-
geom: wp.array(dtype=wp.int32, ndim=3)
259-
efc_address: wp.array(dtype=wp.int32, ndim=2)
254+
dist: wp.array(dtype=wp.float32, ndim=1)
255+
pos: wp.array(dtype=wp.vec3f, ndim=1)
256+
frame: wp.array(dtype=wp.mat33f, ndim=1)
257+
includemargin: wp.array(dtype=wp.float32, ndim=1)
258+
friction: wp.array(dtype=vec5, ndim=1)
259+
solref: wp.array(dtype=wp.vec2f, ndim=1)
260+
solreffriction: wp.array(dtype=wp.vec2f, ndim=1)
261+
solimp: wp.array(dtype=vec5, ndim=1)
262+
dim: wp.array(dtype=wp.int32, ndim=1)
263+
geom: wp.array(dtype=wp.vec2i, ndim=1)
264+
efc_address: wp.array(dtype=wp.int32, ndim=1)
265+
worldid: wp.array(dtype=wp.int32, ndim=1)
260266

261267

262268
@wp.struct
263269
class Data:
264270
nworld: int
271+
ncon_total: wp.array(dtype=wp.int32, ndim=1) # warp only
265272
nefc_total: wp.array(dtype=wp.int32, ndim=1) # warp only
273+
nconmax: int
266274
njmax: int
267275
time: float
268276
qpos: wp.array(dtype=wp.float32, ndim=2)

0 commit comments

Comments
 (0)