@@ -60,9 +60,9 @@ def domain_randomize(
6060) -> Tuple [mjx .Model , mjx .Model ]:
6161 """Tile the necessary axes for the Madrona BatchRenderer."""
6262 mj_model = pick_cartesian .PandaPickCubeCartesian ().mj_model
63- FLOOR_GEOM_ID = mj_model .geom ('floor' ).id
64- BOX_GEOM_ID = mj_model .geom ('box' ).id
65- STRIP_GEOM_ID = mj_model .geom ('init_space' ).id
63+ floor_geom_id = mj_model .geom ('floor' ).id
64+ box_geom_id = mj_model .geom ('box' ).id
65+ strip_geom_id = mj_model .geom ('init_space' ).id
6666
6767 in_axes = jax .tree_util .tree_map (lambda x : None , mjx_model )
6868 in_axes = in_axes .tree_replace ({
@@ -93,16 +93,16 @@ def rand(rng: jax.Array, light_position: jax.Array):
9393 rgba = jp .array (
9494 [jax .random .uniform (key_box , (), minval = 0.5 , maxval = 1.0 ), 0.0 , 0.0 , 1.0 ]
9595 )
96- geom_rgba = mjx_model .geom_rgba .at [BOX_GEOM_ID ].set (rgba )
96+ geom_rgba = mjx_model .geom_rgba .at [box_geom_id ].set (rgba )
9797
9898 strip_white = jax .random .uniform (key_strip , (), minval = 0.8 , maxval = 1.0 )
99- geom_rgba = geom_rgba .at [STRIP_GEOM_ID ].set (
99+ geom_rgba = geom_rgba .at [strip_geom_id ].set (
100100 jp .array ([strip_white , strip_white , strip_white , 1.0 ])
101101 )
102102
103103 # Sample a shade of gray
104104 gray_scale = jax .random .uniform (key_floor , (), minval = 0.0 , maxval = 0.25 )
105- geom_rgba = geom_rgba .at [FLOOR_GEOM_ID ].set (
105+ geom_rgba = geom_rgba .at [floor_geom_id ].set (
106106 jp .array ([gray_scale , gray_scale , gray_scale , 1.0 ])
107107 )
108108
@@ -112,11 +112,11 @@ def rand(rng: jax.Array, light_position: jax.Array):
112112 jax .random .randint (key_matid , shape = (num_geoms ,), minval = 0 , maxval = 10 )
113113 + mat_offset
114114 )
115- geom_matid = geom_matid .at [BOX_GEOM_ID ].set (
115+ geom_matid = geom_matid .at [box_geom_id ].set (
116116 - 2
117117 ) # Use the above randomized colors
118- geom_matid = geom_matid .at [FLOOR_GEOM_ID ].set (- 2 )
119- geom_matid = geom_matid .at [STRIP_GEOM_ID ].set (- 2 )
118+ geom_matid = geom_matid .at [floor_geom_id ].set (- 2 )
119+ geom_matid = geom_matid .at [strip_geom_id ].set (- 2 )
120120
121121 #### Cameras ####
122122 key_pos , key_ori , key = jax .random .split (key , 3 )
@@ -134,7 +134,7 @@ def rand(rng: jax.Array, light_position: jax.Array):
134134 assert (
135135 nlight == 1
136136 ), f'Sim2Real was trained with a single light source, got { nlight } '
137- key_lsha , key_ldir , key_ldct , key = jax .random .split (key , 4 )
137+ key_lsha , key_ldir , key = jax .random .split (key , 3 )
138138
139139 # Direction
140140 shine_at = jp .array ([0.661 , - 0.001 , 0.179 ]) # Gripper starting position
0 commit comments