Skip to content

Commit 1a7ec97

Browse files
btabacopybara-github
authored andcommitted
Import google-deepmind/mujoco_warp from GitHub.
PiperOrigin-RevId: 794774483 Change-Id: I2d8d623bbdfe9f28d4ab96bfc3c9128caa9fe86c
1 parent eff4dda commit 1a7ec97

24 files changed

+1579
-1364
lines changed

mjx/mujoco/mjx/third_party/mujoco_warp/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
from mujoco.mjx.third_party.mujoco_warp._src.forward import fwd_velocity as fwd_velocity
3737
from mujoco.mjx.third_party.mujoco_warp._src.forward import implicit as implicit
3838
from mujoco.mjx.third_party.mujoco_warp._src.forward import rungekutta4 as rungekutta4
39+
from mujoco.mjx.third_party.mujoco_warp._src.forward import step1 as step1
40+
from mujoco.mjx.third_party.mujoco_warp._src.forward import step2 as step2
3941
from mujoco.mjx.third_party.mujoco_warp._src.inverse import inverse as inverse
4042
from mujoco.mjx.third_party.mujoco_warp._src.io import get_data_into as get_data_into
4143
from mujoco.mjx.third_party.mujoco_warp._src.io import make_data as make_data

mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_convex.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
# TODO(team): improve compile time to enable backward pass
3939
wp.config.enable_backward = False
4040

41-
MULTI_CONTACT_COUNT = 4
41+
MULTI_CONTACT_COUNT = 8
4242
mat3c = wp.types.matrix(shape=(MULTI_CONTACT_COUNT, 3), dtype=float)
4343

