@@ -175,23 +175,40 @@ def init(worldid):
175175 )(mx , dx_batch , rc )
176176
177177 rgb_packed = out_batch [0 ]
178- print (f' rgb shape: { rgb_packed .shape } \n ' )
178+ depth_packed = out_batch [1 ]
179+ print (f' rgb shape: { rgb_packed .shape } ' )
180+ print (f' depth shape: { depth_packed .shape } \n ' )
179181
180182 rgb = jax .vmap (
181183 render_util .get_rgb , in_axes = (None , 0 , None )
182184 )(rc , rgb_packed , _CAMERA_ID .value )
183185
186+ depth = jax .vmap (
187+ render_util .get_depth , in_axes = (None , 0 , None , None )
188+ )(rc , depth_packed , _CAMERA_ID .value , 10.0 )
189+
184190 single_path = os .path .join (
185191 _OUTPUT_DIR .value , f'camera_{ _CAMERA_ID .value } .png'
186192 )
187193 _save_single (rgb , single_path )
188194
195+ depth_rgb = np .repeat (np .asarray (depth )[..., None ], 3 , axis = - 1 )
196+ depth_single_path = os .path .join (
197+ _OUTPUT_DIR .value , f'depth_{ _CAMERA_ID .value } .png'
198+ )
199+ _save_single (depth_rgb , depth_single_path )
200+
189201 if _NWORLD .value > 1 :
190202 tiled_path = os .path .join (
191203 _OUTPUT_DIR .value , f'tiled_{ _CAMERA_ID .value } .png'
192204 )
193205 _save_tiled (rgb , tiled_path )
194206
207+ depth_tiled_path = os .path .join (
208+ _OUTPUT_DIR .value , f'depth_tiled_{ _CAMERA_ID .value } .png'
209+ )
210+ _save_tiled (depth_rgb , depth_tiled_path )
211+
195212 print ('\n done.' )
196213
197214
0 commit comments