Skip to content

Commit 60cf843

Browse files
Merge pull request #156 from Andrew-Luo1:rscope
PiperOrigin-RevId: 777744280 Change-Id: I5acbb1c47ed108568244281fe856c1d0e3c0d473
2 parents 81dfe51 + 6170d77 commit 60cf843

File tree

26 files changed

+115
-65
lines changed

26 files changed

+115
-65
lines changed

mujoco_playground/_src/dm_control_suite/acrobot.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,9 @@ def __init__(
5757
self._margin = 0.0 if sparse else 1.0
5858

5959
self._xml_path = _XML_PATH.as_posix()
60+
self._model_assets = common.get_assets()
6061
self._mj_model = mujoco.MjModel.from_xml_string(
61-
_XML_PATH.read_text(), common.get_assets()
62+
_XML_PATH.read_text(), self._model_assets
6263
)
6364
self._mj_model.opt.timestep = self.sim_dt
6465
self._mjx_model = mjx.put_model(self._mj_model)

mujoco_playground/_src/dm_control_suite/ball_in_cup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,9 @@ def __init__(
5353
)
5454

5555
self._xml_path = _XML_PATH.as_posix()
56+
self._model_assets = common.get_assets()
5657
self._mj_model = mujoco.MjModel.from_xml_string(
57-
_XML_PATH.read_text(), common.get_assets()
58+
_XML_PATH.read_text(), self._model_assets
5859
)
5960
self._mj_model.opt.timestep = self.sim_dt
6061
self._mjx_model = mjx.put_model(self._mj_model)

mujoco_playground/_src/dm_control_suite/cheetah.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,9 @@ def __init__(
5656
)
5757

5858
self._xml_path = _XML_PATH.as_posix()
59+
self._model_assets = common.get_assets()
5960
self._mj_model = mujoco.MjModel.from_xml_string(
60-
_XML_PATH.read_text(), common.get_assets()
61+
_XML_PATH.read_text(), self._model_assets
6162
)
6263
self._mj_model.opt.timestep = self.sim_dt
6364
self._mjx_model = mjx.put_model(self._mj_model)

mujoco_playground/_src/dm_control_suite/finger.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ def default_config() -> config_dict.ConfigDict:
5252

5353

5454
def _make_turn_model(
55-
xml_path: epath.Path, target_radius: float
55+
xml_path: epath.Path, target_radius: float, assets: Dict[str, Any]
5656
) -> mujoco.MjModel:
57-
spec = mujoco.MjSpec.from_string(xml_path.read_text(), common.get_assets())
57+
spec = mujoco.MjSpec.from_string(xml_path.read_text(), assets)
5858
target_site = None
5959
for site in spec.sites:
6060
if site.name == "target":
@@ -65,10 +65,10 @@ def _make_turn_model(
6565
return spec.compile()
6666

6767

68-
def _make_spin_model(xml_path: epath.Path) -> mujoco.MjModel:
69-
model = mujoco.MjModel.from_xml_string(
70-
xml_path.read_text(), common.get_assets()
71-
)
68+
def _make_spin_model(
69+
xml_path: epath.Path, assets: Dict[str, Any]
70+
) -> mujoco.MjModel:
71+
model = mujoco.MjModel.from_xml_string(xml_path.read_text(), assets)
7272
model.site_rgba[model.site("target").id, 3] = 0
7373
model.site_rgba[model.site("tip").id, 3] = 0
7474
model.dof_damping[model.joint("hinge").id] = 0.03
@@ -90,7 +90,8 @@ def __init__(
9090
)
9191

9292
self._xml_path = _XML_PATH.as_posix()
93-
self._mj_model = _make_spin_model(_XML_PATH)
93+
self._model_assets = common.get_assets()
94+
self._mj_model = _make_spin_model(_XML_PATH, self._model_assets)
9495
self._mj_model.opt.timestep = self.sim_dt
9596
self._mjx_model = mjx.put_model(self._mj_model)
9697
self._post_init()
@@ -214,7 +215,10 @@ def __init__(
214215
)
215216

216217
self._xml_path = _XML_PATH.as_posix()
217-
self._mj_model = _make_turn_model(_XML_PATH, target_radius)
218+
self._model_assets = common.get_assets()
219+
self._mj_model = _make_turn_model(
220+
_XML_PATH, target_radius, self._model_assets
221+
)
218222
self._mj_model.opt.timestep = self.sim_dt
219223
self._mjx_model = mjx.put_model(self._mj_model)
220224
self._post_init()

mujoco_playground/_src/dm_control_suite/fish.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,9 @@ def __init__(
6464
)
6565

6666
self._xml_path = _XML_PATH.as_posix()
67+
self._model_assets = common.get_assets()
6768
self._mj_model = mujoco.MjModel.from_xml_string(
68-
_XML_PATH.read_text(), common.get_assets()
69+
_XML_PATH.read_text(), self._model_assets
6970
)
7071
self._mj_model.opt.timestep = self.sim_dt
7172
self._mjx_model = mjx.put_model(self._mj_model)

mujoco_playground/_src/dm_control_suite/hopper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,9 @@ def __init__(
7373
]
7474

7575
self._xml_path = _XML_PATH.as_posix()
76+
self._model_assets = common.get_assets()
7677
self._mj_model = mujoco.MjModel.from_xml_string(
77-
_XML_PATH.read_text(), common.get_assets()
78+
_XML_PATH.read_text(), self._model_assets
7879
)
7980
self._mj_model.opt.timestep = self.sim_dt
8081
self._mjx_model = mjx.put_model(self._mj_model)

mujoco_playground/_src/dm_control_suite/humanoid.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,9 @@ def __init__(
6767
self._stand_or_move_reward = self._move_reward
6868

6969
self._xml_path = _XML_PATH.as_posix()
70+
self._model_assets = common.get_assets()
7071
self._mj_model = mujoco.MjModel.from_xml_string(
71-
_XML_PATH.read_text(), common.get_assets()
72+
_XML_PATH.read_text(), self._model_assets
7273
)
7374
self._mj_model.opt.timestep = self.sim_dt
7475
self._mjx_model = mjx.put_model(self._mj_model)

mujoco_playground/_src/dm_control_suite/pendulum.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,9 @@ def __init__(
5757
)
5858

5959
self._xml_path = _XML_PATH.as_posix()
60+
self._model_assets = common.get_assets()
6061
self._mj_model = mujoco.MjModel.from_xml_string(
61-
_XML_PATH.read_text(), common.get_assets()
62+
_XML_PATH.read_text(), self._model_assets
6263
)
6364
self._mj_model.opt.timestep = self.sim_dt
6465
self._mjx_model = mjx.put_model(self._mj_model)

mujoco_playground/_src/dm_control_suite/point_mass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,9 @@ def __init__(
5454
)
5555

5656
self._xml_path = _XML_PATH.as_posix()
57+
self._model_assets = common.get_assets()
5758
self._mj_model = mujoco.MjModel.from_xml_string(
58-
_XML_PATH.read_text(), common.get_assets()
59+
_XML_PATH.read_text(), self._model_assets
5960
)
6061
self._mj_model.opt.timestep = self.sim_dt
6162
self._mjx_model = mjx.put_model(self._mj_model)

mujoco_playground/_src/dm_control_suite/reacher.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,10 @@ def default_config() -> config_dict.ConfigDict:
4242
)
4343

4444

45-
def _make_model(xml_path: epath.Path, target_size: float) -> mujoco.MjModel:
46-
spec = mujoco.MjSpec.from_string(xml_path.read_text(), common.get_assets())
45+
def _make_model(
46+
xml_path: epath.Path, target_size: float, assets: Dict[str, Any]
47+
) -> mujoco.MjModel:
48+
spec = mujoco.MjSpec.from_string(xml_path.read_text(), assets)
4749
if mujoco.__version__ >= "3.3.0":
4850
target_body = spec.body("target")
4951
else:
@@ -70,7 +72,8 @@ def __init__(
7072

7173
self._target_size = target_size
7274
self._xml_path = _XML_PATH.as_posix()
73-
self._mj_model = _make_model(_XML_PATH, target_size)
75+
self._model_assets = common.get_assets()
76+
self._mj_model = _make_model(_XML_PATH, target_size, self._model_assets)
7477
self._mj_model.opt.timestep = self.sim_dt
7578
self._mjx_model = mjx.put_model(self._mj_model)
7679
self._post_init()

0 commit comments

Comments
 (0)