Skip to content

Commit bb91077

Browse files
authored
Merge pull request #18 from adenzler-nvidia/dev/adenzler/fix-tile-fn
Run formatter and fix leftover occurrence of tile_fn in io.py
2 parents 95d89e2 + f6b5c4d commit bb91077

File tree

9 files changed

+216
-59
lines changed

9 files changed

+216
-59
lines changed

.github/workflows/ruff.yml

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright 2025 The Physics-Next Project Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
name: Ruff Format & Lint
17+
18+
on:
19+
push:
20+
branches: [main]
21+
pull_request:
22+
branches: [main]
23+
24+
jobs:
25+
check:
26+
runs-on: ubuntu-latest
27+
steps:
28+
- uses: actions/checkout@v4
29+
30+
- name: Set up Python
31+
uses: actions/setup-python@v5
32+
with:
33+
python-version: "3.x"
34+
35+
- name: Install Ruff
36+
run: pip install ruff
37+
38+
# Check formatting (fail if files aren't formatted)
39+
- name: Check Formatting
40+
run: ruff format --check .
41+
42+
# Check linting (fail on rule violations)
43+
#- name: Check Linting
44+
# run: ruff check --output-format=github .

mujoco/mjx/_src/forward_test.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
# Copyright 2025 The Physics-Next Project Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
116
"""Tests for forward dynamics functions."""
217

318
from absl.testing import absltest
@@ -14,14 +29,13 @@
1429

1530
def _assert_eq(a, b, name):
1631
tol = _TOLERANCE * 10 # avoid test noise
17-
err_msg = f'mismatch: {name}'
32+
err_msg = f"mismatch: {name}"
1833
np.testing.assert_allclose(a, b, err_msg=err_msg, atol=tol, rtol=tol)
1934

2035

2136
class ForwardTest(absltest.TestCase):
22-
2337
def _load(self, fname: str, is_sparse: bool = True):
24-
path = epath.resource_path('mujoco.mjx') / 'test_data' / fname
38+
path = epath.resource_path("mujoco.mjx") / "test_data" / fname
2539
mjm = mujoco.MjModel.from_xml_path(path.as_posix())
2640
mjm.opt.jacobian = is_sparse
2741
mjd = mujoco.MjData(mjm)
@@ -34,28 +48,30 @@ def _load(self, fname: str, is_sparse: bool = True):
3448

3549
def test_fwd_velocity(self):
3650
"""Tests MJX fwd_velocity."""
37-
_, mjd, m, d = self._load('humanoid/humanoid.xml')
38-
51+
_, mjd, m, d = self._load("humanoid/humanoid.xml")
52+
3953
d.actuator_velocity.zero_()
4054
mjx.fwd_velocity(m, d)
4155

42-
_assert_eq(d.actuator_velocity.numpy()[0], mjd.actuator_velocity, 'actuator_velocity')
43-
_assert_eq(d.qfrc_bias.numpy()[0], mjd.qfrc_bias, 'qfrc_bias')
56+
_assert_eq(
57+
d.actuator_velocity.numpy()[0], mjd.actuator_velocity, "actuator_velocity"
58+
)
59+
_assert_eq(d.qfrc_bias.numpy()[0], mjd.qfrc_bias, "qfrc_bias")
4460

4561
def test_fwd_acceleration(self):
4662
"""Tests MJX fwd_acceleration."""
47-
_, mjd, m, d = self._load('humanoid/humanoid.xml', is_sparse=False)
63+
_, mjd, m, d = self._load("humanoid/humanoid.xml", is_sparse=False)
4864

4965
for arr in (d.qfrc_smooth, d.qacc_smooth):
5066
arr.zero_()
5167

52-
mjx.factor_m(m, d) # for dense, get tile cholesky factorization
68+
mjx.factor_m(m, d) # for dense, get tile cholesky factorization
5369
mjx.fwd_acceleration(m, d)
5470

55-
_assert_eq(d.qfrc_smooth.numpy()[0], mjd.qfrc_smooth, 'qfrc_smooth')
56-
_assert_eq(d.qacc_smooth.numpy()[0], mjd.qacc_smooth, 'qacc_smooth')
71+
_assert_eq(d.qfrc_smooth.numpy()[0], mjd.qfrc_smooth, "qfrc_smooth")
72+
_assert_eq(d.qacc_smooth.numpy()[0], mjd.qacc_smooth, "qacc_smooth")
5773

5874

59-
if __name__ == '__main__':
75+
if __name__ == "__main__":
6076
wp.init()
6177
absltest.main()

mujoco/mjx/_src/io.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
# Copyright 2025 The Physics-Next Project Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
116
import warp as wp
217
import mujoco
318
import numpy as np
@@ -19,7 +34,7 @@ def put_model(mjm: mujoco.MjModel) -> types.Model:
1934
m.nM = mjm.nM
2035
m.opt.gravity = wp.vec3(mjm.opt.gravity)
2136
m.opt.is_sparse = support.is_sparse(mjm)
22-
37+
2338
m.qpos0 = wp.array(mjm.qpos0, dtype=wp.float32, ndim=1)
2439
m.qpos_spring = wp.array(mjm.qpos_spring, dtype=wp.float32, ndim=1)
2540