4444
_CONVEX_COLLISION_PAIRS = [
@@ -288,6 +288,7 @@ def ccd_kernel(
288288

289289
points = mat3c()
290290

291+
# TODO(kbayes): remove legacy GJK once multicontact can be enabled
291292
if default_gjk:
292293
simplex, normal = gjk_legacy(
293294
gjk_iterations,
@@ -349,10 +350,11 @@ def ccd_kernel(
349350
count = 0
350351
return
351352

352-
for i in range(count):
353-
points[i] = 0.5 * (witness1[i] + witness2[i])
354-
normal = witness1[0] - witness2[0]
355-
frame = make_frame(normal)
353+
for i in range(count):
354+
points[i] = 0.5 * (witness1[i] + witness2[i])
355+
normal = witness1[0] - witness2[0]
356+
frame = make_frame(normal)
357+
356358
for i in range(count):
357359
# limit maximum number of contacts with height field
358360
if _max_contacts_height_field(ngeom, geom_type, geompair2hfgeompair, g1, g2, worldid, ncon_hfield_out):

mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_driver_test.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,14 @@
2626
from mujoco.mjx.third_party.mujoco_warp._src import test_util
2727
from mujoco.mjx.third_party.mujoco_warp._src import types
2828

29+
_TOLERANCE = 5e-5
30+
31+
32+
def _assert_eq(a, b, name):
33+
tol = _TOLERANCE * 10
34+
err_msg = f"mismatch: {name}"
35+
np.testing.assert_allclose(a, b, err_msg=err_msg, atol=tol, rtol=tol)
36+
2937

3038
class CollisionTest(parameterized.TestCase):
3139
"""Tests the collision contact functions."""
@@ -863,7 +871,35 @@ def test_min_friction(self):
863871
self.assertEqual(d.ncon.numpy()[0], 1)
864872
np.testing.assert_allclose(d.contact.friction.numpy()[0], types.MJ_MINMU)
865873

866-
# TODO(team): test contact parameter mixing
874+
@parameterized.parameters(("1", "1"), ("1", "2"), ("2", "1"))
875+
def test_contact_parameter_mixing(self, priority1, priority2):
876+
_, mjd, m, d = test_util.fixture(
877+
xml=f"""
878+
<mujoco>
879+
<worldbody>
880+
<geom type="plane" size="10 10 .001" friction=".01 .02 .03" priority="{priority1}" condim="1" margin=".002"/>
881+
<body>
882+
<geom type="sphere" size=".1" friction=".123 .456 .789" priority="{priority2}" condim="3" margin=".004"/>
883+
<freejoint/>
884+
</body>
885+
</worldbody>
886+
<keyframe>
887+
<key qpos="0 0 .075 1 0 0 0"/>
888+
</keyframe>
889+
</mujoco>
890+
""",
891+
keyframe=0,
892+
)
893+
894+
mjwarp.collision(m, d)
895+
896+
ncon = d.ncon.numpy()[0]
897+
_assert_eq(ncon, 1, "ncon")
898+
_assert_eq(d.contact.friction.numpy()[0], mjd.contact.friction[0], "friction")
899+
_assert_eq(d.contact.solref.numpy()[0], mjd.contact.solref[0], "solref")
900+
_assert_eq(d.contact.solimp.numpy()[0], mjd.contact.solimp[0], "solimp")
901+
_assert_eq(d.contact.includemargin.numpy()[0], mjd.contact.includemargin[0], "includemargin")
902+
_assert_eq(d.contact.dim.numpy()[0], mjd.contact.dim[0], "dim")
867903

868904

869905
if __name__ == "__main__":

mjx/mujoco/mjx/third_party/mujoco_warp/_src/collision_gjk.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
mat43 = wp.types.matrix(shape=(4, 3), dtype=float)
4242
mat63 = wp.types.matrix(shape=(6, 3), dtype=float)
4343

44-
MULTI_CONTACT_COUNT = 4
44+
MULTI_CONTACT_COUNT = 8
4545
mat3c = wp.types.matrix(shape=(MULTI_CONTACT_COUNT, 3), dtype=float)
4646

4747

@@ -93,6 +93,13 @@ class SupportPoint:
9393
vertex_index: int
9494

9595

96+
@wp.func
97+
def discrete_geoms(g1: int, g2: int):
98+
return (g1 == int(GeomType.MESH.value) or g1 == int(GeomType.BOX.value) or g1 == int(GeomType.HFIELD.value)) and (
99+
g2 == int(GeomType.MESH.value) or g2 == int(GeomType.BOX.value) or g2 == int(GeomType.HFIELD.value)
100+
)
101+
102+
96103
@wp.func
97104
def _support_margin(geom: Geom, geomtype: int, dir: wp.vec3):
98105
sp = SupportPoint()
@@ -590,6 +597,7 @@ def _gjk(
590597
use_margin: bool,
591598
):
592599
"""Find distance within a tolerance between two geoms."""
600+
is_discrete = discrete_geoms(geomtype1, geomtype2)
593601
cutoff2 = cutoff * cutoff
594602
simplex = mat43()
595603
simplex1 = mat43()
@@ -598,7 +606,7 @@ def _gjk(
598606
simplex_index2 = wp.vec4i()
599607
n = int(0)
600608
coordinates = wp.vec4() # barycentric coordinates
601-
epsilon = 0.5 * tolerance * tolerance
609+
epsilon = wp.where(is_discrete, 0.0, 0.5 * tolerance * tolerance)
602610

603611
# set initial guess
604612
x_k = x1_0 - x2_0
@@ -1193,12 +1201,14 @@ def _polytope4(
11931201

11941202

11951203
@wp.func
1196-
def _epa(tolerance2: float, epa_iterations: int, pt: Polytope, geom1: Geom, geom2: Geom, geomtype1: int, geomtype2: int):
1204+
def _epa(tolerance: float, epa_iterations: int, pt: Polytope, geom1: Geom, geom2: Geom, geomtype1: int, geomtype2: int):
11971205
"""Recover penetration data from two geoms in contact given an initial polytope."""
1206+
is_discrete = discrete_geoms(geomtype1, geomtype2)
11981207
upper = FLOAT_MAX
11991208
upper2 = FLOAT_MAX
12001209
idx = int(-1)
12011210
pidx = int(-1)
1211+
epsilon = wp.where(is_discrete, 1e-15, tolerance * tolerance)
12021212

12031213
for k in range(epa_iterations):
12041214
pidx = int(idx)
@@ -1235,9 +1245,19 @@ def _epa(tolerance2: float, epa_iterations: int, pt: Polytope, geom1: Geom, geom
12351245
upper = upper_k
12361246
upper2 = upper * upper
12371247

1238-
if upper - lower < tolerance2:
1248+
if upper - lower < epsilon:
12391249
break
12401250

1251+
# check if vertex wi is a repeated support point
1252+
if is_discrete:
1253+
found_repeated = bool(False)
1254+
for i in range(pt.nvert - 1):
1255+
if pt.vert_index1[i] == pt.vert_index1[wi] and pt.vert_index2[i] == pt.vert_index2[wi]:
1256+
found_repeated = True
1257+
break
1258+
if found_repeated:
1259+
break
1260+
12411261
pt.nmap = _delete_face(pt, idx)
12421262
pt.nhorizon = _add_edge(pt, pt.face[idx][0], pt.face[idx][1])
12431263
pt.nhorizon = _add_edge(pt, pt.face[idx][1], pt.face[idx][2])
@@ -1251,7 +1271,7 @@ def _epa(tolerance2: float, epa_iterations: int, pt: Polytope, geom1: Geom, geom
12511271
if pt.face_index[i] == -2:
12521272
continue
12531273

1254-
if wp.dot(pt.face_pr[i], pt.vert[wi]) - pt.face_norm2[i] > MJ_MINVAL:
1274+
if wp.dot(pt.face_pr[i], pt.vert[wi]) - pt.face_norm2[i] > 1e-10:
12551275
pt.nmap = _delete_face(pt, i)
12561276
pt.nhorizon = _add_edge(pt, pt.face[i][0], pt.face[i][1])
12571277
pt.nhorizon = _add_edge(pt, pt.face[i][1], pt.face[i][2])
@@ -1694,7 +1714,7 @@ def plane_normal(v1: wp.vec3, v2: wp.vec3, n: wp.vec3):
16941714

16951715
@wp.func
16961716
def halfspace(a: wp.vec3, n: wp.vec3, p: wp.vec3):
1697-
return wp.dot(p - a, n) > -MJ_MINVAL
1717+
return wp.dot(p - a, n) > -1e-10
16981718

16991719

17001720
@wp.func
@@ -1725,7 +1745,7 @@ def polygon_clip(face1: polyverts, nface1: int, face2: polyverts, nface2: int, n
17251745
# compute plane normal and distance to plane for each vertex
17261746
pn = polyverts()
17271747
pd = polyvec()
1728-
for i in range(nface1):
1748+
for i in range(nface1 - 1):
17291749
pdi, pni = plane_normal(face1[i], face1[i + 1], n)
17301750
pd[i] = pdi
17311751
pn[i] = pni
@@ -1734,14 +1754,11 @@ def polygon_clip(face1: polyverts, nface1: int, face2: polyverts, nface2: int, n
17341754
pn[nface1 - 1] = pni
17351755

17361756
# reserve 2 * max_sides as max sides for a clipped polygon
1737-
polygon1 = polyclip()
1738-
polygon2 = polyclip()
1757+
polygon = polyclip()
1758+
clipped = polyclip()
17391759
npolygon = nface2
17401760
nclipped = int(0)
17411761

1742-
polygon = polygon1
1743-
clipped = polygon2
1744-
17451762
for i in range(nface2):
17461763
polygon[i] = face2[i]
17471764

@@ -1768,7 +1785,7 @@ def polygon_clip(face1: polyverts, nface1: int, face2: polyverts, nface2: int, n
17681785

17691786
# add new vertex to clipped polygon where PQ intersects the clipping edge
17701787
t, res = plane_intersect(pn[e], pd[e], P, Q)
1771-
if t < 0.0 or t > 1.0:
1788+
if t >= 0.0 and t <= 1.0:
17721789
clipped[nclipped] = res
17731790
nclipped += 1
17741791

@@ -1790,8 +1807,8 @@ def polygon_clip(face1: polyverts, nface1: int, face2: polyverts, nface2: int, n
17901807
# no pruning needed
17911808
for i in range(npolygon):
17921809
witness2[i] = polygon[i]
1793-
witness1[i] = witness2[i] + dir
1794-
return npolygon, witness2, witness1
1810+
witness1[i] = witness2[i] - dir
1811+
return npolygon, witness1, witness2
17951812

17961813

17971814
# recover multiple contacts from EPA polytope
@@ -2125,11 +2142,14 @@ def ccd(
21252142
witness2[0] = result.x2
21262143
return result.dist, 1, witness1, witness2
21272144

2128-
dist, x1, x2, idx = _epa(tolerance * tolerance, epa_iterations, pt, geom1, geom2, geomtype1, geomtype2)
2145+
dist, x1, x2, idx = _epa(tolerance, epa_iterations, pt, geom1, geom2, geomtype1, geomtype2)
2146+
if idx == -1:
2147+
return FLOAT_MAX, 0, witness1, witness2
2148+
21292149
if (
21302150
multiccd
2131-
and (geomtype1 == int(GeomType.BOX.value) or geomtype1 == int(GeomType.MESH.value))
2132-
and (geomtype2 == int(GeomType.BOX.value) or geomtype2 == int(GeomType.MESH.value))
2151+
and (geomtype1 == int(GeomType.BOX.value) or (geomtype1 == int(GeomType.MESH.value) and geom1.mesh_polyadr > -1))
2152+
and (geomtype2 == int(GeomType.BOX.value) or (geomtype2 == int(GeomType.MESH.value) and geom2.mesh_polyadr > -1))
21332153
):
21342154
num, w1, w2 = multicontact(pt, pt.face[idx], x1, x2, geom1, geom2, geomtype1, geomtype2)
21352155
if num > 0:

0 commit comments

Comments
 (0)