@@ -155,8 +155,8 @@ def put_model(mjm: mujoco.MjModel) -> types.Model:
155155 jnt_limited_slide_hinge_adr , dtype = wp .int32 , ndim = 1
156156 )
157157 m .jnt_type = wp .array (mjm .jnt_type , dtype = wp .int32 , ndim = 1 )
158- m .jnt_solref = wp .array (mjm .jnt_solref , dtype = wp .float32 , ndim = 2 )
159- m .jnt_solimp = wp .array (mjm .jnt_solimp , dtype = wp . float32 , ndim = 2 )
158+ m .jnt_solref = wp .array (mjm .jnt_solref , dtype = wp .vec2f , ndim = 1 )
159+ m .jnt_solimp = wp .array (mjm .jnt_solimp , dtype = types . vec5 , ndim = 1 )
160160 m .jnt_qposadr = wp .array (mjm .jnt_qposadr , dtype = wp .int32 , ndim = 1 )
161161 m .jnt_dofadr = wp .array (mjm .jnt_dofadr , dtype = wp .int32 , ndim = 1 )
162162 m .jnt_axis = wp .array (mjm .jnt_axis , dtype = wp .vec3 , ndim = 1 )
@@ -198,12 +198,19 @@ def put_model(mjm: mujoco.MjModel) -> types.Model:
198198 return m
199199
200200
201- def make_data (mjm : mujoco .MjModel , nworld : int = 1 , njmax : int = - 1 ) -> types .Data :
201+ def make_data (
202+ mjm : mujoco .MjModel , nworld : int = 1 , nconmax : int = - 1 , njmax : int = - 1
203+ ) -> types .Data :
202204 d = types .Data ()
203205 d .nworld = nworld
206+ d .ncon_total = wp .zeros ((1 ,), dtype = wp .int32 , ndim = 1 )
204207 d .nefc_total = wp .zeros ((1 ,), dtype = wp .int32 , ndim = 1 )
205208
206209 # TODO(team): move to Model?
210+ if nconmax == - 1 :
211+ # TODO(team): heuristic for nconmax
212+ nconmax = 512
213+ d .nconmax = nconmax
207214 if njmax == - 1 :
208215 # TODO(team): heuristic for njmax
209216 njmax = 512
@@ -255,17 +262,18 @@ def make_data(mjm: mujoco.MjModel, nworld: int = 1, njmax: int = -1) -> types.Da
255262 d .cdof_dot = wp .zeros ((nworld , mjm .nv ), dtype = wp .spatial_vector )
256263 d .qfrc_bias = wp .zeros ((nworld , mjm .nv ), dtype = wp .float32 )
257264 d .contact = types .Contact ()
258- d .contact .dist = wp .zeros ((nworld , d .ncon ), dtype = wp .float32 )
259- d .contact .pos = wp .zeros ((nworld , d .ncon ), dtype = wp .vec3f )
260- d .contact .frame = wp .zeros ((nworld , d .ncon ), dtype = wp .mat33f )
261- d .contact .includemargin = wp .zeros ((nworld , d .ncon ), dtype = wp .float32 )
262- d .contact .friction = wp .zeros ((nworld , d .ncon , 5 ), dtype = wp .float32 )
263- d .contact .solref = wp .zeros ((nworld , d .ncon , types .MJ_NREF ), dtype = wp .float32 )
264- d .contact .solreffriction = wp .zeros ((nworld , d .ncon , types .MJ_NREF ), dtype = wp .float32 )
265- d .contact .solimp = wp .zeros ((nworld , d .ncon , types .MJ_NIMP ), dtype = wp .float32 )
266- d .contact .dim = wp .zeros ((nworld , d .ncon ), dtype = wp .int32 )
267- d .contact .geom = wp .zeros ((nworld , d .ncon , 2 ), dtype = wp .int32 )
268- d .contact .efc_address = wp .zeros ((nworld , d .ncon ), dtype = wp .int32 )
265+ d .contact .dist = wp .zeros ((nconmax ,), dtype = wp .float32 )
266+ d .contact .pos = wp .zeros ((nconmax ,), dtype = wp .vec3f )
267+ d .contact .frame = wp .zeros ((nconmax ,), dtype = wp .mat33f )
268+ d .contact .includemargin = wp .zeros ((nconmax ,), dtype = wp .float32 )
269+ d .contact .friction = wp .zeros ((nconmax ,), dtype = types .vec5 )
270+ d .contact .solref = wp .zeros ((nconmax ,), dtype = wp .vec2f )
271+ d .contact .solreffriction = wp .zeros ((nconmax ,), dtype = wp .vec2f )
272+ d .contact .solimp = wp .zeros ((nconmax ,), dtype = types .vec5 )
273+ d .contact .dim = wp .zeros ((nconmax ,), dtype = wp .int32 )
274+ d .contact .geom = wp .zeros ((nconmax ,), dtype = wp .vec2i )
275+ d .contact .efc_address = wp .zeros ((nconmax ,), dtype = wp .int32 )
276+ d .contact .worldid = wp .zeros ((nconmax ,), dtype = wp .int32 )
269277 d .qfrc_passive = wp .zeros ((nworld , mjm .nv ), dtype = wp .float32 )
270278 d .qfrc_spring = wp .zeros ((nworld , mjm .nv ), dtype = wp .float32 )
271279 d .qfrc_damper = wp .zeros ((nworld , mjm .nv ), dtype = wp .float32 )
@@ -316,13 +324,22 @@ def make_data(mjm: mujoco.MjModel, nworld: int = 1, njmax: int = -1) -> types.Da
316324
317325
318326def put_data (
319- mjm : mujoco .MjModel , mjd : mujoco .MjData , nworld : int = 1 , njmax : int = - 1
327+ mjm : mujoco .MjModel ,
328+ mjd : mujoco .MjData ,
329+ nworld : int = 1 ,
330+ nconmax : int = - 1 ,
331+ njmax : int = - 1 ,
320332) -> types .Data :
321333 d = types .Data ()
322334 d .nworld = nworld
335+ d .ncon_total = wp .array ([mjd .ncon * nworld ], dtype = wp .int32 , ndim = 1 )
323336 d .nefc_total = wp .array ([mjd .nefc * nworld ], dtype = wp .int32 , ndim = 1 )
324337
325338 # TODO(team): move to Model?
339+ if nconmax == - 1 :
340+ # TODO(team): heuristic for nconmax
341+ nconmax = 512
342+ d .nconmax = nconmax
326343 if njmax == - 1 :
327344 # TODO(team): heuristic for njmax
328345 njmax = 512
@@ -437,25 +454,63 @@ def tile(x):
437454 d .efc_force = wp .array (efc_force_fill , dtype = wp .float32 , ndim = 1 )
438455 d .efc_margin = wp .array (efc_margin_fill , dtype = wp .float32 , ndim = 1 )
439456 d .efc_worldid = wp .from_numpy (efc_worldid , dtype = wp .int32 )
457+
440458 d .act = wp .array (tile (mjd .act ), dtype = wp .float32 , ndim = 2 )
441459 d .act_dot = wp .array (tile (mjd .act_dot ), dtype = wp .float32 , ndim = 2 )
442- d .contact .dist = wp .array (tile (mjd .contact .dist ), dtype = wp .float32 , ndim = 2 )
443- d .contact .pos = wp .array (tile (mjd .contact .pos ), dtype = wp .vec3f , ndim = 2 )
444- d .contact .frame = wp .array (tile (mjd .contact .frame ), dtype = wp .mat33f , ndim = 2 )
445- d .contact .includemargin = wp .array (
446- tile (mjd .contact .includemargin ), dtype = wp .float32 , ndim = 2
460+
461+ ncon = mjd .ncon
462+ con_efc_address = np .zeros (nconmax , dtype = int )
463+ con_worldid = np .zeros (nconmax , dtype = int )
464+
465+ for i in range (nworld ):
466+ con_efc_address [i * ncon : (i + 1 ) * ncon ] = mjd .contact .efc_address + i * ncon
467+ con_worldid [i * ncon : (i + 1 ) * ncon ] = i
468+
469+ ncon_fill = nconmax - nworld * ncon
470+
471+ con_dist_fill = np .concatenate (
472+ [np .repeat (mjd .contact .dist , nworld , axis = 0 ), np .zeros (ncon_fill )]
473+ )
474+ con_pos_fill = np .vstack (
475+ [np .repeat (mjd .contact .pos , nworld , axis = 0 ), np .zeros ((ncon_fill , 3 ))]
447476 )
448- d .contact .friction = wp .array (tile (mjd .contact .friction ), dtype = wp .float32 , ndim = 3 )
449- d .contact .solref = wp .array (tile (mjd .contact .solref ), dtype = wp .float32 , ndim = 3 )
450- d .contact .solreffriction = wp .array (
451- tile (mjd .contact .solreffriction ), dtype = wp .float32 , ndim = 3
477+ con_frame_fill = np .vstack (
478+ [np .repeat (mjd .contact .frame , nworld , axis = 0 ), np .zeros ((ncon_fill , 9 ))]
452479 )
453- d .contact .solimp = wp .array (tile (mjd .contact .solimp ), dtype = wp .float32 , ndim = 3 )
454- d .contact .dim = wp .array (tile (mjd .contact .dim ), dtype = wp .int32 , ndim = 2 )
455- d .contact .geom = wp .array (tile (mjd .contact .geom ), dtype = wp .int32 , ndim = 3 )
456- d .contact .efc_address = wp .array (
457- tile (mjd .contact .efc_address ), dtype = wp .int32 , ndim = 2
480+ con_includemargin_fill = np .concatenate (
481+ [np .repeat (mjd .contact .includemargin , nworld , axis = 0 ), np .zeros (ncon_fill )]
458482 )
483+ con_friction_fill = np .vstack (
484+ [np .repeat (mjd .contact .friction , nworld , axis = 0 ), np .zeros ((ncon_fill , 5 ))]
485+ )
486+ con_solref_fill = np .vstack (
487+ [np .repeat (mjd .contact .solref , nworld , axis = 0 ), np .zeros ((ncon_fill , 2 ))]
488+ )
489+ con_solreffriction_fill = np .vstack (
490+ [np .repeat (mjd .contact .solreffriction , nworld , axis = 0 ), np .zeros ((ncon_fill , 2 ))]
491+ )
492+ con_solimp_fill = np .vstack (
493+ [np .repeat (mjd .contact .solimp , nworld , axis = 0 ), np .zeros ((ncon_fill , 5 ))]
494+ )
495+ con_dim_fill = np .concatenate (
496+ [np .repeat (mjd .contact .dim , nworld , axis = 0 ), np .zeros (ncon_fill )]
497+ )
498+ con_geom_fill = np .vstack (
499+ [np .repeat (mjd .contact .geom , nworld , axis = 0 ), np .zeros ((ncon_fill , 2 ))]
500+ )
501+
502+ d .contact .dist = wp .array (con_dist_fill , dtype = wp .float32 , ndim = 1 )
503+ d .contact .pos = wp .array (con_pos_fill , dtype = wp .vec3f , ndim = 1 )
504+ d .contact .frame = wp .array (con_frame_fill , dtype = wp .mat33f , ndim = 1 )
505+ d .contact .includemargin = wp .array (con_includemargin_fill , dtype = wp .float32 , ndim = 1 )
506+ d .contact .friction = wp .array (con_friction_fill , dtype = types .vec5 , ndim = 1 )
507+ d .contact .solref = wp .array (con_solref_fill , dtype = wp .vec2f , ndim = 1 )
508+ d .contact .solreffriction = wp .array (con_solreffriction_fill , dtype = wp .vec2f , ndim = 1 )
509+ d .contact .solimp = wp .array (con_solimp_fill , dtype = types .vec5 , ndim = 1 )
510+ d .contact .dim = wp .array (con_dim_fill , dtype = wp .int32 , ndim = 1 )
511+ d .contact .geom = wp .array (con_geom_fill , dtype = wp .vec2i , ndim = 1 )
512+ d .contact .efc_address = wp .array (con_efc_address , dtype = wp .int32 , ndim = 1 )
513+ d .contact .worldid = wp .array (con_worldid , dtype = wp .int32 , ndim = 1 )
459514
460515 d .xfrc_applied = wp .array (tile (mjd .xfrc_applied ), dtype = wp .spatial_vector , ndim = 2 )
461516 # internal tmp arrays
0 commit comments