Skip to content

Commit ee06d3f

Browse files
btabacopybara-github
authored andcommitted
add back #9
PiperOrigin-RevId: 718684254 Change-Id: Ie9a5a9653224b9be47024c600c768ae21b3556e9
1 parent 532ecd7 commit ee06d3f

23 files changed

+162
-127
lines changed

learning/train_jax_ppo.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,12 @@
1414
# ==============================================================================
1515
"""Train a PPO agent using JAX on the specified environment."""
1616

17-
import os
18-
19-
xla_flags = os.environ.get("XLA_FLAGS", "")
20-
xla_flags += " --xla_gpu_triton_gemm_any=True"
21-
os.environ["XLA_FLAGS"] = xla_flags
22-
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
23-
os.environ["MUJOCO_GL"] = "egl"
24-
2517
from datetime import datetime
2618
import functools
2719
import json
20+
import os
2821
import time
22+
import warnings
2923

3024
from absl import app
3125
from absl import flags
@@ -39,7 +33,6 @@
3933
import jax.numpy as jp
4034
import mediapy as media
4135
from ml_collections import config_dict
42-
from ml_collections import config_flags
4336
import mujoco
4437
from orbax import checkpoint as ocp
4538
from tensorboardX import SummaryWriter
@@ -52,11 +45,16 @@
5245
from mujoco_playground.config import locomotion_params
5346
from mujoco_playground.config import manipulation_params
5447

48+
xla_flags = os.environ.get("XLA_FLAGS", "")
49+
xla_flags += " --xla_gpu_triton_gemm_any=True"
50+
os.environ["XLA_FLAGS"] = xla_flags
51+
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
52+
os.environ["MUJOCO_GL"] = "egl"
53+
5554
# Ignore the info logs from brax
5655
logging.set_verbosity(logging.WARNING)
5756

5857
# Suppress warnings
59-
import warnings
6058

6159
# Suppress RuntimeWarnings from JAX
6260
warnings.filterwarnings("ignore", category=RuntimeWarning, module="jax")
@@ -267,11 +265,11 @@ def main(argv):
267265
print(f"Checkpoint path: {ckpt_path}")
268266

269267
# Save environment configuration
270-
with open(ckpt_path / "config.json", "w") as fp:
268+
with open(ckpt_path / "config.json", "w", encoding="utf-8") as fp:
271269
json.dump(env_cfg.to_json(), fp, indent=4)
272270

273271
# Define policy parameters function for saving checkpoints
274-
def policy_params_fn(current_step, make_policy, params):
272+
def policy_params_fn(current_step, make_policy, params): # pylint: disable=unused-argument
275273
orbax_checkpointer = ocp.PyTreeCheckpointer()
276274
save_args = orbax_utils.save_args_from_target(params)
277275
path = ckpt_path / f"{current_step}"
@@ -352,7 +350,7 @@ def progress(num_steps, metrics):
352350
)
353351

