Skip to content

Commit a179363

Browse files
authored
Merge pull request #74 from erikfrey/collision_driver_fixes
Fixes to collision driver.
2 parents ebf1409 + afcca0d commit a179363

File tree

5 files changed

+111
-63
lines changed

5 files changed

+111
-63
lines changed

mujoco/mjx/_src/collision_driver.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,11 @@ def get_contact_solver_params_kernel(
327327
p2 = m.geom_priority[g2]
328328
mix = where(p1 == p2, mix, where(p1 > p2, 1.0, 0.0))
329329

330+
condim1 = m.geom_condim[g1]
331+
condim2 = m.geom_condim[g2]
332+
condim = where(p1 == p2, wp.max(condim1, condim2), where(p1 > p2, condim1, condim2))
333+
d.contact.dim[tid] = condim
334+
330335
if m.geom_solref[g1].x > 0.0 and m.geom_solref[g2].x > 0.0:
331336
d.contact.solref[tid] = mix * m.geom_solref[g1] + (1.0 - mix) * m.geom_solref[g2]
332337
else:
@@ -344,6 +349,7 @@ def group_contacts_by_type_kernel(
344349
d: Data,
345350
):
346351
worldid, tid = wp.tid()
352+
347353
if tid >= d.broadphase_result_count[worldid]:
348354
return
349355

@@ -490,16 +496,30 @@ def _nxn_broadphase(m: Model, d: Data):
490496
margin2 = m.geom_margin[geom2]
491497
pos1 = d.geom_xpos[worldid, geom1]
492498
pos2 = d.geom_xpos[worldid, geom2]
499+
xmat1 = d.geom_xmat[worldid, geom1]
500+
xmat2 = d.geom_xmat[worldid, geom2]
493501
size1 = m.geom_rbound[geom1]
494502
size2 = m.geom_rbound[geom2]
495503

496504
bound = size1 + size2 + wp.max(margin1, margin2)
497505
dif = pos2 - pos1
498-
sphere_filter = wp.dot(dif, dif) <= bound * bound
506+
507+
if size1 != 0.0 and size2 != 0.0:
508+
# neither geom is a plane
509+
dist_sq = wp.dot(dif, dif)
510+
bounds_filter = dist_sq <= bound * bound
511+
elif size1 == 0.0:
512+
# geom1 is a plane
513+
dist = wp.dot(dif, wp.vec3(xmat1[0, 2], xmat1[1, 2], xmat1[2, 2]))
514+
bounds_filter = dist <= bound
515+
else:
516+
# geom2 is a plane
517+
dist = wp.dot(-dif, wp.vec3(xmat2[0, 2], xmat2[1, 2], xmat2[2, 2]))
518+
bounds_filter = dist <= bound
499519

500520
geom_filter = _geom_filter(m, geom1, geom2, filterparent)
501521

502-
if sphere_filter and geom_filter:
522+
if bounds_filter and geom_filter:
503523
pairid = wp.atomic_add(d.broadphase_result_count, 0, 1)
504524
d.broadphase_pairs[worldid, pairid] = _geom_pair(m, geom1, geom2)
505525

mujoco/mjx/_src/collision_driver_test.py

Lines changed: 62 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,51 @@
1717
import mujoco
1818
from mujoco import mjx
1919
import numpy as np
20+
from absl.testing import absltest
2021
from absl.testing import parameterized
2122

2223

23-
class ConvexTest(parameterized.TestCase):
24-
"""Tests the convex contact functions."""
24+
class PrimitiveTest(parameterized.TestCase):
25+
"""Tests the primitive contact functions."""
2526

26-
_BOX_PLANE = """
27+
_MJCFS = {
28+
"box_plane": """
2729
<mujoco>
2830
<worldbody>
2931
<geom size="40 40 40" type="plane"/>
30-
<body pos="0 0 0.7" euler="45 0 0">
32+
<body pos="0 0 0.3" euler="45 0 0">
3133
<freejoint/>
3234
<geom size="0.5 0.5 0.5" type="box"/>
3335
</body>
3436
</worldbody>
3537
</mujoco>
36-
"""
37-
38-
_CAPSULE_CAPSULE = """
38+
""",
39+
"plane_sphere": """
40+
<mujoco>
41+
<worldbody>
42+
<geom size="40 40 40" type="plane"/>
43+
<body pos="0 0 0.2" euler="45 0 0">
44+
<freejoint/>
45+
<geom size="0.5" type="sphere"/>
46+
</body>
47+
</worldbody>
48+
</mujoco>
49+
""",
50+
"sphere_sphere": """
51+
<mujoco>
52+
<worldbody>
53+
<body>
54+
<joint type="free"/>
55+
<geom pos="0 0 0" size="0.2" type="sphere"/>
56+
</body>
57+
<body >
58+
<joint type="free"/>
59+
<geom pos="0 0.3 0" size="0.11" type="sphere"/>
60+
</body>
61+
</worldbody>
62+
</mujoco>
63+
""",
64+
"capsule_capsule": """
3965
<mujoco model="two_capsules">
4066
<worldbody>
4167
<body>
@@ -50,40 +76,45 @@ class ConvexTest(parameterized.TestCase):
5076
</body>
5177
</worldbody>
5278
</mujoco>
53-
"""
54-
55-
_SPHERE_SPHERE = """
79+
""",
80+
"plane_capsule": """
5681
<mujoco>
5782
<worldbody>
58-
<body>
59-
<joint type="free"/>
60-
<geom pos="0 0 0" size="0.2" type="sphere"/>
61-
</body>
62-
<body >
63-
<joint type="free"/>
64-
<geom pos="0 0.3 0" size="0.11" type="sphere"/>
83+
<geom size="40 40 40" type="plane"/>
84+
<body pos="0 0 0.0" euler="30 30 0">
85+
<freejoint/>
86+
<geom size="0.05 0.05" type="capsule"/>
6587
</body>
6688
</worldbody>
6789
</mujoco>
68-
"""
90+
""",
91+
}
6992

