Skip to content

Commit 2c64beb

Browse files
authored
Feature: MJX convert functions (#8)
* dump * feat: tuned optax to solve logchol * dump * fix: remove WIP examples * fix: bump version
1 parent d1000da commit 2c64beb

File tree

4 files changed

+141
-2
lines changed

4 files changed

+141
-2
lines changed

mujoco_sysid/mjx/convert.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import jax.numpy as np
2+
3+
4+
def theta2pseudo(theta: np.ndarray) -> np.ndarray:
5+
m = theta[0]
6+
h = theta[1:4]
7+
I_xx, I_xy, I_yy, I_xz, I_yz, I_zz = theta[4:]
8+
9+
I_bar = np.array([[I_xx, I_xy, I_xz], [I_xy, I_yy, I_yz], [I_xz, I_yz, I_zz]])
10+
11+
Sigma = 0.5 * np.trace(I_bar) * np.eye(3) - I_bar
12+
13+
pseudo_inertia = np.zeros((4, 4))
14+
pseudo_inertia = pseudo_inertia.at[:3, :3].set(Sigma)
15+
pseudo_inertia = pseudo_inertia.at[:3, 3].set(h)
16+
pseudo_inertia = pseudo_inertia.at[3, :3].set(h)
17+
pseudo_inertia = pseudo_inertia.at[3, 3].set(m)
18+
19+
return pseudo_inertia
20+
21+
22+
def pseudo2theta(pseudo_inertia: np.ndarray) -> np.ndarray:
23+
m = pseudo_inertia[3, 3]
24+
h = pseudo_inertia[:3, 3]
25+
Sigma = pseudo_inertia[:3, :3]
26+
27+
I_bar = np.trace(Sigma) * np.eye(3) - Sigma
28+
29+
I_xx = I_bar[0, 0]
30+
I_xy = I_bar[0, 1]
31+
I_yy = I_bar[1, 1]
32+
I_xz = I_bar[0, 2]
33+
I_yz = I_bar[1, 2]
34+
I_zz = I_bar[2, 2]
35+
36+
theta = np.array([m, h[0], h[1], h[2], I_xx, I_xy, I_yy, I_xz, I_yz, I_zz])
37+
38+
return theta
39+
40+
41+
def logchol2chol(log_cholesky):
42+
alpha, d1, d2, d3, s12, s23, s13, t1, t2, t3 = log_cholesky
43+
44+
exp_alpha = np.exp(alpha)
45+
exp_d1 = np.exp(d1)
46+
exp_d2 = np.exp(d2)
47+
exp_d3 = np.exp(d3)
48+
49+
U = np.zeros((4, 4))
50+
U = U.at[0, 0].set(exp_d1)
51+
U = U.at[0, 1].set(s12)
52+
U = U.at[0, 2].set(s13)
53+
U = U.at[0, 3].set(t1)
54+
U = U.at[1, 1].set(exp_d2)
55+
U = U.at[1, 2].set(s23)
56+
U = U.at[1, 3].set(t2)
57+
U = U.at[2, 2].set(exp_d3)
58+
U = U.at[2, 3].set(t3)
59+
U = U.at[3, 3].set(1)
60+
61+
U *= exp_alpha
62+
63+
return U
64+
65+
66+
def chol2logchol(U: np.ndarray) -> np.ndarray:
67+
alpha = np.log(U[3, 3])
68+
d1 = np.log(U[0, 0] / U[3, 3])
69+
d2 = np.log(U[1, 1] / U[3, 3])
70+
d3 = np.log(U[2, 2] / U[3, 3])
71+
s12 = U[0, 1] / U[3, 3]
72+
s23 = U[1, 2] / U[3, 3]
73+
s13 = U[0, 2] / U[3, 3]
74+
t1 = U[0, 3] / U[3, 3]
75+
t2 = U[1, 3] / U[3, 3]
76+
t3 = U[2, 3] / U[3, 3]
77+
return np.array([alpha, d1, d2, d3, s12, s23, s13, t1, t2, t3])
78+
79+
80+
def logchol2theta(log_cholesky: np.ndarray) -> np.ndarray:
81+
alpha, d1, d2, d3, s12, s23, s13, t1, t2, t3 = log_cholesky
82+
83+
exp_d1 = np.exp(d1)
84+
exp_d2 = np.exp(d2)
85+
exp_d3 = np.exp(d3)
86+
87+
theta = np.array(
88+
[
89+
1,
90+
t1,
91+
t2,
92+
t3,
93+
s23**2 + t2**2 + t3**2 + exp_d2**2 + exp_d3**2,
94+
-s12 * exp_d2 - s13 * s23 - t1 * t2,
95+
s12**2 + s13**2 + t1**2 + t3**2 + exp_d1**2 + exp_d3**2,
96+
-s13 * exp_d3 - t1 * t3,
97+
-s23 * exp_d3 - t2 * t3,
98+
s12**2 + s13**2 + s23**2 + t1**2 + t2**2 + exp_d1**2 + exp_d2**2,
99+
]
100+
)
101+
102+
exp_2_alpha = np.exp(2 * alpha)
103+
theta *= exp_2_alpha
104+
105+
return theta
106+
107+
108+
def pseudo2cholesky(pseudo_inertia: np.ndarray) -> np.ndarray:
109+
n = pseudo_inertia.shape[0]
110+
indices = np.arange(n - 1, -1, -1)
111+
112+
reversed_inertia = pseudo_inertia[indices][:, indices]
113+
114+
L_prime = np.linalg.cholesky(reversed_inertia)
115+
116+
U = L_prime[indices][:, indices]
117+
118+
return U
119+
120+
121+
def cholesky2pseudo(U: np.ndarray) -> np.ndarray:
122+
return U @ U.T
123+
124+
125+
def pseudo2logchol(pseudo_inertia: np.ndarray) -> np.ndarray:
126+
U = pseudo2cholesky(pseudo_inertia)
127+
logchol = chol2logchol(U)
128+
return logchol
129+
130+
131+
def theta2logchol(theta: np.ndarray) -> np.ndarray:
132+
pseudo_inertia = theta2pseudo(theta)
133+
return pseudo2logchol(pseudo_inertia)

mujoco_sysid/regressors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def joint_body_regressor(mj_model, mj_data, body_id) -> npt.ArrayLike:
111111
def get_jacobian(mjmodel, mjdata, bodyid):
112112
R = mjdata.xmat[bodyid].reshape(3, 3)
113113

114-
jac_lin, jac_rot = np.zeros((3, 6)), np.zeros((3, 6))
114+
jac_lin, jac_rot = np.zeros((3, mjmodel.nv)), np.zeros((3, mjmodel.nv))
115115
mujoco.mj_jacBody(mjmodel, mjdata, jac_lin, jac_rot, bodyid)
116116

117117
return np.vstack([R.T @ jac_lin, R.T @ jac_rot])

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "mujoco_sysid"
33
description = "MuJoCo System Identification tools"
4-
version = "0.2.0"
4+
version = "0.2.1"
55
authors = [
66
{ name = "Lev Kozlov", email = "[email protected]" },
77
{ name = "Simeon Nedelchev", email = "[email protected]" },

tests/test_dynamics.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33

44
from mujoco_sysid import regressors
5+
from mujoco_sysid import parameters
56
from mujoco_sysid.utils import muj2pin
67

78
np.random.seed(0)
@@ -104,6 +105,8 @@ def test_joint_torque_regressor():
104105

105106
SAMPLES = 10000
106107

108+
theta = np.concatenate([parameters.get_dynamic_parameters(mjmodel, i) for i in mjmodel.jnt_bodyid])
109+
107110
for _ in range(SAMPLES):
108111
q, v, dv = np.random.rand(pinmodel.nq), np.random.rand(pinmodel.nv), np.random.rand(pinmodel.nv)
109112
pin.rnea(pinmodel, pindata, q, v, dv)
@@ -117,6 +120,9 @@ def test_joint_torque_regressor():
117120
pinY = pin.computeJointTorqueRegressor(pinmodel, pindata, q, v, dv)
118121
mjY = regressors.joint_torque_regressor(mjmodel, mjdata)
119122

123+
tau = pin.rnea(pinmodel, pindata, q, v, dv)
124+
125+
assert np.allclose(mjY @ theta, tau, atol=1e-6), f"Norm diff: {np.linalg.norm(mjY @ theta - tau)}"
120126
assert np.allclose(mjY, pinY, atol=1e-6), f"Norm diff: {np.linalg.norm(mjY - pinY)}"
121127

122128

0 commit comments

Comments
 (0)