@@ -232,8 +232,8 @@ def get_constraint_state(constraint_solver, solver):
232232
233233 efc_AR_shape = maybe_shape ((len_constraints_ , len_constraints_ , _B ), solver ._options .noslip_iterations > 0 )
234234 efc_b_shape = maybe_shape ((len_constraints_ , _B ), solver ._options .noslip_iterations > 0 )
235- jac_relevant_dofs_shape = ( len_constraints_ , solver .n_dofs_ , _B )
236- jac_n_relevant_dofs_shape = ( len_constraints_ , _B )
235+ jac_relevant_dofs_shape = maybe_shape (( len_constraints_ , solver .n_dofs_ , _B ), constraint_solver . sparse_solve )
236+ jac_n_relevant_dofs_shape = maybe_shape (( len_constraints_ , _B ), constraint_solver . sparse_solve )
237237
238238 return StructConstraintState (
239239 n_constraints = V (dtype = gs .ti_int , shape = (_B ,)),
@@ -513,6 +513,19 @@ def get_collider_state(
513513 max_contact_pairs = max_collision_pairs * collider_static_config .n_contacts_per_pair
514514 requires_grad = static_rigid_sim_config .requires_grad
515515
516+ box_depth_shape = maybe_shape (
517+ (collider_static_config .n_contacts_per_pair , _B ), static_rigid_sim_config .box_box_detection
518+ )
519+ box_points_shape = maybe_shape (
520+ (collider_static_config .n_contacts_per_pair , _B ), static_rigid_sim_config .box_box_detection
521+ )
522+ box_pts_shape = maybe_shape ((6 , _B ), static_rigid_sim_config .box_box_detection )
523+ box_lines_shape = maybe_shape ((4 , _B ), static_rigid_sim_config .box_box_detection )
524+ box_linesu_shape = maybe_shape ((4 , _B ), static_rigid_sim_config .box_box_detection )
525+ box_axi_shape = maybe_shape ((3 , _B ), static_rigid_sim_config .box_box_detection )
526+ box_ppts2_shape = maybe_shape ((4 , 2 , _B ), static_rigid_sim_config .box_box_detection )
527+ box_pu_shape = maybe_shape ((4 , _B ), static_rigid_sim_config .box_box_detection )
528+
516529 return StructColliderState (
517530 sort_buffer = get_sort_buffer (solver ),
518531 contact_data = get_contact_data (solver , max_contact_pairs , requires_grad ),
@@ -521,14 +534,14 @@ def get_collider_state(
521534 broad_collision_pairs = V_VEC (2 , dtype = gs .ti_int , shape = (max (max_collision_pairs_broad , 1 ), _B )),
522535 active_buffer_awake = V (dtype = gs .ti_int , shape = (n_geoms , _B )),
523536 active_buffer_hib = V (dtype = gs .ti_int , shape = (n_geoms , _B )),
524- box_depth = V (dtype = gs .ti_float , shape = ( collider_static_config . n_contacts_per_pair , _B ) ),
525- box_points = V_VEC (3 , dtype = gs .ti_float , shape = ( collider_static_config . n_contacts_per_pair , _B ) ),
526- box_pts = V_VEC (3 , dtype = gs .ti_float , shape = ( 6 , _B ) ),
527- box_lines = V_VEC (6 , dtype = gs .ti_float , shape = ( 4 , _B ) ),
528- box_linesu = V_VEC (6 , dtype = gs .ti_float , shape = ( 4 , _B ) ),
529- box_axi = V_VEC (3 , dtype = gs .ti_float , shape = ( 3 , _B ) ),
530- box_ppts2 = V (dtype = gs .ti_float , shape = ( 4 , 2 , _B ) ),
531- box_pu = V_VEC (3 , dtype = gs .ti_float , shape = ( 4 , _B ) ),
537+ box_depth = V (dtype = gs .ti_float , shape = box_depth_shape ),
538+ box_points = V_VEC (3 , dtype = gs .ti_float , shape = box_points_shape ),
539+ box_pts = V_VEC (3 , dtype = gs .ti_float , shape = box_pts_shape ),
540+ box_lines = V_VEC (6 , dtype = gs .ti_float , shape = box_lines_shape ),
541+ box_linesu = V_VEC (6 , dtype = gs .ti_float , shape = box_linesu_shape ),
542+ box_axi = V_VEC (3 , dtype = gs .ti_float , shape = box_axi_shape ),
543+ box_ppts2 = V (dtype = gs .ti_float , shape = box_ppts2_shape ),
544+ box_pu = V_VEC (3 , dtype = gs .ti_float , shape = box_pu_shape ),
532545 xyz_max_min = V (dtype = gs .ti_float , shape = (6 , _B )),
533546 prism = V_VEC (3 , dtype = gs .ti_float , shape = (6 , _B )),
534547 n_contacts = V (dtype = gs .ti_int , shape = (_B ,)),
0 commit comments