@@ -35,7 +35,7 @@ def _ray_quad(
3535 det = b * b - a * c
3636 det_2 = jp .sqrt (det )
3737
38- x0 , x1 = (- b - det_2 ) / a , (- b + det_2 ) / a
38+ x0 , x1 = math . safe_div (- b - det_2 , a ), math . safe_div (- b + det_2 , a )
3939 x0 = jp .where ((det < mujoco .mjMINVAL ) | (x0 < 0 ), jp .inf , x0 )
4040 x1 = jp .where ((det < mujoco .mjMINVAL ) | (x1 < 0 ), jp .inf , x1 )
4141
@@ -48,7 +48,7 @@ def _ray_plane(
4848 vec : jax .Array ,
4949) -> jax .Array :
5050 """Returns the distance at which a ray intersects with a plane."""
51- x = - pnt [2 ] / vec [2 ]
51+ x = - math . safe_div ( pnt [2 ], vec [2 ])
5252
5353 valid = vec [2 ] <= - mujoco .mjMINVAL # z-vec pointing towards front face
5454 valid &= x >= 0
@@ -116,7 +116,7 @@ def _ray_ellipsoid(
116116 """Returns the distance at which a ray intersects with an ellipsoid."""
117117
118118 # invert size^2
119- s = 1 / jp .square (size )
119+ s = math . safe_div ( 1 , jp .square (size ) )
120120
121121 # (x*lvec+lpnt)' * diag(1/size^2) * (x*lvec+lpnt) = 1
122122 svec = s * vec
@@ -142,7 +142,7 @@ def _ray_box(
142142
143143 # side +1, -1
144144 # solution of pnt[i] + x * vec[i] = side * size[i]
145- x = jp .concatenate ([(size - pnt ) / vec , ( - size - pnt ) / vec ])
145+ x = jp .concatenate ([math . safe_div (size - pnt , vec ), - math . safe_div ( size + pnt , vec ) ])
146146
147147 # intersection with face
148148 p0 = pnt [iface [:, 0 ]] + x * vec [iface [:, 0 ]]
@@ -170,13 +170,13 @@ def _ray_triangle(
170170 b = - planar [2 ]
171171 det = A [0 , 0 ] * A [1 , 1 ] - A [1 , 0 ] * A [0 , 1 ]
172172
173- t0 = (A [1 , 1 ] * b [0 ] - A [1 , 0 ] * b [1 ]) / det
174- t1 = (- A [0 , 1 ] * b [0 ] + A [0 , 0 ] * b [1 ]) / det
173+ t0 = math . safe_div (A [1 , 1 ] * b [0 ] - A [1 , 0 ] * b [1 ], det )
174+ t1 = math . safe_div (- A [0 , 1 ] * b [0 ] + A [0 , 0 ] * b [1 ], det )
175175 valid = (t0 >= 0 ) & (t1 >= 0 ) & (t0 + t1 <= 1 )
176176
177177 # intersect ray with plane of triangle
178178 nrm = jp .cross (vert [0 ] - vert [2 ], vert [1 ] - vert [2 ])
179- dist = jp .dot (vert [2 ] - pnt , nrm ) / jp .dot (vec , nrm )
179+ dist = math . safe_div ( jp .dot (vert [2 ] - pnt , nrm ), jp .dot (vec , nrm ) )
180180 valid &= dist >= 0
181181 dist = jp .where (valid , dist , jp .inf )
182182
0 commit comments