@@ -232,15 +232,14 @@ def get_constraint_state(constraint_solver, solver):
232232
233233 efc_AR_shape = maybe_shape ((len_constraints_ , len_constraints_ , _B ), solver ._options .noslip_iterations > 0 )
234234 efc_b_shape = maybe_shape ((len_constraints_ , _B ), solver ._options .noslip_iterations > 0 )
235+ jac_relevant_dofs_shape = (len_constraints_ , solver .n_dofs_ , _B )
236+ jac_n_relevant_dofs_shape = (len_constraints_ , _B )
235237
236238 return StructConstraintState (
237239 n_constraints = V (dtype = gs .ti_int , shape = (_B ,)),
238240 ti_n_equalities = V (dtype = gs .ti_int , shape = (_B ,)),
239- jac = V (dtype = gs .ti_float , shape = (len_constraints_ , solver .n_dofs_ , _B )),
240241 diag = V (dtype = gs .ti_float , shape = (len_constraints_ , _B )),
241242 aref = V (dtype = gs .ti_float , shape = (len_constraints_ , _B )),
242- jac_relevant_dofs = V (dtype = gs .ti_int , shape = (len_constraints_ , solver .n_dofs_ , _B )),
243- jac_n_relevant_dofs = V (dtype = gs .ti_int , shape = (len_constraints_ , _B )),
244243 n_constraints_equality = V (dtype = gs .ti_int , shape = (_B ,)),
245244 n_constraints_frictionloss = V (dtype = gs .ti_int , shape = (_B ,)),
246245 improved = V (dtype = gs .ti_int , shape = (_B ,)),
@@ -293,6 +292,10 @@ def get_constraint_state(constraint_solver, solver):
293292 bw_Ju = V (dtype = gs .ti_float , shape = maybe_shape ((len_constraints_ , _B ), solver ._requires_grad )),
294293 bw_y = V (dtype = gs .ti_float , shape = maybe_shape ((len_constraints_ , _B ), solver ._requires_grad )),
295294 bw_w = V (dtype = gs .ti_float , shape = maybe_shape ((len_constraints_ , _B ), solver ._requires_grad )),
295+ # /!\ Moving allocation of these tensors at the end improves runtime speed by ~5-10% /!\
296+ jac = V (dtype = gs .ti_float , shape = (len_constraints_ , solver .n_dofs_ , _B )),
297+ jac_relevant_dofs = V (dtype = gs .ti_int , shape = jac_relevant_dofs_shape ),
298+ jac_n_relevant_dofs = V (dtype = gs .ti_int , shape = jac_n_relevant_dofs_shape ),
296299 )
297300
298301
@@ -518,8 +521,8 @@ def get_collider_state(
518521 broad_collision_pairs = V_VEC (2 , dtype = gs .ti_int , shape = (max (max_collision_pairs_broad , 1 ), _B )),
519522 active_buffer_awake = V (dtype = gs .ti_int , shape = (n_geoms , _B )),
520523 active_buffer_hib = V (dtype = gs .ti_int , shape = (n_geoms , _B )),
521- box_depth = V (dtype = gs .ti_float , shape = (collider_info . box_MAXCONPAIR [ None ] , _B )),
522- box_points = V_VEC (3 , dtype = gs .ti_float , shape = (collider_info . box_MAXCONPAIR [ None ] , _B )),
524+ box_depth = V (dtype = gs .ti_float , shape = (collider_static_config . n_contacts_per_pair , _B )),
525+ box_points = V_VEC (3 , dtype = gs .ti_float , shape = (collider_static_config . n_contacts_per_pair , _B )),
523526 box_pts = V_VEC (3 , dtype = gs .ti_float , shape = (6 , _B )),
524527 box_lines = V_VEC (6 , dtype = gs .ti_float , shape = (4 , _B )),
525528 box_linesu = V_VEC (6 , dtype = gs .ti_float , shape = (4 , _B )),
@@ -555,8 +558,6 @@ class StructColliderInfo(metaclass=BASE_METACLASS):
555558 mc_perturbation : V_ANNOTATION
556559 mc_tolerance : V_ANNOTATION
557560 mpr_to_sdf_overlap_ratio : V_ANNOTATION
558- # maximum number of contact points for box-box collision detection
559- box_MAXCONPAIR : V_ANNOTATION
560561 # differentiable contact tolerance
561562 diff_pos_tolerance : V_ANNOTATION
562563 diff_normal_tolerance : V_ANNOTATION
@@ -586,7 +587,6 @@ def get_collider_info(solver, n_vert_neighbors, collider_static_config, **kwargs
586587 mc_perturbation = V_SCALAR_FROM (dtype = gs .ti_float , value = kwargs ["mc_perturbation" ]),
587588 mc_tolerance = V_SCALAR_FROM (dtype = gs .ti_float , value = kwargs ["mc_tolerance" ]),
588589 mpr_to_sdf_overlap_ratio = V_SCALAR_FROM (dtype = gs .ti_float , value = kwargs ["mpr_to_sdf_overlap_ratio" ]),
589- box_MAXCONPAIR = V_SCALAR_FROM (dtype = gs .ti_int , value = kwargs ["box_MAXCONPAIR" ]),
590590 diff_pos_tolerance = V_SCALAR_FROM (dtype = gs .ti_float , value = kwargs ["diff_pos_tolerance" ]),
591591 diff_normal_tolerance = V_SCALAR_FROM (dtype = gs .ti_float , value = kwargs ["diff_normal_tolerance" ]),
592592 )
@@ -1731,6 +1731,7 @@ class StructRigidSimStaticConfig(metaclass=AutoInitMeta):
17311731class DataManager :
17321732 def __init__ (self , solver ):
17331733 self .rigid_global_info = get_rigid_global_info (solver )
1734+
17341735 self .dofs_info = get_dofs_info (solver )
17351736 self .dofs_state = get_dofs_state (solver )
17361737 self .links_info = get_links_info (solver )
@@ -1758,6 +1759,8 @@ def __init__(self, solver):
17581759 self .entities_info = get_entities_info (solver )
17591760 self .entities_state = get_entities_state (solver )
17601761
1762+ self .errno = V_SCALAR_FROM (dtype = gs .ti_int , value = 0 )
1763+
17611764
17621765DofsState = StructDofsState if gs .use_ndarray else ti .template ()
17631766DofsInfo = StructDofsInfo if gs .use_ndarray else ti .template ()
0 commit comments