Skip to content

Commit eab4705

Browse files
authored
Capsule-capsule collision (#69)
capsule-capsule collision.
1 parent a3016f5 commit eab4705

File tree

5 files changed

+302
-32
lines changed

5 files changed

+302
-32
lines changed

mujoco/mjx/_src/collision_driver_test.py

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,34 +14,76 @@
1414
# ==============================================================================
1515
"""Tests the collision driver."""
1616

17-
from absl.testing import absltest
18-
1917
import mujoco
2018
from mujoco import mjx
19+
import numpy as np
20+
from absl.testing import parameterized
2121

2222

23-
class ConvexTest(absltest.TestCase):
23+
class ConvexTest(parameterized.TestCase):
2424
"""Tests the convex contact functions."""
2525

2626
_BOX_PLANE = """
27-
<mujoco>
28-
<worldbody>
29-
<geom size="40 40 40" type="plane"/>
30-
<body pos="0 0 0.7" euler="45 0 0">
31-
<freejoint/>
32-
<geom size="0.5 0.5 0.5" type="box"/>
33-
</body>
34-
</worldbody>
35-
</mujoco>
36-
"""
27+
<mujoco>
28+
<worldbody>
29+
<geom size="40 40 40" type="plane"/>
30+
<body pos="0 0 0.7" euler="45 0 0">
31+
<freejoint/>
32+
<geom size="0.5 0.5 0.5" type="box"/>
33+
</body>
34+
</worldbody>
35+
</mujoco>
36+
"""
37+
38+
_CAPSULE_CAPSULE = """
39+
<mujoco model="two_capsules">
40+
<worldbody>
41+
<body>
42+
<joint type="free"/>
43+
<geom fromto="0.62235904 0.58846647 0.651046 1.5330081 0.33564585 0.977849"
44+
size="0.05" type="capsule"/>
45+
</body>
46+
<body>
47+
<joint type="free"/>
48+
<geom fromto="0.5505271 0.60345304 0.476661 1.3900293 0.30709633 0.932082"
49+
size="0.05" type="capsule"/>
50+
</body>
51+
</worldbody>
52+
</mujoco>
53+
"""
54+
55+
_SPHERE_SPHERE = """
56+
<mujoco>
57+
<worldbody>
58+
<body>
59+
<joint type="free"/>
60+
<geom pos="0 0 0" size="0.2" type="sphere"/>
61+
</body>
62+
<body >
63+
<joint type="free"/>
64+
<geom pos="0 0.3 0" size="0.11" type="sphere"/>
65+
</body>
66+
</worldbody>
67+
</mujoco>
68+
"""
3769

38-
def test_box_plane(self):
39-
"""Tests box collision with a plane."""
40-
m = mujoco.MjModel.from_xml_string(self._BOX_PLANE)
70+
@parameterized.parameters(
71+
(_BOX_PLANE),
72+
(_SPHERE_SPHERE),
73+
(_CAPSULE_CAPSULE),
74+
)
75+
def test_convex_collision(self, xml_string):
76+
"""Tests convex collision with different geometries."""
77+
m = mujoco.MjModel.from_xml_string(xml_string)
4178
d = mujoco.MjData(m)
4279
mujoco.mj_forward(m, d)
43-
4480
mx = mjx.put_model(m)
4581
dx = mjx.put_data(m, d)
46-
4782
mjx.collision(mx, dx)
83+
mujoco.mj_step(m, d)
84+
actual_dist = dx.contact.dist.numpy()[0]
85+
actual_pos = dx.contact.pos.numpy()[0, :]
86+
actual_frame = dx.contact.frame.numpy()[0].flatten()
87+
np.testing.assert_array_almost_equal(actual_dist, d.contact.dist[0], 4)
88+
np.testing.assert_array_almost_equal(actual_pos, d.contact.pos[0], 4)
89+
np.testing.assert_array_almost_equal(actual_frame, d.contact.frame[0], 4)

mujoco/mjx/_src/collision_functions.py

Lines changed: 89 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .types import Data
2020
from .types import GeomType
2121
from .math import make_frame
22+
from .math import closest_segment_to_segment_points
2223
from .math import normalize_with_norm
2324
from .support import group_key
2425
from .support import mat33_from_cols
@@ -142,8 +143,17 @@ def _get_info(
142143

143144
return _get_info
144145

146+
145147
@wp.func
146-
def write_contact(d: Data, dist: float, pos: wp.vec3, frame: wp.mat33, margin: float, geoms: wp.vec2i, worldid: int):
148+
def write_contact(
149+
d: Data,
150+
dist: float,
151+
pos: wp.vec3,
152+
frame: wp.mat33,
153+
margin: float,
154+
geoms: wp.vec2i,
155+
worldid: int,
156+
):
147157
active = (dist - margin) < 0
148158
if active:
149159
index = wp.atomic_add(d.ncon, 0, 1)
@@ -154,6 +164,7 @@ def write_contact(d: Data, dist: float, pos: wp.vec3, frame: wp.mat33, margin: f
154164
d.contact.geom[index] = geoms
155165
d.contact.worldid[index] = worldid
156166

167+
157168
@wp.func
158169
def _plane_sphere(
159170
plane_normal: wp.vec3, plane_pos: wp.vec3, sphere_pos: wp.vec3, sphere_radius: float
@@ -164,28 +175,98 @@ def _plane_sphere(
164175

165176

166177
@wp.func
167-
def plane_sphere(plane: GeomPlane, sphere: GeomSphere, worldid: int, d: Data, margin: float, geom_indices: wp.vec2i):
178+
def plane_sphere(
179+
plane: GeomPlane,
180+
sphere: GeomSphere,
181+
worldid: int,
182+
d: Data,
183+
margin: float,
184+
geom_indices: wp.vec2i,
185+
):
168186
dist, pos = _plane_sphere(plane.normal, plane.pos, sphere.pos, sphere.radius)
169187

170188
write_contact(d, dist, pos, make_frame(plane.normal), margin, geom_indices, worldid)
171189

172190

173191
@wp.func
174-
def sphere_sphere(sphere1: GeomSphere, sphere2: GeomSphere, worldid: int, d: Data, margin: float, geom_indices: wp.vec2i):
175-
dir = sphere1.pos - sphere2.pos
192+
def _sphere_sphere(
193+
pos1: wp.vec3,
194+
radius1: float,
195+
pos2: wp.vec3,
196+
radius2: float,
197+
worldid: int,
198+
d: Data,
199+
margin: float,
200+
geom_indices: wp.vec2i,
201+
):
202+
dir = pos2 - pos1
176203
dist = wp.length(dir)
177204
if dist == 0.0:
178205
n = wp.vec3(1.0, 0.0, 0.0)
179206
else:
180207
n = dir / dist
181-
dist = dist - (sphere1.radius + sphere2.radius)
182-
pos = sphere1.pos + n * (sphere1.radius + 0.5 * dist)
208+
dist = dist - (radius1 + radius2)
209+
pos = pos1 + n * (radius1 + 0.5 * dist)
183210

184211
write_contact(d, dist, pos, make_frame(n), margin, geom_indices, worldid)
185212

186213

187214
@wp.func
188-
def plane_capsule(plane: GeomPlane, cap: GeomCapsule, worldid: int, d: Data, margin: float, geom_indices: wp.vec2i):
215+
def sphere_sphere(
216+
sphere1: GeomSphere,
217+
sphere2: GeomSphere,
218+
worldid: int,
219+
d: Data,
220+
margin: float,
221+
geom_indices: wp.vec2i,
222+
):
223+
_sphere_sphere(
224+
sphere1.pos,
225+
sphere1.radius,
226+
sphere2.pos,
227+
sphere2.radius,
228+
worldid,
229+
d,
230+
margin,
231+
geom_indices,
232+
)
233+
234+
235+
@wp.func
236+
def capsule_capsule(
237+
cap1: GeomCapsule,
238+
cap2: GeomCapsule,
239+
worldid: int,
240+
d: Data,
241+
margin: float,
242+
geom_indices: wp.vec2i,
243+
):
244+
axis1 = wp.vec3(cap1.rot[0, 2], cap1.rot[1, 2], cap1.rot[2, 2])
245+
axis2 = wp.vec3(cap2.rot[0, 2], cap2.rot[1, 2], cap2.rot[2, 2])
246+
length1 = cap1.halfsize
247+
length2 = cap2.halfsize
248+
seg1 = axis1 * length1
249+
seg2 = axis2 * length2
250+
251+
pt1, pt2 = closest_segment_to_segment_points(
252+
cap1.pos - seg1,
253+
cap1.pos + seg1,
254+
cap2.pos - seg2,
255+
cap2.pos + seg2,
256+
)
257+
258+
_sphere_sphere(pt1, cap1.radius, pt2, cap2.radius, worldid, d, margin, geom_indices)
259+
260+
261+
@wp.func
262+
def plane_capsule(
263+
plane: GeomPlane,
264+
cap: GeomCapsule,
265+
worldid: int,
266+
d: Data,
267+
margin: float,
268+
geom_indices: wp.vec2i,
269+
):
189270
"""Calculates two contacts between a capsule and a plane."""
190271
n = plane.normal
191272
axis = wp.vec3(cap.rot[0, 2], cap.rot[1, 2], cap.rot[2, 2])
@@ -212,6 +293,7 @@ def plane_capsule(plane: GeomPlane, cap: GeomCapsule, worldid: int, d: Data, mar
212293
(GeomType.PLANE.value, GeomType.SPHERE.value): plane_sphere,
213294
(GeomType.SPHERE.value, GeomType.SPHERE.value): sphere_sphere,
214295
(GeomType.PLANE.value, GeomType.CAPSULE.value): plane_capsule,
296+
(GeomType.CAPSULE.value, GeomType.CAPSULE.value): capsule_capsule,
215297
}
216298

217299

mujoco/mjx/_src/forward.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,6 @@ def euler(m: Model, d: Data):
166166
# integrate damping implicitly
167167

168168
def eulerdamp_sparse(m: Model, d: Data):
169-
170169
@wp.kernel
171170
def add_damping_sum_qfrc_kernel_sparse(m: Model, d: Data):
172171
worldId, tid = wp.tid()
@@ -191,10 +190,11 @@ def add_damping_sum_qfrc_kernel_sparse(m: Model, d: Data):
191190
)
192191

193192
def eulerdamp_fused_dense(m: Model, d: Data):
194-
195193
def tile_eulerdamp(adr: int, size: int, tilesize: int):
196194
@wp.kernel
197-
def eulerdamp(m: Model, d: Data, damping: wp.array(dtype=wp.float32), leveladr: int):
195+
def eulerdamp(
196+
m: Model, d: Data, damping: wp.array(dtype=wp.float32), leveladr: int
197+
):
198198
worldid, nodeid = wp.tid()
199199
dofid = m.qLD_tile[leveladr + nodeid]
200200
M_tile = wp.tile_load(
@@ -204,8 +204,12 @@ def eulerdamp(m: Model, d: Data, damping: wp.array(dtype=wp.float32), leveladr:
204204
damping_scaled = damping_tile * m.opt.timestep
205205
qm_integration_tile = wp.tile_diag_add(M_tile, damping_scaled)
206206

207-
qfrc_smooth_tile = wp.tile_load(d.qfrc_smooth[worldid], shape=(tilesize,), offset=(dofid,))
208-
qfrc_constraint_tile = wp.tile_load(d.qfrc_constraint[worldid], shape=(tilesize,), offset=(dofid,))
207+
qfrc_smooth_tile = wp.tile_load(
208+
d.qfrc_smooth[worldid], shape=(tilesize,), offset=(dofid,)
209+
)
210+
qfrc_constraint_tile = wp.tile_load(
211+
d.qfrc_constraint[worldid], shape=(tilesize,), offset=(dofid,)
212+
)
209213

210214
qfrc_tile = qfrc_smooth_tile + qfrc_constraint_tile
211215

@@ -225,11 +229,11 @@ def eulerdamp(m: Model, d: Data, damping: wp.array(dtype=wp.float32), leveladr:
225229
tile_eulerdamp(beg, end - beg, int(qLD_tilesize[i]))
226230

227231
if not m.opt.disableflags & DisableBit.EULERDAMP.value:
228-
if (m.opt.is_sparse):
232+
if m.opt.is_sparse:
229233
eulerdamp_sparse(m, d)
230234
else:
231235
eulerdamp_fused_dense(m, d)
232-
236+
233237
_advance(m, d, d.act_dot, d.qacc_integration)
234238
else:
235239
_advance(m, d, d.act_dot, d.qacc)

mujoco/mjx/_src/math.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16+
from typing import Tuple
1617
import warp as wp
1718

1819
from . import types
@@ -185,3 +186,57 @@ def normalize_with_norm(x: wp.vec3):
185186
if norm == 0.0:
186187
return x, 0.0
187188
return x / norm, norm
189+
190+
191+
@wp.func
192+
def closest_segment_point(a: wp.vec3, b: wp.vec3, pt: wp.vec3) -> wp.vec3:
193+
"""Returns the closest point on the a-b line segment to a point pt."""
194+
ab = b - a
195+
t = wp.dot(pt - a, ab) / (wp.dot(ab, ab) + 1e-6)
196+
return a + wp.clamp(t, 0.0, 1.0) * ab
197+
198+
199+
@wp.func
200+
def closest_segment_point_and_dist(
201+
a: wp.vec3, b: wp.vec3, pt: wp.vec3
202+
) -> Tuple[wp.vec3, wp.float32]:
203+
"""Returns closest point on the line segment and the distance squared."""
204+
closest = closest_segment_point(a, b, pt)
205+
dist = wp.dot((pt - closest), (pt - closest))
206+
return closest, dist
207+
208+
209+
@wp.func
210+
def closest_segment_to_segment_points(
211+
a0: wp.vec3, a1: wp.vec3, b0: wp.vec3, b1: wp.vec3
212+
) -> Tuple[wp.vec3, wp.vec3]:
213+
"""Returns closest points between two line segments."""
214+
215+
dir_a, len_a = normalize_with_norm(a1 - a0)
216+
dir_b, len_b = normalize_with_norm(b1 - b0)
217+
218+
half_len_a = len_a * 0.5
219+
half_len_b = len_b * 0.5
220+
a_mid = a0 + dir_a * half_len_a
221+
b_mid = b0 + dir_b * half_len_b
222+
223+
trans = a_mid - b_mid
224+
225+
dira_dot_dirb = wp.dot(dir_a, dir_b)
226+
dira_dot_trans = wp.dot(dir_a, trans)
227+
dirb_dot_trans = wp.dot(dir_b, trans)
228+
denom = 1.0 - dira_dot_dirb * dira_dot_dirb
229+
230+
orig_t_a = (-dira_dot_trans + dira_dot_dirb * dirb_dot_trans) / (denom + 1e-6)
231+
orig_t_b = dirb_dot_trans + orig_t_a * dira_dot_dirb
232+
t_a = wp.clamp(orig_t_a, -half_len_a, half_len_a)
233+
t_b = wp.clamp(orig_t_b, -half_len_b, half_len_b)
234+
235+
best_a = a_mid + dir_a * t_a
236+
best_b = b_mid + dir_b * t_b
237+
238+
new_a, d1 = closest_segment_point_and_dist(a0, a1, best_b)
239+
new_b, d2 = closest_segment_point_and_dist(b0, b1, best_a)
240+
if d1 < d2:
241+
return new_a, best_b
242+
return best_a, new_b

0 commit comments

Comments
 (0)