Skip to content

Commit f612959

Browse files
btabacopybara-github
authored andcommitted
Save depth images in visualize_render.py
PiperOrigin-RevId: 871454618 Change-Id: If844e2c4af10e271536c4e6a36e760e8d4d50374
1 parent 448e221 commit f612959

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

mjx/mujoco/mjx/warp/visualize_render.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,23 +175,40 @@ def init(worldid):
175175
)(mx, dx_batch, rc)
176176

177177
rgb_packed = out_batch[0]
178-
print(f' rgb shape: {rgb_packed.shape}\n')
178+
depth_packed = out_batch[1]
179+
print(f' rgb shape: {rgb_packed.shape}')
180+
print(f' depth shape: {depth_packed.shape}\n')
179181

180182
rgb = jax.vmap(
181183
render_util.get_rgb, in_axes=(None, 0, None)
182184
)(rc, rgb_packed, _CAMERA_ID.value)
183185

186+
depth = jax.vmap(
187+
render_util.get_depth, in_axes=(None, 0, None, None)
188+
)(rc, depth_packed, _CAMERA_ID.value, 10.0)
189+
184190
single_path = os.path.join(
185191
_OUTPUT_DIR.value, f'camera_{_CAMERA_ID.value}.png'
186192
)
187193
_save_single(rgb, single_path)
188194

195+
depth_rgb = np.repeat(np.asarray(depth)[..., None], 3, axis=-1)
196+
depth_single_path = os.path.join(
197+
_OUTPUT_DIR.value, f'depth_{_CAMERA_ID.value}.png'
198+
)
199+
_save_single(depth_rgb, depth_single_path)
200+
189201
if _NWORLD.value > 1:
190202
tiled_path = os.path.join(
191203
_OUTPUT_DIR.value, f'tiled_{_CAMERA_ID.value}.png'
192204
)
193205
_save_tiled(rgb, tiled_path)
194206

207+
depth_tiled_path = os.path.join(
208+
_OUTPUT_DIR.value, f'depth_tiled_{_CAMERA_ID.value}.png'
209+
)
210+
_save_tiled(depth_rgb, depth_tiled_path)
211+
195212
print('\ndone.')
196213

197214

0 commit comments

Comments
 (0)