@@ -529,7 +529,9 @@ def _length_circle(
529529 p0n = math .normalize (p0 ).reshape (- 1 )
530530 p1n = math .normalize (p1 ).reshape (- 1 )
531531
532- angle = jp .arccos (jp .dot (p0n , p1n ))
532+ # clip input to closed interval for jp.arccos to prevent potential nan
533+ # TODO(taylorhowell): add test for case where clip is necessary
534+ angle = jp .arccos (jp .clip (jp .dot (p0n , p1n ), - 1 , 1 ))
533535
534536 # flip if necessary
535537 cross = p0 [1 ] * p1 [0 ] - p0 [0 ] * p1 [1 ]
@@ -554,7 +556,11 @@ def _is_intersect(
554556 (p2 [0 ] - p1 [0 ]) * (p1 [1 ] - p3 [1 ]) - (p2 [1 ] - p1 [1 ]) * (p1 [0 ] - p3 [0 ])
555557 ) / det
556558
557- return (a >= 0 ) & (a <= 1 ) & (b >= 0 ) & (b <= 1 )
559+ return jp .where (
560+ jp .abs (det ) < mujoco .mjMINVAL ,
561+ 0 ,
562+ (a >= 0 ) & (a <= 1 ) & (b >= 0 ) & (b <= 1 ),
563+ )
558564
559565
560566def wrap_circle (
@@ -567,7 +573,9 @@ def wrap_circle(
567573 sqrad = rad * rad
568574 dif = jp .array ([d [2 ] - d [0 ], d [3 ] - d [1 ]])
569575 dd = dif [0 ] ** 2 + dif [1 ] ** 2
570- a = jp .clip (- (dif [0 ] * d [0 ] + dif [1 ] * d [1 ]) / dd , 0 , 1 )
576+ a = jp .clip (
577+ - (dif [0 ] * d [0 ] + dif [1 ] * d [1 ]) / jp .maximum (mujoco .mjMINVAL , dd ), 0 , 1
578+ )
571579 seg = jp .array ([a * dif [0 ] + d [0 ], a * dif [1 ] + d [1 ]])
572580
573581 point_inside0 = sqlen0 < sqrad
@@ -581,13 +589,21 @@ def wrap_circle(
581589
582590 # construct the two solutions, compute goodness
583591 def _sol (sgn ):
584- sqrt0 = jp .sqrt (sqlen0 - sqrad )
585- sqrt1 = jp .sqrt (sqlen1 - sqrad )
592+ sqrt0 = jp .sqrt (jp . maximum ( mujoco . mjMINVAL , sqlen0 - sqrad ) )
593+ sqrt1 = jp .sqrt (jp . maximum ( mujoco . mjMINVAL , sqlen1 - sqrad ) )
586594
587- d00 = (d [0 ] * sqrad + sgn * rad * d [1 ] * sqrt0 ) / sqlen0
588- d01 = (d [1 ] * sqrad - sgn * rad * d [0 ] * sqrt0 ) / sqlen0
589- d10 = (d [2 ] * sqrad - sgn * rad * d [3 ] * sqrt1 ) / sqlen1
590- d11 = (d [3 ] * sqrad + sgn * rad * d [2 ] * sqrt1 ) / sqlen1
595+ d00 = (d [0 ] * sqrad + sgn * rad * d [1 ] * sqrt0 ) / jp .maximum (
596+ mujoco .mjMINVAL , sqlen0
597+ )
598+ d01 = (d [1 ] * sqrad - sgn * rad * d [0 ] * sqrt0 ) / jp .maximum (
599+ mujoco .mjMINVAL , sqlen0
600+ )
601+ d10 = (d [2 ] * sqrad - sgn * rad * d [3 ] * sqrt1 ) / jp .maximum (
602+ mujoco .mjMINVAL , sqlen1
603+ )
604+ d11 = (d [3 ] * sqrad + sgn * rad * d [2 ] * sqrt1 ) / jp .maximum (
605+ mujoco .mjMINVAL , sqlen1
606+ )
591607
592608 sol = jp .array ([[d00 , d01 ], [d10 , d11 ]])
593609
@@ -785,9 +801,8 @@ def muscle_gain(
785801
786802 # velocity curve
787803 y = fvmax - 1
788- FV = fvmax # pylint:disable=invalid-name
789804 FV = jp .where ( # pylint:disable=invalid-name
790- V <= y , fvmax - jp .square (y - V ) / jp .maximum (mujoco .mjMINVAL , y ), FV
805+ V <= y , fvmax - jp .square (y - V ) / jp .maximum (mujoco .mjMINVAL , y ), fvmax
791806 )
792807 FV = jp .where (V <= 0 , jp .square (V + 1 ), FV ) # pylint:disable=invalid-name
793808 FV = jp .where (V <= - 1 , 0 , FV ) # pylint:disable=invalid-name
@@ -845,7 +860,10 @@ def _sigmoid(x):
845860 # sigmoid function over 0 <= x <= 1 using quintic polynomial
846861 # sigmoid: f(x) = 6 * x^5 - 15 * x^4 + 10 * x^3
847862 # solution of f(0) = f'(0) = f''(0) = 0, f(1) = 1, f'(1) = f''(1) = 0
848- return jp .clip (x ** 3 * (3 * x * (2 * x - 5 ) + 10 ), 0 , 1 )
863+ sol = x * x * x * (3 * x * (2 * x - 5 ) + 10 )
864+ sol = jp .where (x <= 0 , 0 , sol )
865+ sol = jp .where (x >= 1 , 1 , sol )
866+ return sol
849867
850868 # smooth switching
851869 # scale by width, center around 0.5 midpoint, rescale to bounds
0 commit comments