Skip to content

Commit 430b83d

Browse files
authored
Feature/mjx (#5)
* feat: add parameters implementation * fix: sort * dump * feat: matched xbody object velocity * feat: start testing mjx regressor * style: ruff check * fix: alter the test to use jit and vmap * fix: clear the notebook
1 parent 78c2460 commit 430b83d

File tree

9 files changed

+677
-85
lines changed

9 files changed

+677
-85
lines changed

.github/workflows/build.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ jobs:
3535
- name: Install dependencies
3636
run: |
3737
python -m pip install --upgrade pip
38-
pip install .
38+
pip install .[mjx_cpu]
3939
pip install pytest robot_descriptions pin
4040
- name: Run tests
4141
run: pytest

examples/match_energy.ipynb

Lines changed: 14 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 1,
5+
"execution_count": null,
66
"metadata": {},
77
"outputs": [],
88
"source": [
@@ -21,7 +21,7 @@
2121
},
2222
{
2323
"cell_type": "code",
24-
"execution_count": 2,
24+
"execution_count": null,
2525
"metadata": {},
2626
"outputs": [],
2727
"source": [
@@ -40,7 +40,7 @@
4040
},
4141
{
4242
"cell_type": "code",
43-
"execution_count": 3,
43+
"execution_count": null,
4444
"metadata": {},
4545
"outputs": [],
4646
"source": [
@@ -49,7 +49,7 @@
4949
},
5050
{
5151
"cell_type": "code",
52-
"execution_count": 4,
52+
"execution_count": null,
5353
"metadata": {},
5454
"outputs": [],
5555
"source": [
@@ -61,20 +61,9 @@
6161
},
6262
{
6363
"cell_type": "code",
64-
"execution_count": 5,
64+
"execution_count": null,
6565
"metadata": {},
66-
"outputs": [
67-
{
68-
"data": {
69-
"text/plain": [
70-
"(array([-1.21974604, 0. ]), -1.2197460404740708)"
71-
]
72-
},
73-
"execution_count": 5,
74-
"metadata": {},
75-
"output_type": "execute_result"
76-
}
77-
],
66+
"outputs": [],
7867
"source": [
7968
"mjdata.qpos[:] = q.copy()\n",
8069
"mjdata.qvel[:] = v.copy()\n",
@@ -89,20 +78,9 @@
8978
},
9079
{
9180
"cell_type": "code",
92-
"execution_count": 6,
81+
"execution_count": null,
9382
"metadata": {},
94-
"outputs": [
95-
{
96-
"data": {
97-
"text/plain": [
98-
"(-1.2197457679147197, [-1.2197457679147197, 0.0])"
99-
]
100-
},
101-
"execution_count": 6,
102-
"metadata": {},
103-
"output_type": "execute_result"
104-
}
105-
],
83+
"outputs": [],
10684
"source": [
10785
"(\n",
10886
" np.sum(\n",
@@ -120,23 +98,9 @@
12098
},
12199
{
122100
"cell_type": "code",
123-
"execution_count": 7,
101+
"execution_count": null,
124102
"metadata": {},
125-
"outputs": [
126-
{
127-
"data": {
128-
"text/plain": [
129-
"((60,),\n",
130-
" array([ 6.73326000e-01, 1.66311522e-06, -1.69664685e-04, 1.56021081e-02,\n",
131-
" 1.64484975e-03, -5.95801269e-08, 1.08083746e-03, -4.38536417e-07,\n",
132-
" 4.43147004e-06, 8.39403034e-04]))"
133-
]
134-
},
135-
"execution_count": 7,
136-
"metadata": {},
137-
"output_type": "execute_result"
138-
}
139-
],
103+
"outputs": [],
140104
"source": [
141105
"theta = np.concatenate([parameters.get_dynamic_parameters(mjmodel, i) for i in mjmodel.jnt_bodyid])\n",
142106
"\n",
@@ -145,32 +109,9 @@
145109
},
146110
{
147111
"cell_type": "code",
148-
"execution_count": 8,
112+
"execution_count": null,
149113
"metadata": {},
150-
"outputs": [
151-
{
152-
"name": "stdout",
153-
"output_type": "stream",
154-
"text": [
155-
"for body 0 norm of difference is 4.903343079507622e-07\n",
156-
"for body 1 norm of difference is 2.5970031568529883e-06\n",
157-
"for body 2 norm of difference is 2.650078190181762e-07\n",
158-
"for body 3 norm of difference is 3.701542518534774e-07\n",
159-
"for body 4 norm of difference is 8.053173890678282e-08\n",
160-
"for body 5 norm of difference is 7.002258930267969e-08\n"
161-
]
162-
},
163-
{
164-
"data": {
165-
"text/plain": [
166-
"((60,), 2.6839308799870235e-06)"
167-
]
168-
},
169-
"execution_count": 8,
170-
"metadata": {},
171-
"output_type": "execute_result"
172-
}
173-
],
114+
"outputs": [],
174115
"source": [
175116
"params = []\n",
176117
"\n",
@@ -195,20 +136,9 @@
195136
},
196137
{
197138
"cell_type": "code",
198-
"execution_count": 9,
139+
"execution_count": null,
199140
"metadata": {},
200-
"outputs": [
201-
{
202-
"data": {
203-
"text/plain": [
204-
"-1.2197460404740712"
205-
]
206-
},
207-
"execution_count": 9,
208-
"metadata": {},
209-
"output_type": "execute_result"
210-
}
211-
],
141+
"outputs": [],
212142
"source": [
213143
"reg_en = regressors.mj_energyRegressor(mjmodel, mjdata)[2]\n",
214144
"\n",

examples/match_object_vel.ipynb

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [
8+
{
9+
"name": "stderr",
10+
"output_type": "stream",
11+
"text": [
12+
"2024-06-29 16:09:01.637386: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.3 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n"
13+
]
14+
}
15+
],
16+
"source": [
17+
"import jax\n",
18+
"import jax.numpy as jnp\n",
19+
"import mujoco\n",
20+
"import numpy as np\n",
21+
"from mujoco import mjx\n",
22+
"from robot_descriptions.z1_mj_description import MJCF_PATH\n",
23+
"\n",
24+
"key = jax.random.PRNGKey(0)\n",
25+
"\n",
26+
"mjmodel = mujoco.MjModel.from_xml_path(MJCF_PATH)\n",
27+
"mjdata = mujoco.MjData(mjmodel)\n",
28+
"\n",
29+
"# alter the model so it becomes mjx compatible\n",
30+
"mjmodel.dof_frictionloss = 0\n",
31+
"mjmodel.opt.integrator = 0\n",
32+
"\n",
33+
"mjxmodel = mjx.put_model(mjmodel)\n",
34+
"mjxdata = mjx.put_data(mjmodel, mjdata)"
35+
]
36+
},
37+
{
38+
"cell_type": "code",
39+
"execution_count": 2,
40+
"metadata": {},
41+
"outputs": [
42+
{
43+
"data": {
44+
"text/plain": [
45+
"'/home/leo/.cache/robot_descriptions/mujoco_menagerie/unitree_z1/z1.xml'"
46+
]
47+
},
48+
"execution_count": 2,
49+
"metadata": {},
50+
"output_type": "execute_result"
51+
}
52+
],
53+
"source": [
54+
"MJCF_PATH"
55+
]
56+
},
57+
{
58+
"cell_type": "code",
59+
"execution_count": 3,
60+
"metadata": {},
61+
"outputs": [
62+
{
63+
"data": {
64+
"text/plain": [
65+
"(Array([ 1.1901639 , -1.0996888 , 0.44367844, 0.5984697 , -0.39189556,\n",
66+
" 0.69261974], dtype=float32),\n",
67+
" Array([ 0.46018356, -2.068578 , -0.21438177, -0.9898306 , -0.6789304 ,\n",
68+
" 0.27362573], dtype=float32))"
69+
]
70+
},
71+
"execution_count": 3,
72+
"metadata": {},
73+
"output_type": "execute_result"
74+
}
75+
],
76+
"source": [
77+
"q, v = jax.random.normal(key, (2, mjmodel.nq))\n",
78+
"\n",
79+
"q, v"
80+
]
81+
},
82+
{
83+
"cell_type": "code",
84+
"execution_count": 4,
85+
"metadata": {},
86+
"outputs": [
87+
{
88+
"name": "stdout",
89+
"output_type": "stream",
90+
"text": [
91+
"bodyid: 2, v: [0. 0. 0.], w: [0. 0. 0.46018356]\n",
92+
"bodyid: 3, v: [-3.09150672e-18 0.00000000e+00 -1.57468993e-18], w: [ 0.41005399 -2.068578 0.20886511]\n",
93+
"bodyid: 4, v: [ 0.31078859 -0.07310279 -0.6539035 ], w: [ 0.28069364 -2.28295977 0.36466421]\n",
94+
"bodyid: 5, v: [ 0.23727103 -0.00960553 -0.02728739], w: [ 0.02646465 -3.27279039 0.45942195]\n",
95+
"bodyid: 6, v: [0.21066843 0.11146764 0.20180794], w: [ 1.27447096 -3.0145615 -0.21950845]\n",
96+
"bodyid: 7, v: [0.21066843 0.30104535 0.20516525], w: [ 1.5480967 -2.46010192 1.75603632]\n"
97+
]
98+
}
99+
],
100+
"source": [
101+
"mjdata.qpos[:] = q\n",
102+
"mjdata.qvel[:] = v\n",
103+
"\n",
104+
"mujoco.mj_step(mjmodel, mjdata)\n",
105+
"\n",
106+
"velocity = np.zeros(6)\n",
107+
"for bodyid in mjmodel.jnt_bodyid:\n",
108+
" mujoco.mj_objectVelocity(mjmodel, mjdata, 2, bodyid, velocity, 1)\n",
109+
"\n",
110+
" print(f\"bodyid: {bodyid}, v: {velocity[3:]}, w: {velocity[:3]}\")"
111+
]
112+
},
113+
{
114+
"cell_type": "code",
115+
"execution_count": 5,
116+
"metadata": {},
117+
"outputs": [
118+
{
119+
"name": "stdout",
120+
"output_type": "stream",
121+
"text": [
122+
"[[ 0. 0. 0. 0. 0. 0. ]\n",
123+
" [ 0. 0. 0. 0. 0. 0. ]\n",
124+
" [ 0. 0. 0.4602 0.0068 -0.0015 0. ]\n",
125+
" [ 1.9205 -0.7685 0.4602 0.0957 0.2207 -0.0307]\n",
126+
" [ 2.1196 -0.8481 0.4602 0.08 0.1816 0.0002]\n",
127+
" [ 3.0386 -1.2159 0.4602 0.0734 0.1651 0.0061]\n",
128+
" [ 3.0531 -1.1796 -0.2176 0.0331 0.1794 0.006 ]\n",
129+
" [ 3.2439 -0.984 -0.2031 0.0367 0.1764 -0.0013]]\n"
130+
]
131+
}
132+
],
133+
"source": [
134+
"with np.printoptions(precision=4, suppress=True):\n",
135+
" print(mjdata.cvel)"
136+
]
137+
},
138+
{
139+
"cell_type": "code",
140+
"execution_count": 6,
141+
"metadata": {},
142+
"outputs": [],
143+
"source": [
144+
"from mujoco_sysid.mjx.regressors import object_velocity\n",
145+
"\n",
146+
"mjxdata = mjxdata.replace(qpos=q, qvel=v)\n",
147+
"mjxdata = mjx.step(mjxmodel, mjxdata)"
148+
]
149+
},
150+
{
151+
"cell_type": "code",
152+
"execution_count": 7,
153+
"metadata": {},
154+
"outputs": [
155+
{
156+
"name": "stdout",
157+
"output_type": "stream",
158+
"text": [
159+
"[[ 0. 0. 0. 0. 0. 0. ]\n",
160+
" [ 0. 0. 0. 0. 0. 0. ]\n",
161+
" [ 0. 0. 0.4602 0.0068 -0.0015 0. ]\n",
162+
" [ 1.9205 -0.7685 0.4602 0.0957 0.2207 -0.0307]\n",
163+
" [ 2.1196 -0.8481 0.4602 0.08 0.1816 0.0002]\n",
164+
" [ 3.0386 -1.2159 0.4602 0.0734 0.1651 0.0061]\n",
165+
" [ 3.0531 -1.1796 -0.2176 0.0331 0.1794 0.006 ]\n",
166+
" [ 3.2439 -0.984 -0.2031 0.0367 0.1764 -0.0013]]\n"
167+
]
168+
}
169+
],
170+
"source": [
171+
"with np.printoptions(precision=4, suppress=True):\n",
172+
" print(mjxdata.cvel)"
173+
]
174+
},
175+
{
176+
"cell_type": "code",
177+
"execution_count": 8,
178+
"metadata": {},
179+
"outputs": [
180+
{
181+
"name": "stdout",
182+
"output_type": "stream",
183+
"text": [
184+
"bodyid: 2, v: [0. 0. 0.], w: [0. 0. 0.46018356]\n",
185+
"bodyid: 3, v: [0. 0. 0.], w: [ 0.41005418 -2.0685785 0.20886509]\n",
186+
"bodyid: 4, v: [ 0.31078863 -0.07310276 -0.6539037 ], w: [ 0.28069377 -2.2829604 0.36466417]\n",
187+
"bodyid: 5, v: [ 0.2372711 -0.00960554 -0.02728742], w: [ 0.02646467 -3.2727916 0.45942217]\n",
188+
"bodyid: 6, v: [0.21066852 0.1114677 0.20180808], w: [ 1.2744716 -3.0145626 -0.21950865]\n",
189+
"bodyid: 7, v: [0.21066849 0.30104554 0.20516539], w: [ 1.5480974 -2.4601028 1.7560369]\n"
190+
]
191+
}
192+
],
193+
"source": [
194+
"for bodyid in mjmodel.jnt_bodyid:\n",
195+
" velocity = object_velocity(mjxmodel, mjxdata, bodyid)\n",
196+
"\n",
197+
" print(f\"bodyid: {bodyid}, v: {velocity[3:]}, w: {velocity[:3]}\")"
198+
]
199+
}
200+
],
201+
"metadata": {
202+
"kernelspec": {
203+
"display_name": "venv",
204+
"language": "python",
205+
"name": "python3"
206+
},
207+
"language_info": {
208+
"codemirror_mode": {
209+
"name": "ipython",
210+
"version": 3
211+
},
212+
"file_extension": ".py",
213+
"mimetype": "text/x-python",
214+
"name": "python",
215+
"nbconvert_exporter": "python",
216+
"pygments_lexer": "ipython3",
217+
"version": "3.10.12"
218+
}
219+
},
220+
"nbformat": 4,
221+
"nbformat_minor": 2
222+
}

mujoco_sysid/mjx/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)