@@ -75,10 +90,12 @@ def put_model(mjm: mujoco.MjModel) -> types.Model:
7590
qLD_tilesize = np.array(sorted(tiles.keys()))
7691

7792
m.qLD_update_tree = wp.array(qLD_update_tree, dtype=wp.vec3i, ndim=1)
78-
m.qLD_update_treeadr = wp.array(qLD_update_treeadr, dtype=wp.int32, ndim=1, device="cpu")
93+
m.qLD_update_treeadr = wp.array(
94+
qLD_update_treeadr, dtype=wp.int32, ndim=1, device="cpu"
95+
)
7996
m.qLD_tile = wp.array(qLD_tile, dtype=wp.int32, ndim=1)
80-
m.qLD_tileadr = wp.array(qLD_tileadr, dtype=wp.int32, ndim=1, device='cpu')
81-
m.qLD_tilesize = wp.array(qLD_tilesize, dtype=wp.int32, ndim=1, device='cpu')
97+
m.qLD_tileadr = wp.array(qLD_tileadr, dtype=wp.int32, ndim=1, device="cpu")
98+
m.qLD_tilesize = wp.array(qLD_tilesize, dtype=wp.int32, ndim=1, device="cpu")
8299
m.body_dofadr = wp.array(mjm.body_dofadr, dtype=wp.int32, ndim=1)
83100
m.body_dofnum = wp.array(mjm.body_dofnum, dtype=wp.int32, ndim=1)
84101
m.body_jntadr = wp.array(mjm.body_jntadr, dtype=wp.int32, ndim=1)
@@ -179,7 +196,13 @@ def tile(x):
179196

180197
# TODO(taylorhowell): sparse actuator_moment
181198
actuator_moment = np.zeros((mjm.nu, mjm.nv))
182-
mujoco.mju_sparse2dense(actuator_moment, mjd.actuator_moment, mjd.moment_rownnz, mjd.moment_rowadr, mjd.moment_colind)
199+
mujoco.mju_sparse2dense(
200+
actuator_moment,
201+
mjd.actuator_moment,
202+
mjd.moment_rownnz,
203+
mjd.moment_rowadr,
204+
mjd.moment_colind,
205+
)
183206

184207
d.qpos = wp.array(tile(mjd.qpos), dtype=wp.float32, ndim=2)
185208
d.qvel = wp.array(tile(mjd.qvel), dtype=wp.float32, ndim=2)

mujoco/mjx/_src/math.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,23 @@
1+
# Copyright 2025 The Physics-Next Project Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
116
import warp as wp
217

318
from . import types
419

20+
521
@wp.func
622
def mul_quat(u: wp.quat, v: wp.quat) -> wp.quat:
723
return wp.quat(
@@ -91,7 +107,7 @@ def motion_cross_force(v: wp.spatial_vector, f: wp.spatial_vector) -> wp.spatial
91107

92108
@wp.func
93109
def quat_to_vel(quat: wp.quat) -> wp.vec3:
94-
axis = wp.vec3(quat[1], quat[2], quat[3])
110+
axis = wp.vec3(quat[1], quat[2], quat[3])
95111
sin_a_2 = wp.norm_l2(axis)
96112

97113
if sin_a_2 == 0.0:

mujoco/mjx/_src/passive_test.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
# Copyright 2025 The Physics-Next Project Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
116
"""Tests for passive force functions."""
217

318
from absl.testing import absltest
@@ -11,28 +26,28 @@
1126
# due to float precision
1227
_TOLERANCE = 5e-5
1328

29+
1430
def _assert_eq(a, b, name):
1531
tol = _TOLERANCE * 10 # avoid test noise
16-
err_msg = f'mismatch: {name}'
32+
err_msg = f"mismatch: {name}"
1733
np.testing.assert_allclose(a, b, err_msg=err_msg, atol=tol, rtol=tol)
1834

1935

2036
class PassiveTest(absltest.TestCase):
21-
2237
def test_passive(self):
2338
"""Tests MJX passive."""
24-
_, mjd, m, d = test_util.fixture('pendula.xml')
39+
_, mjd, m, d = test_util.fixture("pendula.xml")
2540

2641
for arr in (d.qfrc_spring, d.qfrc_damper, d.qfrc_passive):
2742
arr.zero_()
2843

2944
mjx.passive(m, d)
3045

31-
_assert_eq(d.qfrc_spring.numpy()[0], mjd.qfrc_spring, 'qfrc_spring')
32-
_assert_eq(d.qfrc_damper.numpy()[0], mjd.qfrc_damper, 'qfrc_damper')
33-
_assert_eq(d.qfrc_passive.numpy()[0], mjd.qfrc_passive, 'qfrc_passive')
46+
_assert_eq(d.qfrc_spring.numpy()[0], mjd.qfrc_spring, "qfrc_spring")
47+
_assert_eq(d.qfrc_damper.numpy()[0], mjd.qfrc_damper, "qfrc_damper")
48+
_assert_eq(d.qfrc_passive.numpy()[0], mjd.qfrc_passive, "qfrc_passive")
3449

3550

36-
if __name__ == '__main__':
51+
if __name__ == "__main__":
3752
wp.init()
3853
absltest.main()

0 commit comments

Comments
 (0)