354352
# Train or load the model
355-
make_inference_fn, params, _ = train_fn(
353+
make_inference_fn, params, _ = train_fn( # pylint: disable=no-value-for-parameter
356354
environment=env,
357355
progress_fn=progress,
358356
eval_env=None if _VISION.value else eval_env,
@@ -389,7 +387,7 @@ def progress(num_steps, metrics):
389387
rollout = [state0]
390388

391389
# Run evaluation rollout
392-
for i in range(env_cfg.episode_length):
390+
for _ in range(env_cfg.episode_length):
393391
act_rng, rng = jax.random.split(rng)
394392
ctrl, _ = jit_inference_fn(state.obs, act_rng)
395393
state = jit_step(state, ctrl)

learning/train_rsl_rl.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15+
# pylint: disable=wrong-import-position
1516
"""Train a PPO agent using RSL-RL for the specified environment."""
1617

1718
import os
@@ -28,7 +29,6 @@
2829
from absl import flags
2930
from absl import logging
3031
import jax
31-
import jax.numpy as jp
3232
import mediapy as media
3333
from ml_collections import config_dict
3434
import mujoco
@@ -136,7 +136,9 @@ def main(argv):
136136
wandb.config.update({"env_name": _ENV_NAME.value})
137137

138138
# Save environment config to JSON
139-
with open(os.path.join(ckpt_path, "config.json"), "w") as fp:
139+
with open(
140+
os.path.join(ckpt_path, "config.json"), "w", encoding="utf-8"
141+
) as fp:
140142
json.dump(env_cfg.to_json(), fp, indent=4)
141143

142144
# Domain randomization
@@ -146,7 +148,7 @@ def main(argv):
146148
render_trajectory = []
147149

148150
# Callback to gather states for rendering
149-
def render_callback(env, state):
151+
def render_callback(_, state):
150152
render_trajectory.append(state)
151153

152154
# Create the environment
@@ -231,8 +233,11 @@ def render_callback(env, state):
231233
fps = 1.0 / base_env.dt / render_every
232234
traj = rollout[::render_every]
233235
frames = eval_env.render(
234-
traj, camera=_CAMERA.value, height=480, width=640,
235-
scene_option=scene_option
236+
traj,
237+
camera=_CAMERA.value,
238+
height=480,
239+
width=640,
240+
scene_option=scene_option,
236241
)
237242
media.write_video("rollout.mp4", frames, fps=fps)
238243
print("Rollout video saved as 'rollout.mp4'.")

mujoco_playground/_src/locomotion/g1/randomize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15+
"""Utilities for randomization."""
1516
import jax
1617
from mujoco import mjx
1718

mujoco_playground/_src/locomotion/locomotion_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525

2626
class TestSuite(parameterized.TestCase):
27+
"""Tests for the locomotion environments."""
2728

2829
@parameterized.named_parameters(
2930
{"testcase_name": f"test_can_create_{env_name}", "env_name": env_name}

mujoco_playground/_src/manipulation/franka_emika_panda/open_cabinet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
import jax
2020
import jax.numpy as jp
2121
from ml_collections import config_dict
22-
import mujoco
2322
from mujoco import mjx
23+
import mujoco # pylint: disable=unused-import
2424
from mujoco.mjx._src import math
2525

2626
from mujoco_playground._src import collision

mujoco_playground/_src/manipulation/franka_emika_panda/pick_cartesian.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class PandaPickCubeCartesian(pick.PandaPickCube):
8888
"""Environment for training the Franka Panda robot to pick up a cube in
8989
Cartesian space."""
9090

91-
def __init__(
91+
def __init__( # pylint: disable=non-parent-init-called,super-init-not-called
9292
self,
9393
config=default_config(),
9494
config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None,

mujoco_playground/_src/manipulation/franka_emika_panda/randomize_vision.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ def domain_randomize(
6060
) -> Tuple[mjx.Model, mjx.Model]:
6161
"""Tile the necessary axes for the Madrona BatchRenderer."""
6262
mj_model = pick_cartesian.PandaPickCubeCartesian().mj_model
63-
FLOOR_GEOM_ID = mj_model.geom('floor').id
64-
BOX_GEOM_ID = mj_model.geom('box').id
65-
STRIP_GEOM_ID = mj_model.geom('init_space').id
63+
floor_geom_id = mj_model.geom('floor').id
64+
box_geom_id = mj_model.geom('box').id
65+
strip_geom_id = mj_model.geom('init_space').id
6666

6767
in_axes = jax.tree_util.tree_map(lambda x: None, mjx_model)
6868
in_axes = in_axes.tree_replace({
@@ -93,16 +93,16 @@ def rand(rng: jax.Array, light_position: jax.Array):
9393
rgba = jp.array(
9494
[jax.random.uniform(key_box, (), minval=0.5, maxval=1.0), 0.0, 0.0, 1.0]
9595
)
96-
geom_rgba = mjx_model.geom_rgba.at[BOX_GEOM_ID].set(rgba)
96+
geom_rgba = mjx_model.geom_rgba.at[box_geom_id].set(rgba)
9797

9898
strip_white = jax.random.uniform(key_strip, (), minval=0.8, maxval=1.0)
99-
geom_rgba = geom_rgba.at[STRIP_GEOM_ID].set(
99+
geom_rgba = geom_rgba.at[strip_geom_id].set(
100100
jp.array([strip_white, strip_white, strip_white, 1.0])
101101
)
102102

103103
# Sample a shade of gray
104104
gray_scale = jax.random.uniform(key_floor, (), minval=0.0, maxval=0.25)
105-
geom_rgba = geom_rgba.at[FLOOR_GEOM_ID].set(
105+
geom_rgba = geom_rgba.at[floor_geom_id].set(
106106
jp.array([gray_scale, gray_scale, gray_scale, 1.0])
107107
)
108108

@@ -112,11 +112,11 @@ def rand(rng: jax.Array, light_position: jax.Array):
112112
jax.random.randint(key_matid, shape=(num_geoms,), minval=0, maxval=10)
113113
+ mat_offset
114114
)
115-
geom_matid = geom_matid.at[BOX_GEOM_ID].set(
115+
geom_matid = geom_matid.at[box_geom_id].set(
116116
-2
117117
) # Use the above randomized colors
118-
geom_matid = geom_matid.at[FLOOR_GEOM_ID].set(-2)
119-
geom_matid = geom_matid.at[STRIP_GEOM_ID].set(-2)
118+
geom_matid = geom_matid.at[floor_geom_id].set(-2)
119+
geom_matid = geom_matid.at[strip_geom_id].set(-2)
120120

121121
#### Cameras ####
122122
key_pos, key_ori, key = jax.random.split(key, 3)
@@ -134,7 +134,7 @@ def rand(rng: jax.Array, light_position: jax.Array):
134134
assert (
135135
nlight == 1
136136
), f'Sim2Real was trained with a single light source, got {nlight}'
137-
key_lsha, key_ldir, key_ldct, key = jax.random.split(key, 4)
137+
key_lsha, key_ldir, key = jax.random.split(key, 3)
138138

139139
# Direction
140140
shine_at = jp.array([0.661, -0.001, 0.179]) # Gripper starting position

mujoco_playground/_src/manipulation/leap_hand/rotate_z.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import numpy as np
2424

2525
from mujoco_playground._src import mjx_env
26-
from mujoco_playground._src import reward
2726
from mujoco_playground._src.manipulation.leap_hand import base as leap_hand_base
2827
from mujoco_playground._src.manipulation.leap_hand import leap_hand_constants as consts
2928

@@ -145,7 +144,7 @@ def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State:
145144
rewards = {
146145
k: v * self._config.reward_config.scales[k] for k, v in rewards.items()
147146
}
148-
reward = sum(rewards.values()) * self.dt
147+
reward = sum(rewards.values()) * self.dt # pylint: disable=redefined-outer-name
149148

150149
state.info["last_last_act"] = state.info["last_act"]
151150
state.info["last_act"] = action

mujoco_playground/_src/manipulation/manipulation_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525

2626
class TestSuite(parameterized.TestCase):
27+
"""Tests for the manipulation environments."""
2728

2829
@parameterized.named_parameters(
2930
{"testcase_name": f"test_can_create_{env_name}", "env_name": env_name}

mujoco_playground/_src/mjx_env.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import numpy as np
2929
import tqdm
3030

31-
3231
# Root path is used for loading XML strings directly using etils.epath.
3332
ROOT_PATH = epath.Path(__file__).parent
3433
# Base directory for external dependencies.

0 commit comments

Comments
 (0)