Skip to content

Commit 0ba31e2

Browse files
thowellcopybara-github
authored andcommitted
MJX safe division. Fixes #2776, #2568, #2657.
PiperOrigin-RevId: 794997676 Change-Id: I3100af7c1987a8435b29cc68c4e5c624dc573986
1 parent 07e7417 commit 0ba31e2

File tree

5 files changed

+40
-24
lines changed

5 files changed

+40
-24
lines changed

mjx/mujoco/mjx/_src/collision_primitive.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def plane_cylinder(plane: GeomInfo, cylinder: GeomInfo) -> Collision:
120120
# disk parallel to plane: pick x-axis of cylinder, scale by radius
121121
cylinder.mat[:, 0] * cylinder.size[0],
122122
# general configuration: normalize vector, scale by radius
123-
vec / len_ * cylinder.size[0],
123+
math.safe_div(vec, len_) * cylinder.size[0],
124124
)
125125

126126
# project vector on normal

mjx/mujoco/mjx/_src/math.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@
1818

1919
import jax
2020
from jax import numpy as jp
21+
import mujoco
22+
23+
24+
def safe_div(
25+
num: Union[float, jax.Array], den: Union[float, jax.Array]
26+
) -> Union[float, jax.Array]:
27+
"""Safe division for case where denominator is zero."""
28+
return num / (den + mujoco.mjMINVAL * (den == 0))
2129

2230

2331
def matmul_unroll(a: jax.Array, b: jax.Array) -> jax.Array:

