Skip to content

Commit 64d0f57

Browse files
thowellcopybara-github
authored andcommitted
Fixes for MJX tendons and muscle actuators. Fixes #2317.
PiperOrigin-RevId: 714951834 Change-Id: I823b3a01b2b76cb4707cb313afd1a187f9bdb333
1 parent 1c32b5d commit 64d0f57

File tree

2 files changed

+33
-13
lines changed

2 files changed

+33
-13
lines changed

mjx/mujoco/mjx/_src/smooth.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,9 @@ def tendon(m: Model, d: Data) -> Data:
734734
for adr, num in zip(m.tendon_adr, m.tendon_num):
735735
for id_pulley in wrap_id_pulley:
736736
if adr <= id_pulley < adr + num:
737-
divisor[id_pulley : adr + num] = m.wrap_prm[id_pulley]
737+
divisor[id_pulley : adr + num] = np.maximum(
738+
mujoco.mjMINVAL, m.wrap_prm[id_pulley]
739+
)
738740

739741
# process spatial tendon sites
740742
(wrap_id_site,) = np.nonzero(m.wrap_type == WrapType.SITE)

mjx/mujoco/mjx/_src/support.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,9 @@ def _length_circle(
529529
p0n = math.normalize(p0).reshape(-1)
530530
p1n = math.normalize(p1).reshape(-1)
531531

532-
angle = jp.arccos(jp.dot(p0n, p1n))
532+
# clip input to closed interval for jp.arccos to prevent potential nan
533+
# TODO(taylorhowell): add test for case where clip is necessary
534+
angle = jp.arccos(jp.clip(jp.dot(p0n, p1n), -1, 1))
533535

534536
# flip if necessary
535537
cross = p0[1] * p1[0] - p0[0] * p1[1]
@@ -554,7 +556,11 @@ def _is_intersect(
554556
(p2[0] - p1[0]) * (p1[1] - p3[1]) - (p2[1] - p1[1]) * (p1[0] - p3[0])
555557
) / det
556558

557-
return (a >= 0) & (a <= 1) & (b >= 0) & (b <= 1)
559+
return jp.where(
560+
jp.abs(det) < mujoco.mjMINVAL,
561+
0,
562+
(a >= 0) & (a <= 1) & (b >= 0) & (b <= 1),
563+
)
558564

559565

560566
def wrap_circle(
@@ -567,7 +573,9 @@ def wrap_circle(
567573
sqrad = rad * rad
568574
dif = jp.array([d[2] - d[0], d[3] - d[1]])
569575
dd = dif[0] ** 2 + dif[1] ** 2
570-
a = jp.clip(-(dif[0] * d[0] + dif[1] * d[1]) / dd, 0, 1)
576+
a = jp.clip(
577+
-(dif[0] * d[0] + dif[1] * d[1]) / jp.maximum(mujoco.mjMINVAL, dd), 0, 1
578+
)
571579
seg = jp.array([a * dif[0] + d[0], a * dif[1] + d[1]])
572580

573581
point_inside0 = sqlen0 < sqrad
@@ -581,13 +589,21 @@ def wrap_circle(
581589

582590
# construct the two solutions, compute goodness
583591
def _sol(sgn):
584-
sqrt0 = jp.sqrt(sqlen0 - sqrad)
585-
sqrt1 = jp.sqrt(sqlen1 - sqrad)
592+
sqrt0 = jp.sqrt(jp.maximum(mujoco.mjMINVAL, sqlen0 - sqrad))
593+
sqrt1 = jp.sqrt(jp.maximum(mujoco.mjMINVAL, sqlen1 - sqrad))
586594

587-
d00 = (d[0] * sqrad + sgn * rad * d[1] * sqrt0) / sqlen0
588-
d01 = (d[1] * sqrad - sgn * rad * d[0] * sqrt0) / sqlen0
589-
d10 = (d[2] * sqrad - sgn * rad * d[3] * sqrt1) / sqlen1
590-
d11 = (d[3] * sqrad + sgn * rad * d[2] * sqrt1) / sqlen1
595+
d00 = (d[0] * sqrad + sgn * rad * d[1] * sqrt0) / jp.maximum(
596+
mujoco.mjMINVAL, sqlen0
597+
)
598+
d01 = (d[1] * sqrad - sgn * rad * d[0] * sqrt0) / jp.maximum(
599+
mujoco.mjMINVAL, sqlen0
600+
)
601+
d10 = (d[2] * sqrad - sgn * rad * d[3] * sqrt1) / jp.maximum(
602+
mujoco.mjMINVAL, sqlen1
603+
)
604+
d11 = (d[3] * sqrad + sgn * rad * d[2] * sqrt1) / jp.maximum(
605+
mujoco.mjMINVAL, sqlen1
606+
)
591607

592608
sol = jp.array([[d00, d01], [d10, d11]])
593609

@@ -785,9 +801,8 @@ def muscle_gain(
785801

786802
# velocity curve
787803
y = fvmax - 1
788-
FV = fvmax # pylint:disable=invalid-name
789804
FV = jp.where( # pylint:disable=invalid-name
790-
V <= y, fvmax - jp.square(y - V) / jp.maximum(mujoco.mjMINVAL, y), FV
805+
V <= y, fvmax - jp.square(y - V) / jp.maximum(mujoco.mjMINVAL, y), fvmax
791806
)
792807
FV = jp.where(V <= 0, jp.square(V + 1), FV) # pylint:disable=invalid-name
793808
FV = jp.where(V <= -1, 0, FV) # pylint:disable=invalid-name
@@ -845,7 +860,10 @@ def _sigmoid(x):
845860
# sigmoid function over 0 <= x <= 1 using quintic polynomial
846861
# sigmoid: f(x) = 6 * x^5 - 15 * x^4 + 10 * x^3
847862
# solution of f(0) = f'(0) = f''(0) = 0, f(1) = 1, f'(1) = f''(1) = 0
848-
return jp.clip(x**3 * (3 * x * (2 * x - 5) + 10), 0, 1)
863+
sol = x * x * x * (3 * x * (2 * x - 5) + 10)
864+
sol = jp.where(x <= 0, 0, sol)
865+
sol = jp.where(x >= 1, 1, sol)
866+
return sol
849867

850868
# smooth switching
851869
# scale by width, center around 0.5 midpoint, rescale to bounds

0 commit comments

Comments
 (0)