Skip to content

Commit ade9a4b

Browse files
Euler/Advance (#19)
- Adds euler integrator and advance function - Adds internal API for solve/factor m with explicit arguments
1 parent bb91077 commit ade9a4b

File tree

12 files changed

+438
-48
lines changed

12 files changed

+438
-48
lines changed

mujoco/mjx/__init__.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,21 @@
1+
# Copyright 2025 The Physics-Next Project Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
116
"""Public API for MJX."""
217

18+
from ._src.forward import euler
319
from ._src.forward import forward
420
from ._src.forward import fwd_acceleration
521
from ._src.forward import fwd_position
@@ -12,9 +28,9 @@
1228
from ._src.smooth import com_vel
1329
from ._src.smooth import crb
1430
from ._src.smooth import factor_m
15-
from ._src.smooth import solve_m
1631
from ._src.smooth import kinematics
1732
from ._src.smooth import rne
33+
from ._src.smooth import solve_m
1834
from ._src.support import is_sparse
1935
from ._src.test_util import benchmark
2036
from ._src.types import *

mujoco/mjx/_src/forward.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,210 @@
1+
# Copyright 2025 The Physics-Next Project Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
from typing import Optional
17+
118
import warp as wp
19+
20+
from . import math
221
from . import passive
322
from . import smooth
423

24+
from .types import array2df
525
from .types import Model
626
from .types import Data
27+
from .types import MJ_MINVAL
28+
from .types import MJ_DSBL_EULERDAMP
29+
30+
31+
def _advance(
32+
m: Model,
33+
d: Data,
34+
act_dot: wp.array,
35+
qacc: wp.array,
36+
qvel: Optional[wp.array] = None,
37+
) -> Data:
38+
"""Advance state and time given activation derivatives and acceleration."""
39+
40+
# TODO(team): can we assume static timesteps?
41+
42+
@wp.kernel
43+
def next_activation(
44+
m: Model,
45+
d: Data,
46+
act_dot_in: array2df,
47+
):
48+
worldId, actid = wp.tid()
49+
50+
# get the high/low range for each actuator state
51+
limited = m.actuator_actlimited[actid]
52+
range_low = wp.select(limited, -wp.inf, m.actuator_actrange[actid, 0])
53+
range_high = wp.select(limited, wp.inf, m.actuator_actrange[actid, 1])
54+
55+
# get the actual actuation - skip if -1 (means stateless actuator)
56+
act_adr = m.actuator_actadr[actid]
57+
if act_adr == -1:
58+
return
59+
60+
acts = d.act[worldId]
61+
acts_dot = act_dot_in[worldId]
62+
63+
act = acts[act_adr]
64+
act_dot = acts_dot[act_adr]
65+
66+
# check dynType
67+
dyn_type = m.actuator_dyntype[actid]
68+
dyn_prm = m.actuator_dynprm[actid, 0]
69+
70+
# advance the actuation
71+
if dyn_type == 3: # wp.static(WarpDynType.FILTEREXACT):
72+
tau = wp.select(dyn_prm < MJ_MINVAL, dyn_prm, MJ_MINVAL)
73+
act = act + act_dot * tau * (1.0 - wp.exp(-m.opt.timestep / tau))
74+
else:
75+
act = act + act_dot * m.opt.timestep
76+
77+
# apply limits
78+
wp.clamp(act, range_low, range_high)
79+
80+
acts[act_adr] = act
81+
82+
@wp.kernel
83+
def advance_velocities(m: Model, d: Data, qacc: array2df):
84+
worldId, tid = wp.tid()
85+
d.qvel[worldId, tid] = d.qvel[worldId, tid] + qacc[worldId, tid] * m.opt.timestep
86+
87+
@wp.kernel
88+
def integrate_joint_positions(m: Model, d: Data, qvel_in: array2df):
89+
worldId, jntid = wp.tid()
90+
91+
jnt_type = m.jnt_type[jntid]
92+
qpos_adr = m.jnt_qposadr[jntid]
93+
dof_adr = m.jnt_dofadr[jntid]
94+
qpos = d.qpos[worldId]
95+
qvel = qvel_in[worldId]
96+
97+
if jnt_type == 0: # free joint
98+
qpos_pos = wp.vec3(qpos[qpos_adr], qpos[qpos_adr + 1], qpos[qpos_adr + 2])
99+
qvel_lin = wp.vec3(qvel[dof_adr], qvel[dof_adr + 1], qvel[dof_adr + 2])
100+
101+
qpos_new = qpos_pos + m.opt.timestep * qvel_lin
102+
103+
qpos_quat = wp.quat(
104+
qpos[qpos_adr + 3],
105+
qpos[qpos_adr + 4],
106+
qpos[qpos_adr + 5],
107+
qpos[qpos_adr + 6],
108+
)
109+
qvel_ang = wp.vec3(qvel[dof_adr + 3], qvel[dof_adr + 4], qvel[dof_adr + 5])
110+
111+
qpos_quat_new = math.quat_integrate(qpos_quat, qvel_ang, m.opt.timestep)
112+
113+
qpos[qpos_adr] = qpos_new[0]
114+
qpos[qpos_adr + 1] = qpos_new[1]
115+
qpos[qpos_adr + 2] = qpos_new[2]
116+
qpos[qpos_adr + 3] = qpos_quat_new[0]
117+
qpos[qpos_adr + 4] = qpos_quat_new[1]
118+
qpos[qpos_adr + 5] = qpos_quat_new[2]
119+
qpos[qpos_adr + 6] = qpos_quat_new[3]
120+
121+
elif jnt_type == 1: # ball joint
122+
qpos_quat = wp.quat(
123+
qpos[qpos_adr],
124+
qpos[qpos_adr + 1],
125+
qpos[qpos_adr + 2],
126+
qpos[qpos_adr + 3],
127+
)
128+
qvel_ang = wp.vec3(qvel[dof_adr], qvel[dof_adr + 1], qvel[dof_adr + 2])
129+
130+
qpos_quat_new = math.quat_integrate(qpos_quat, qvel_ang, m.opt.timestep)
131+
132+
qpos[qpos_adr] = qpos_quat_new[0]
133+
qpos[qpos_adr + 1] = qpos_quat_new[1]
134+
qpos[qpos_adr + 2] = qpos_quat_new[2]
135+
qpos[qpos_adr + 3] = qpos_quat_new[3]
136+
137+
else: # if jnt_type in (JointType.HINGE, JointType.SLIDE):
138+
qpos[qpos_adr] = qpos[qpos_adr] + m.opt.timestep * qvel[dof_adr]
139+
140+
# skip if no stateful actuators.
141+
if m.na:
142+
wp.launch(next_activation, dim=(d.nworld, m.nu), inputs=[m, d, act_dot])
143+
144+
wp.launch(advance_velocities, dim=(d.nworld, m.nv), inputs=[m, d, qacc])
145+
146+
# advance positions with qvel if given, d.qvel otherwise (semi-implicit)
147+
if qvel is not None:
148+
qvel_in = qvel
149+
else:
150+
qvel_in = d.qvel
151+
152+
wp.launch(integrate_joint_positions, dim=(d.nworld, m.njnt), inputs=[m, d, qvel_in])
153+
154+
d.time = d.time + m.opt.timestep
155+
return d
156+
157+
158+
def euler(m: Model, d: Data) -> Data:
159+
"""Euler integrator, semi-implicit in velocity."""
160+
# integrate damping implicitly
161+
162+
def add_damping_sum_qfrc(m: Model, d: Data, is_sparse: bool):
163+
@wp.kernel
164+
def add_damping_sum_qfrc_kernel_sparse(m: Model, d: Data):
165+
worldId, tid = wp.tid()
166+
167+
dof_Madr = m.dof_Madr[tid]
168+
d.qM_integration[worldId, 0, dof_Madr] += m.opt.timestep * m.dof_damping[dof_Madr]
169+
170+
d.qfrc_integration[worldId, tid] = (
171+
d.qfrc_smooth[worldId, tid] + d.qfrc_constraint[worldId, tid]
172+
)
173+
174+
@wp.kernel
175+
def add_damping_sum_qfrc_kernel_dense(m: Model, d: Data):
176+
worldid, i, j = wp.tid()
177+
178+
damping = wp.select(i == j, 0.0, m.opt.timestep * m.dof_damping[i])
179+
d.qM_integration[worldid, i, j] = d.qM[worldid, i, j] + damping
180+
181+
if i == 0:
182+
d.qfrc_integration[worldid, j] = (
183+
d.qfrc_smooth[worldid, j] + d.qfrc_constraint[worldid, j]
184+
)
185+
186+
if is_sparse:
187+
wp.copy(d.qM_integration, d.qM)
188+
wp.launch(add_damping_sum_qfrc_kernel_sparse, dim=(d.nworld, m.nv), inputs=[m, d])
189+
else:
190+
wp.launch(
191+
add_damping_sum_qfrc_kernel_dense, dim=(d.nworld, m.nv, m.nv), inputs=[m, d]
192+
)
193+
194+
if not m.opt.disableflags & MJ_DSBL_EULERDAMP:
195+
add_damping_sum_qfrc(m, d, m.opt.is_sparse)
196+
smooth.factor_i(m, d, d.qM_integration, d.qLD_integration, d.qLDiagInv_integration)
197+
smooth.solve_LD(
198+
m,
199+
d,
200+
d.qLD_integration,
201+
d.qLDiagInv_integration,
202+
d.qacc_integration,
203+
d.qfrc_integration,
204+
)
205+
return _advance(m, d, d.act_dot, d.qacc_integration)
206+
207+
return _advance(m, d, d.act_dot, d.qacc)
7208

8209

9210
def fwd_position(m: Model, d: Data):

mujoco/mjx/_src/forward_test.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717

1818
from absl.testing import absltest
1919
from etils import epath
20-
import mujoco
21-
from mujoco import mjx
2220
import numpy as np
2321
import warp as wp
2422

23+
import mujoco
24+
from mujoco import mjx
25+
2526
# tolerance for difference between MuJoCo and MJX smooth calculations - mostly
2627
# due to float precision
2728
_TOLERANCE = 5e-5
@@ -71,6 +72,58 @@ def test_fwd_acceleration(self):
7172
_assert_eq(d.qfrc_smooth.numpy()[0], mjd.qfrc_smooth, "qfrc_smooth")
7273
_assert_eq(d.qacc_smooth.numpy()[0], mjd.qacc_smooth, "qacc_smooth")
7374

75+
def test_eulerdamp(self):
76+
path = epath.resource_path("mujoco.mjx") / "test_data/pendula.xml"
77+
mjm = mujoco.MjModel.from_xml_path(path.as_posix())
78+
self.assertTrue((mjm.dof_damping > 0).any())
79+
80+
mjd = mujoco.MjData(mjm)
81+
mjd.qvel[:] = 1.0
82+
mjd.qacc[:] = 1.0
83+
mujoco.mj_forward(mjm, mjd)
84+
85+
m = mjx.put_model(mjm)
86+
d = mjx.put_data(mjm, mjd)
87+
88+
mjx.euler(m, d)
89+
mujoco.mj_Euler(mjm, mjd)
90+
91+
_assert_eq(d.qpos.numpy()[0], mjd.qpos, "qpos")
92+
_assert_eq(d.act.numpy()[0], mjd.act, "act")
93+
94+
# also test sparse
95+
mjm.opt.jacobian = mujoco.mjtJacobian.mjJAC_SPARSE
96+
mjd = mujoco.MjData(mjm)
97+
mjd.qvel[:] = 1.0
98+
mjd.qacc[:] = 1.0
99+
mujoco.mj_forward(mjm, mjd)
100+
101+
m = mjx.put_model(mjm)
102+
d = mjx.put_data(mjm, mjd)
103+
104+
mjx.euler(m, d)
105+
mujoco.mj_Euler(mjm, mjd)
106+
107+
_assert_eq(d.qpos.numpy()[0], mjd.qpos, "qpos")
108+
_assert_eq(d.act.numpy()[0], mjd.act, "act")
109+
110+
def test_disable_eulerdamp(self):
111+
path = epath.resource_path("mujoco.mjx") / "test_data/pendula.xml"
112+
mjm = mujoco.MjModel.from_xml_path(path.as_posix())
113+
mjm.opt.disableflags = mjm.opt.disableflags | mujoco.mjtDisableBit.mjDSBL_EULERDAMP
114+
115+
mjd = mujoco.MjData(mjm)
116+
mujoco.mj_forward(mjm, mjd)
117+
mjd.qvel[:] = 1.0
118+
mjd.qacc[:] = 1.0
119+
120+
m = mjx.put_model(mjm)
121+
d = mjx.put_data(mjm, mjd)
122+
123+
mjx.euler(m, d)
124+
125+
np.testing.assert_allclose(d.qvel.numpy()[0], 1 + mjm.opt.timestep)
126+
74127

75128
if __name__ == "__main__":
76129
wp.init()

0 commit comments

Comments
 (0)