7093
@parameterized.parameters(
71-
(_BOX_PLANE),
72-
(_SPHERE_SPHERE),
73-
(_CAPSULE_CAPSULE),
94+
"box_plane",
95+
"plane_sphere",
96+
"sphere_sphere",
97+
"plane_capsule",
98+
"capsule_capsule",
7499
)
75-
def test_convex_collision(self, xml_string):
76-
"""Tests convex collision with different geometries."""
77-
m = mujoco.MjModel.from_xml_string(xml_string)
100+
def test_contact(self, name):
101+
"""Tests contact calculation with different collision functions."""
102+
m = mujoco.MjModel.from_xml_string(self._MJCFS[name])
78103
d = mujoco.MjData(m)
79104
mujoco.mj_forward(m, d)
80105
mx = mjx.put_model(m)
81106
dx = mjx.put_data(m, d)
82107
mjx.collision(mx, dx)
83-
mujoco.mj_step(m, d)
84-
actual_dist = dx.contact.dist.numpy()[0]
85-
actual_pos = dx.contact.pos.numpy()[0, :]
86-
actual_frame = dx.contact.frame.numpy()[0].flatten()
87-
np.testing.assert_array_almost_equal(actual_dist, d.contact.dist[0], 4)
88-
np.testing.assert_array_almost_equal(actual_pos, d.contact.pos[0], 4)
89-
np.testing.assert_array_almost_equal(actual_frame, d.contact.frame[0], 4)
108+
mujoco.mj_collision(m, d)
109+
self.assertEqual(d.ncon, dx.ncon.numpy()[0])
110+
for i in range(d.ncon):
111+
actual_dist = dx.contact.dist.numpy()[i]
112+
actual_pos = dx.contact.pos.numpy()[i, :]
113+
actual_frame = dx.contact.frame.numpy()[i].flatten()
114+
np.testing.assert_array_almost_equal(actual_dist, d.contact.dist[i], 4)
115+
np.testing.assert_array_almost_equal(actual_pos, d.contact.pos[i], 4)
116+
np.testing.assert_array_almost_equal(actual_frame, d.contact.frame[i], 4)
117+
118+
119+
if __name__ == "__main__":
120+
absltest.main()

mujoco/mjx/_src/collision_functions.py

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from .math import closest_segment_to_segment_points
2323
from .math import normalize_with_norm
2424
from .support import group_key
25-
from .support import mat33_from_cols
2625

2726

