Skip to content

Commit bbfe7fe

Browse files
btabacopybara-github
authored andcommitted
Update mjx-warp render with partial codegen.
PiperOrigin-RevId: 869874195 Change-Id: I71a16a6c192873786a09211dc4d326c3a4115ad9
1 parent faf55f7 commit bbfe7fe

File tree

10 files changed

+195
-148
lines changed

10 files changed

+195
-148
lines changed

mjx/mujoco/mjx/_src/io.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1960,5 +1960,5 @@ def create_render_context(
19601960
Render context object that is JAX compatible.
19611961
"""
19621962
_check_warp_installed()
1963-
from mujoco.mjx.warp import render as mjxw_render # pylint: disable=g-import-not-at-top # pytype: disable=import-error
1964-
return mjxw_render.create_render_context(mjm, nworld=nworld, **kwargs)
1963+
from mujoco.mjx.warp import io as mjxw_io # pylint: disable=g-import-not-at-top # pytype: disable=import-error
1964+
return mjxw_io.create_render_context(mjm, nworld=nworld, **kwargs)

mjx/mujoco/mjx/_src/render_util.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,30 +14,40 @@
1414
# ==============================================================================
1515
"""JAX render utilities for unpacking render output from MuJoCo Warp."""
1616

17+
import typing
18+
from typing import Any
19+
1720
import jax
1821
import jax.numpy as jnp
1922

2023

2124
def get_rgb(
25+
rc: Any,
2226
rgb_data: jax.Array,
2327
cam_id: int,
24-
width: int,
25-
height: int,
2628
) -> jax.Array:
2729
"""Unpack uint32 ABGR pixel data into float32 RGB.
2830
2931
Args:
30-
rgb_data: Packed render output, shape (nworld, ncam, H*W)
32+
rc: The RenderContext handle.
33+
rgb_data: Packed render output, shape (nworld, total_pixels)
3134
as uint32.
3235
cam_id: Camera index to extract.
33-
width: Image width.
34-
height: Image height.
3536
3637
Returns:
3738
Float32 RGB array with shape (nworld, H, W, 3), values
3839
in [0, 1].
3940
"""
40-
packed = rgb_data[:, cam_id]
41+
import mujoco.mjx.warp.render as mjxw_render # pylint: disable=g-import-not-at-top
42+
warp_rc = mjxw_render._MJX_RENDER_CONTEXT_BUFFERS[rc.key]
43+
rgb_adr = int(warp_rc.rgb_adr.numpy()[cam_id])
44+
width = int(warp_rc.cam_res.numpy()[cam_id][0])
45+
height = int(warp_rc.cam_res.numpy()[cam_id][1])
46+
47+
packed = jax.lax.dynamic_slice_in_dim(
48+
rgb_data, rgb_adr, width * height, axis=1
49+
)
50+
4151
r = (packed & 0xFF).astype(jnp.float32) / 255.0
4252
g = ((packed >> 8) & 0xFF).astype(jnp.float32) / 255.0
4353
b = ((packed >> 16) & 0xFF).astype(jnp.float32) / 255.0
@@ -47,27 +57,35 @@ def get_rgb(
4757

4858

4959
def get_depth(
60+
rc: Any,
5061
depth_data: jax.Array,
5162
cam_id: int,
52-
width: int,
53-
height: int,
5463
depth_scale: float,
5564
) -> jax.Array:
5665
"""Extract and normalize depth data for a camera.
5766
5867
Args:
59-
depth_data: Raw depth output, shape (nworld, ncam, H*W)
68+
rc: The RenderContext handle.
69+
depth_data: Raw depth output, shape (nworld, total_pixels)
6070
as float32.
6171
cam_id: Camera index to extract.
62-
width: Image width.
63-
height: Image height.
6472
depth_scale: Scale factor for normalizing depth values.
6573
6674
Returns:
6775
Float32 depth array with shape (nworld, H, W), clamped
6876
to [0, 1].
6977
"""
70-
raw = depth_data[:, cam_id]
78+
import mujoco.mjx.warp.render as mjxw_render # pylint: disable=g-import-not-at-top
79+
warp_rc = mjxw_render._MJX_RENDER_CONTEXT_BUFFERS[rc.key]
80+
depth_adr = int(warp_rc.depth_adr.numpy()[cam_id])
81+
width = int(warp_rc.cam_res.numpy()[cam_id][0])
82+
height = int(warp_rc.cam_res.numpy()[cam_id][1])
83+
84+
raw = jax.lax.dynamic_slice_in_dim(
85+
depth_data, depth_adr, width * height, axis=1
86+
)
87+
7188
nworld = depth_data.shape[0]
7289
depth = jnp.clip(raw / depth_scale, 0.0, 1.0)
7390
return depth.reshape(nworld, height, width)
91+

mjx/mujoco/mjx/warp/bvh.py

Lines changed: 58 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,106 +12,128 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15+
1516
"""DO NOT EDIT. This file is auto-generated."""
1617

1718
import dataclasses
18-
import functools
1919
import jax
20-
import mujoco
2120
from mujoco.mjx._src import types
2221
from mujoco.mjx.warp import ffi
23-
# Re-use the render registry
24-
from mujoco.mjx.warp.render import _MJX_RENDER_CONTEXT_BUFFERS
22+
from mujoco.mjx.warp.io import _MJX_RENDER_CONTEXT_BUFFERS
2523
from mujoco.mjx.warp.types import RenderContext
2624
import mujoco.mjx.third_party.mujoco_warp as mjwarp
25+
from mujoco.mjx.third_party.mujoco_warp._src import types as mjwp_types
2726
import warp as wp
2827

28+
2929
_m = mjwarp.Model(
3030
**{f.name: None for f in dataclasses.fields(mjwarp.Model) if f.init}
3131
)
3232
_d = mjwarp.Data(
3333
**{f.name: None for f in dataclasses.fields(mjwarp.Data) if f.init}
3434
)
35+
_o = mjwarp.Option(
36+
**{f.name: None for f in dataclasses.fields(mjwarp.Option) if f.init}
37+
)
38+
_s = mjwarp.Statistic(
39+
**{f.name: None for f in dataclasses.fields(mjwarp.Statistic) if f.init}
40+
)
41+
_c = mjwarp.Contact(
42+
**{f.name: None for f in dataclasses.fields(mjwarp.Contact) if f.init}
43+
)
44+
_e = mjwarp.Constraint(
45+
**{f.name: None for f in dataclasses.fields(mjwarp.Constraint) if f.init}
46+
)
3547

3648

3749
@ffi.format_args_for_warp
3850
def _refit_bvh_shim(
3951
# Model
52+
nworld: int,
53+
flex_dim: wp.array(dtype=int),
54+
flex_elem: wp.array(dtype=int),
55+
flex_elemnum: wp.array(dtype=int),
56+
flex_vertadr: wp.array(dtype=int),
4057
geom_dataid: wp.array(dtype=int),
4158
geom_size: wp.array2d(dtype=wp.vec3),
4259
geom_type: wp.array(dtype=int),
4360
nflex: int,
4461
nflexelemdata: int,
4562
nflexvert: int,
46-
flex_dim: wp.array(dtype=int),
47-
flex_elem: wp.array(dtype=int),
48-
flex_elemnum: wp.array(dtype=int),
49-
flex_vertadr: wp.array(dtype=int),
5063
# Data
64+
flexvert_xpos: wp.array2d(dtype=wp.vec3),
5165
geom_xmat: wp.array2d(dtype=wp.mat33),
5266
geom_xpos: wp.array2d(dtype=wp.vec3),
53-
flexvert_xpos: wp.array2d(dtype=wp.vec3),
5467
# Registry
5568
rc_id: int,
56-
geom_xpos_out: wp.array2d(dtype=wp.vec3),
69+
# Dummy output
70+
dummy: wp.array(dtype=int),
5771
):
72+
_m.stat = _s
73+
_m.opt = _o
74+
_d.efc = _e
75+
_d.contact = _c
76+
_m.flex_dim = flex_dim
77+
_m.flex_elem = flex_elem
78+
_m.flex_elemnum = flex_elemnum
79+
_m.flex_vertadr = flex_vertadr
5880
_m.geom_dataid = geom_dataid
5981
_m.geom_size = geom_size
6082
_m.geom_type = geom_type
6183
_m.nflex = nflex
6284
_m.nflexelemdata = nflexelemdata
6385
_m.nflexvert = nflexvert
64-
_m.flex_dim = flex_dim
65-
_m.flex_elem = flex_elem
66-
_m.flex_elemnum = flex_elemnum
67-
_m.flex_vertadr = flex_vertadr
86+
_d.flexvert_xpos = flexvert_xpos
6887
_d.geom_xmat = geom_xmat
6988
_d.geom_xpos = geom_xpos
70-
_d.flexvert_xpos = flexvert_xpos
71-
_d.nworld = geom_xpos.shape[0]
72-
89+
_d.nworld = nworld
7390
render_context = _MJX_RENDER_CONTEXT_BUFFERS[rc_id]
91+
92+
dummy.zero_()
7493
mjwarp.refit_bvh(_m, _d, render_context)
75-
wp.copy(geom_xpos_out, geom_xpos)
7694

7795

7896
def _refit_bvh_jax_impl(m: types.Model, d: types.Data, ctx: RenderContext):
79-
nworld = d.qpos.shape[0]
80-
ngeom = d.geom_xpos.shape[1]
81-
97+
output_dims = {'dummy': (d.qpos.shape[0],)}
8298
jf = ffi.jax_callable_variadic_tuple(
8399
_refit_bvh_shim,
84100
num_outputs=1,
85-
output_dims={'geom_xpos_out': (nworld, ngeom, 3)},
101+
output_dims=output_dims,
86102
vmap_method=None,
103+
in_out_argnames=set([]),
104+
stage_in_argnames=set(['geom_size', 'geom_xmat', 'geom_xpos']),
105+
stage_out_argnames=set([]),
106+
graph_mode=m.opt._impl.graph_mode,
87107
)
88108
out = jf(
109+
d.qpos.shape[0],
110+
m._impl.flex_dim,
111+
m._impl.flex_elem,
112+
m._impl.flex_elemnum,
113+
m._impl.flex_vertadr,
89114
m.geom_dataid,
90115
m.geom_size,
91116
m.geom_type,
92-
m.nflex,
93-
m.nflexelemdata,
94-
m.nflexvert,
95-
m.flex_dim,
96-
m.flex_elem,
97-
m.flex_elemnum,
98-
m.flex_vertadr,
117+
m._impl.nflex,
118+
m._impl.nflexelemdata,
119+
m._impl.nflexvert,
120+
d._impl.flexvert_xpos,
99121
d.geom_xmat,
100122
d.geom_xpos,
101-
d.flexvert_xpos,
102123
ctx.key,
103124
)
104-
return d.replace(geom_xpos=out[0])
125+
d = d.tree_replace({'time': d.time + out[0]})
126+
return d
105127

106128

107129
@jax.custom_batching.custom_vmap
108-
@functools.partial(ffi.marshal_jax_warp_callable)
130+
@ffi.marshal_jax_warp_callable
109131
def refit_bvh(m: types.Model, d: types.Data, ctx: RenderContext):
110132
return _refit_bvh_jax_impl(m, d, ctx)
111133

112134

113135
@refit_bvh.def_vmap
114-
@functools.partial(ffi.marshal_custom_vmap)
115-
def refit_bvh_vmap(unused_axis_size, is_batched, m, d, ctx):
116-
out = refit_bvh(m, d, ctx)
117-
return out, is_batched[1]
136+
@ffi.marshal_custom_vmap
137+
def refit_bvh_vmap(unused_axis_size, is_batched, m, d, ctx: RenderContext):
138+
d = refit_bvh(m, d, ctx)
139+
return d, is_batched[1]

mjx/mujoco/mjx/warp/collision_driver.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ def _collision_shim(
223223
_d.naconmax = naconmax
224224
_d.ncollision = ncollision
225225
_d.nworld = nworld
226+
226227
mjwarp.collision(_m, _d)
227228

228229

@@ -249,7 +250,7 @@ def _collision_jax_impl(m: types.Model, d: types.Data):
249250
num_outputs=15,
250251
output_dims=output_dims,
251252
vmap_method=None,
252-
in_out_argnames={
253+
in_out_argnames=set([
253254
'nacon',
254255
'ncollision',
255256
'contact__dim',
@@ -265,8 +266,8 @@ def _collision_jax_impl(m: types.Model, d: types.Data):
265266
'contact__solreffriction',
266267
'contact__type',
267268
'contact__worldid',
268-
},
269-
stage_in_argnames={
269+
]),
270+
stage_in_argnames=set([
270271
'geom_aabb',
271272
'geom_friction',
272273
'geom_gap',
@@ -285,8 +286,8 @@ def _collision_jax_impl(m: types.Model, d: types.Data):
285286
'pair_solimp',
286287
'pair_solref',
287288
'pair_solreffriction',
288-
},
289-
stage_out_argnames={},
289+
]),
290+
stage_out_argnames=set([]),
290291
graph_mode=m.opt._impl.graph_mode,
291292
)
292293
out = jf(

mjx/mujoco/mjx/warp/forward.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -909,6 +909,7 @@ def _forward_shim(
909909
_d.xpos = xpos
910910
_d.xquat = xquat
911911
_d.nworld = nworld
912+
912913
mjwarp.forward(_m, _d)
913914

914915

@@ -1012,7 +1013,7 @@ def _forward_jax_impl(m: types.Model, d: types.Data):
10121013
num_outputs=92,
10131014
output_dims=output_dims,
10141015
vmap_method=None,
1015-
in_out_argnames={
1016+
in_out_argnames=set([
10161017
'act_dot',
10171018
'actuator_force',
10181019
'actuator_length',
@@ -1105,8 +1106,8 @@ def _forward_jax_impl(m: types.Model, d: types.Data):
11051106
'efc__state',
11061107
'efc__type',
11071108
'efc__vel',
1108-
},
1109-
stage_in_argnames={
1109+
]),
1110+
stage_in_argnames=set([
11101111
'act',
11111112
'act_dot',
11121113
'actuator_acc0',
@@ -1242,8 +1243,8 @@ def _forward_jax_impl(m: types.Model, d: types.Data):
12421243
'xmat',
12431244
'xpos',
12441245
'xquat',
1245-
},
1246-
stage_out_argnames={
1246+
]),
1247+
stage_out_argnames=set([
12471248
'act_dot',
12481249
'actuator_force',
12491250
'actuator_length',
@@ -1276,7 +1277,7 @@ def _forward_jax_impl(m: types.Model, d: types.Data):
12761277
'xmat',
12771278
'xpos',
12781279
'xquat',
1279-
},
1280+
]),
12801281
graph_mode=m.opt._impl.graph_mode,
12811282
)
12821283
out = jf(
@@ -2688,6 +2689,7 @@ def _step_shim(
26882689
_d.xpos = xpos
26892690
_d.xquat = xquat
26902691
_d.nworld = nworld
2692+
26912693
mjwarp.step(_m, _d)
26922694

26932695

@@ -2795,7 +2797,7 @@ def _step_jax_impl(m: types.Model, d: types.Data):
27952797
num_outputs=96,
27962798
output_dims=output_dims,
27972799
vmap_method=None,
2798-
in_out_argnames={
2800+
in_out_argnames=set([
27992801
'act',
28002802
'act_dot',
28012803
'actuator_force',
@@ -2892,8 +2894,8 @@ def _step_jax_impl(m: types.Model, d: types.Data):
28922894
'efc__state',
28932895
'efc__type',
28942896
'efc__vel',
2895-
},
2896-
stage_in_argnames={
2897+
]),
2898+
stage_in_argnames=set([
28972899
'act',
28982900
'act_dot',
28992901
'actuator_acc0',
@@ -3029,8 +3031,8 @@ def _step_jax_impl(m: types.Model, d: types.Data):
30293031
'xmat',
30303032
'xpos',
30313033
'xquat',
3032-
},
3033-
stage_out_argnames={
3034+
]),
3035+
stage_out_argnames=set([
30343036
'act',
30353037
'act_dot',
30363038
'actuator_force',
@@ -3067,7 +3069,7 @@ def _step_jax_impl(m: types.Model, d: types.Data):
30673069
'xmat',
30683070
'xpos',
30693071
'xquat',
3070-
},
3072+
]),
30713073
graph_mode=m.opt._impl.graph_mode,
30723074
)
30733075
out = jf(

0 commit comments

Comments
 (0)