mjx/mujoco/mjx/_src/ray.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def _ray_quad(
3535
det = b * b - a * c
3636
det_2 = jp.sqrt(det)
3737

38-
x0, x1 = (-b - det_2) / a, (-b + det_2) / a
38+
x0, x1 = math.safe_div(-b - det_2, a), math.safe_div(-b + det_2, a)
3939
x0 = jp.where((det < mujoco.mjMINVAL) | (x0 < 0), jp.inf, x0)
4040
x1 = jp.where((det < mujoco.mjMINVAL) | (x1 < 0), jp.inf, x1)
4141

@@ -48,7 +48,7 @@ def _ray_plane(
4848
vec: jax.Array,
4949
) -> jax.Array:
5050
"""Returns the distance at which a ray intersects with a plane."""
51-
x = -pnt[2] / vec[2]
51+
x = -math.safe_div(pnt[2], vec[2])
5252

5353
valid = vec[2] <= -mujoco.mjMINVAL # z-vec pointing towards front face
5454
valid &= x >= 0
@@ -116,7 +116,7 @@ def _ray_ellipsoid(
116116
"""Returns the distance at which a ray intersects with an ellipsoid."""
117117

118118
# invert size^2
119-
s = 1 / jp.square(size)
119+
s = math.safe_div(1, jp.square(size))
120120

121121
# (x*lvec+lpnt)' * diag(1/size^2) * (x*lvec+lpnt) = 1
122122
svec = s * vec
@@ -142,7 +142,7 @@ def _ray_box(
142142

143143
# side +1, -1
144144
# solution of pnt[i] + x * vec[i] = side * size[i]
145-
x = jp.concatenate([(size - pnt) / vec, (-size - pnt) / vec])
145+
x = jp.concatenate([math.safe_div(size - pnt, vec), -math.safe_div(size + pnt, vec)])
146146

147147
# intersection with face
148148
p0 = pnt[iface[:, 0]] + x * vec[iface[:, 0]]
@@ -170,13 +170,13 @@ def _ray_triangle(
170170
b = -planar[2]
171171
det = A[0, 0] * A[1, 1] - A[1, 0] * A[0, 1]
172172

173-
t0 = (A[1, 1] * b[0] - A[1, 0] * b[1]) / det
174-
t1 = (-A[0, 1] * b[0] + A[0, 0] * b[1]) / det
173+
t0 = math.safe_div(A[1, 1] * b[0] - A[1, 0] * b[1], det)
174+
t1 = math.safe_div(-A[0, 1] * b[0] + A[0, 0] * b[1], det)
175175
valid = (t0 >= 0) & (t1 >= 0) & (t0 + t1 <= 1)
176176

177177
# intersect ray with plane of triangle
178178
nrm = jp.cross(vert[0] - vert[2], vert[1] - vert[2])
179-
dist = jp.dot(vert[2] - pnt, nrm) / jp.dot(vec, nrm)
179+
dist = math.safe_div(jp.dot(vert[2] - pnt, nrm), jp.dot(vec, nrm))
180180
valid &= dist >= 0
181181
dist = jp.where(valid, dist, jp.inf)
182182

mjx/mujoco/mjx/_src/smooth.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def kinematics(m: Model, d: Data) -> Data:
4444
from mujoco.mjx.warp import smooth as mjxw_smooth # pylint: disable=g-import-not-at-top # pytype: disable=import-error
4545
return mjxw_smooth.kinematics(m, d)
4646

47+
4748
def fn(carry, jnt_typs, jnt_pos, jnt_axis, qpos, qpos0, pos, quat):
4849
# calculate joint anchors, axes, body pos and quat in global frame
4950
# also normalize qpos while we're at it
@@ -910,7 +911,9 @@ def _length_moment(pnt0, pnt1, body0, body1):
910911
dif = pnt1 - pnt0
911912
length = math.norm(dif)
912913
vec = jp.where(
913-
length < mujoco.mjMINVAL, jp.array([1.0, 0.0, 0.0]), dif / length
914+
length < mujoco.mjMINVAL,
915+
jp.array([1.0, 0.0, 0.0]),
916+
math.safe_div(dif, length),
914917
)
915918

916919
jacp1, _ = support.jac(m, d, pnt0, body0)
@@ -1387,14 +1390,16 @@ def _momentdot(wpnt0, wpnt1, wvel0, wvel1, body0, body1):
13871390
dpnt = wpnt1 - wpnt0
13881391
norm = math.norm(dpnt)
13891392
dpnt = jp.where(
1390-
norm < mujoco.mjMINVAL, jp.array([1.0, 0.0, 0.0]), dpnt / norm
1393+
norm < mujoco.mjMINVAL,
1394+
jp.array([1.0, 0.0, 0.0]),
1395+
math.safe_div(dpnt, norm),
13911396
)
13921397

13931398
# dvel = d / dt(dpnt)
13941399
dvel = wvel1 - wvel0
13951400
dot = jp.dot(dpnt, dvel)
13961401
dvel += dpnt * -dot
1397-
dvel = jp.where(norm > mujoco.mjMINVAL, dvel / norm, 0.0)
1402+
dvel = jp.where(norm > mujoco.mjMINVAL, math.safe_div(dvel, norm), 0.0)
13981403

13991404
# get endpoint JacobianDots, subtract
14001405
jacp1, _ = support.jac_dot(m, d, wpnt0, body0)

mjx/mujoco/mjx/_src/support.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,12 @@ def __getname(self, name: str):
470470
return name
471471
else:
472472
raise AttributeError('ctrl is not available for this type')
473-
if name == 'qpos' or name == 'qvel' or name == 'qacc' or name.startswith('qfrc_'):
473+
if (
474+
name == 'qpos'
475+
or name == 'qvel'
476+
or name == 'qacc'
477+
or name.startswith('qfrc_')
478+
):
474479
if self.prefix == 'jnt_':
475480
return name
476481
else:
@@ -672,12 +677,12 @@ def _is_intersect(
672677
det = (p4[1] - p3[1]) * (p2[0] - p1[0]) - (p4[0] - p3[0]) * (p2[1] - p1[1])
673678

674679
# compute intersection point on each line
675-
a = (
676-
(p4[0] - p3[0]) * (p1[1] - p3[1]) - (p4[1] - p3[1]) * (p1[0] - p3[0])
677-
) / det
678-
b = (
679-
(p2[0] - p1[0]) * (p1[1] - p3[1]) - (p2[1] - p1[1]) * (p1[0] - p3[0])
680-
) / det
680+
a = math.safe_div(
681+
(p4[0] - p3[0]) * (p1[1] - p3[1]) - (p4[1] - p3[1]) * (p1[0] - p3[0]), det
682+
)
683+
b = math.safe_div(
684+
(p2[0] - p1[0]) * (p1[1] - p3[1]) - (p2[1] - p1[1]) * (p1[0] - p3[0]), det
685+
)
681686

682687
return jp.where(
683688
jp.abs(det) < mujoco.mjMINVAL,
@@ -856,9 +861,7 @@ def _newton(carry, _):
856861
status0 = df > -mjMINVAL
857862

858863
# new point
859-
z_next = z - (1 - converged) * f / jp.where(
860-
jp.abs(df) < mjMINVAL, mjMINVAL, df
861-
)
864+
z_next = z - (1 - converged) * math.safe_div(f, df)
862865

863866
# make sure we are moving to the left; SHOULD NOT OCCUR
864867
status1 = z_next > z
@@ -987,8 +990,8 @@ def wrap(
987990
l1 = jp.sqrt(
988991
(p1[0] - res[3]) * (p1[0] - res[3]) + (p1[1] - res[4]) * (p1[1] - res[4])
989992
)
990-
r2 = p0[2] + (p1[2] - p0[2]) * l0 / (l0 + wlen + l1)
991-
r5 = p0[2] + (p1[2] - p0[2]) * (l0 + wlen) / (l0 + wlen + l1)
993+
r2 = p0[2] + (p1[2] - p0[2]) * math.safe_div(l0, l0 + wlen + l1)
994+
r5 = p0[2] + (p1[2] - p0[2]) * math.safe_div(l0 + wlen, l0 + wlen + l1)
992995
height = jp.abs(r5 - r2)
993996

994997
wlen = jp.where(is_sphere, wlen, jp.sqrt(wlen * wlen + height * height))
@@ -1130,7 +1133,7 @@ def _sigmoid(x):
11301133
# smooth switching
11311134
# scale by width, center around 0.5 midpoint, rescale to bounds
11321135
tau_smooth = tau_deact + (tau_act - tau_deact) * _sigmoid(
1133-
dctrl / smoothing_width + 0.5
1136+
math.safe_div(dctrl, smoothing_width) + 0.5
11341137
)
11351138

11361139
return jp.where(smoothing_width < mujoco.mjMINVAL, tau_hard, tau_smooth)

0 commit comments

Comments
 (0)