2827
@wp.struct
@@ -279,7 +278,8 @@ def plane_capsule(
279278
else:
280279
b = wp.vec3(0.0, 0.0, 1.0)
281280

282-
frame = mat33_from_cols(n, b, wp.cross(n, b))
281+
c = wp.cross(n, b)
282+
frame = wp.mat33(n[0], n[1], n[2], b[0], b[1], b[2], c[0], c[1], c[2])
283283
segment = axis * cap.halfsize
284284

285285
dist1, pos1 = _plane_sphere(n, plane.pos, cap.pos + segment, cap.radius)
@@ -289,13 +289,6 @@ def plane_capsule(
289289
write_contact(d, dist2, pos2, frame, margin, geom_indices, worldid)
290290

291291

292-
@wp.func
293-
def distance_point_plane(plane_normal: wp.vec3, plane_pos: wp.vec3, point: wp.vec3):
294-
plane_normal = wp.normalize(plane_normal)
295-
dist = wp.dot(point - plane_pos, plane_normal)
296-
return dist, plane_pos - plane_normal * dist
297-
298-
299292
@wp.func
300293
def plane_box(
301294
plane: GeomPlane,
@@ -305,29 +298,31 @@ def plane_box(
305298
margin: float,
306299
geom_indices: wp.vec2i,
307300
):
308-
contact_count = int(0)
301+
count = int(0)
302+
corner = wp.vec3()
303+
dist = wp.dot(box.pos - plane.pos, plane.normal)
309304

310-
# Check all 8 corners of the box
305+
# test all corners, pick bottom 4
311306
for i in range(8):
312-
corner = wp.vec3(box.size.x * 0.5, box.size.y * 0.5, box.size.z * 0.5)
313-
if i % 2 == 0:
314-
corner.x = -corner.x
315-
if (i // 2) % 2 == 0:
316-
corner.y = -corner.y
317-
if i < 4:
318-
corner.z = -corner.z
319-
320-
corner_world = box.rot * (corner) + box.pos
321-
322-
dist, pos = distance_point_plane(plane.normal, plane.pos, corner_world)
323-
324-
if dist < 0.0:
325-
write_contact(
326-
d, dist, pos, make_frame(plane.normal), margin, geom_indices, worldid
327-
)
328-
contact_count += 1
329-
330-
if contact_count >= 4:
307+
# get corner in local coordinates
308+
corner.x = wp.select(i & 1, -box.size.x, box.size.x)
309+
corner.y = wp.select(i & 2, -box.size.y, box.size.y)
310+
corner.z = wp.select(i & 4, -box.size.z, box.size.z)
311+
312+
# get corner in global coordinates relative to box center
313+
corner = box.rot * corner
314+
315+
# compute distance to plane, skip if too far or pointing up
316+
ldist = wp.dot(plane.normal, corner)
317+
if dist + ldist > margin or ldist > 0:
318+
continue
319+
320+
cdist = dist + ldist
321+
frame = make_frame(plane.normal)
322+
pos = corner + box.pos + (plane.normal * cdist / -2.0)
323+
write_contact(d, cdist, pos, frame, margin, geom_indices, worldid)
324+
count += 1
325+
if count >= 4:
331326
break
332327

333328

mujoco/mjx/_src/io.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ def put_model(mjm: mujoco.MjModel) -> types.Model:
248248
m.geom_bodyid = wp.array(mjm.geom_bodyid, dtype=wp.int32, ndim=1)
249249
m.geom_conaffinity = wp.array(mjm.geom_conaffinity, dtype=wp.int32, ndim=1)
250250
m.geom_contype = wp.array(mjm.geom_contype, dtype=wp.int32, ndim=1)
251+
m.geom_condim = wp.array(mjm.geom_condim, dtype=wp.int32, ndim=1)
251252
m.geom_pos = wp.array(mjm.geom_pos, dtype=wp.vec3, ndim=1)
252253
m.geom_quat = wp.array(mjm.geom_quat, dtype=wp.quat, ndim=1)
253254
m.geom_size = wp.array(mjm.geom_size, dtype=wp.vec3, ndim=1)

mujoco/mjx/_src/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ class Model:
292292
geom_bodyid: wp.array(dtype=wp.int32, ndim=1)
293293
geom_conaffinity: wp.array(dtype=wp.int32, ndim=1)
294294
geom_contype: wp.array(dtype=wp.int32, ndim=1)
295+
geom_condim: wp.array(dtype=wp.int32, ndim=1)
295296
geom_pos: wp.array(dtype=wp.vec3, ndim=1)
296297
geom_quat: wp.array(dtype=wp.quat, ndim=1)
297298
geom_size: wp.array(dtype=wp.vec3, ndim=1)

0 commit comments

Comments
 (0)