@@ -115,7 +115,7 @@ def _init_static_config(self) -> None:
115115
116116 def _init_collision_fields (self ) -> None :
117117 # Pre-compute fields, as they are needed to initialize the collider state and info.
118- n_possible_pairs_ , collision_pair_validity = self ._compute_collision_pair_validity ()
118+ self . _n_possible_pairs , collision_pair_idx = self ._compute_collision_pair_idx ()
119119 vert_neighbors , vert_neighbor_start , vert_n_neighbors = self ._compute_verts_connectivity ()
120120 n_vert_neighbors = len (vert_neighbors )
121121
@@ -131,17 +131,17 @@ def _init_collision_fields(self) -> None:
131131 diff_pos_tolerance = self ._diff_pos_tolerance ,
132132 diff_normal_tolerance = self ._diff_normal_tolerance ,
133133 )
134- self ._init_collision_pair_validity ( collision_pair_validity )
134+ self ._init_collision_pair_idx ( collision_pair_idx )
135135 self ._init_verts_connectivity (vert_neighbors , vert_neighbor_start , vert_n_neighbors )
136- self ._init_max_contact_pairs (n_possible_pairs_ )
136+ self ._init_max_contact_pairs (self . _n_possible_pairs )
137137 self ._init_terrain_state ()
138138
139139 # Initialize [state], which stores every data that are may be updated at every single simulation step
140- n_possible_pairs = max (n_possible_pairs_ , 1 )
140+ n_possible_pairs_ = max (self . _n_possible_pairs , 1 )
141141 self ._collider_state = array_class .get_collider_state (
142142 self ._solver ,
143143 self ._solver ._static_rigid_sim_config ,
144- n_possible_pairs ,
144+ n_possible_pairs_ ,
145145 self ._solver ._options .multiplier_collision_broad_phase ,
146146 self ._collider_info ,
147147 self ._collider_static_config ,
@@ -152,9 +152,9 @@ def _init_collision_fields(self) -> None:
152152
153153 self .reset ()
154154
155- def _compute_collision_pair_validity (self ):
155+ def _compute_collision_pair_idx (self ):
156156 """
157- Compute the collision pair validity matrix .
157+ Compute flat indices of all valid collision pairs .
158158
159159 For each pair of geoms, determine if they can collide based on their properties and the solver configuration.
160160 """
@@ -183,7 +183,7 @@ def _compute_collision_pair_validity(self):
183183 entities_is_local_collision_mask = solver .entities_info .is_local_collision_mask .to_numpy ()
184184
185185 n_possible_pairs = 0
186- collision_pair_validity = np .zeros ((n_geoms , n_geoms ), dtype = gs .np_int )
186+ collision_pair_idx = np .full ((n_geoms , n_geoms ), fill_value = - 1 , dtype = gs .np_int )
187187 for i_ga in range (n_geoms ):
188188 for i_gb in range (i_ga + 1 , n_geoms ):
189189 i_la = geoms_link_idx [i_ga ]
@@ -229,10 +229,10 @@ def _compute_collision_pair_validity(self):
229229 if links_is_fixed [i_la ] and links_is_fixed [i_lb ]:
230230 continue
231231
232- collision_pair_validity [i_ga , i_gb ] = 1
232+ collision_pair_idx [i_ga , i_gb ] = n_possible_pairs
233233 n_possible_pairs += 1
234234
235- return n_possible_pairs , collision_pair_validity
235+ return n_possible_pairs , collision_pair_idx
236236
237237 def _compute_verts_connectivity (self ):
238238 """
@@ -255,8 +255,8 @@ def _compute_verts_connectivity(self):
255255
256256 return vert_neighbors , vert_neighbor_start , vert_n_neighbors
257257
258- def _init_collision_pair_validity (self , collision_pair_validity ):
259- self ._collider_info .collision_pair_validity .from_numpy (collision_pair_validity )
258+ def _init_collision_pair_idx (self , collision_pair_idx ):
259+ self ._collider_info .collision_pair_idx .from_numpy (collision_pair_idx )
260260
261261 def _init_verts_connectivity (self , vert_neighbors , vert_neighbor_start , vert_n_neighbors ):
262262 if self ._solver .n_verts > 0 :
@@ -299,22 +299,23 @@ def _init_terrain_state(self):
299299 def reset (self , envs_idx : npt .NDArray [np .int32 ] | None = None , cache_only : bool = False ) -> None :
300300 self ._contacts_info_cache .clear ()
301301 if gs .use_zerocopy :
302- mask = ( ) if envs_idx is None else envs_idx
302+ envs_idx = slice ( None ) if envs_idx is None else envs_idx
303303 if not cache_only :
304304 first_time = ti_to_torch (self ._collider_state .first_time , copy = False )
305305 if isinstance (envs_idx , torch .Tensor ):
306306 first_time .scatter_ (0 , envs_idx , True )
307307 else :
308- first_time [mask ] = True
308+ first_time [envs_idx ] = True
309+
309310 i_va_ws = ti_to_torch (self ._collider_state .contact_cache .i_va_ws , copy = False )
310311 normal = ti_to_torch (self ._collider_state .contact_cache .normal , copy = False )
311312 if isinstance (envs_idx , torch .Tensor ):
312- n_geoms = i_va_ws .shape [0 ]
313- i_va_ws .scatter_ (2 , envs_idx [None , None ].expand ((n_geoms , n_geoms , - 1 )), - 1 )
314- normal .scatter_ (2 , envs_idx [None , None , :, None ].expand ((n_geoms , n_geoms , - 1 , 3 )), 0.0 )
313+ max_possible_pairs = normal .shape [0 ]
314+ i_va_ws .scatter_ (2 , envs_idx [None , None ].expand ((2 , max_possible_pairs , - 1 )), - 1 )
315+ normal .scatter_ (1 , envs_idx [None , :, None ].expand ((max_possible_pairs , - 1 , 3 )), 0.0 )
315316 else :
316- i_va_ws [mask ] = - 1
317- normal [mask ] = 0.0
317+ i_va_ws [:, :, envs_idx ] = - 1
318+ normal [:, envs_idx ] = 0.0
318319 return
319320
320321 if envs_idx is None :
@@ -333,12 +334,16 @@ def clear(self, envs_idx=None):
333334 )
334335
335336 def detection (self ) -> None :
336- self ._contacts_info_cache .clear ()
337337 rigid_solver .kernel_update_geom_aabbs (
338338 self ._solver .geoms_state ,
339339 self ._solver .geoms_init_AABB ,
340340 self ._solver ._static_rigid_sim_config ,
341341 )
342+
343+ if self ._n_possible_pairs == 0 :
344+ return
345+
346+ self ._contacts_info_cache .clear ()
342347 func_broad_phase (
343348 self ._solver .links_state ,
344349 self ._solver .links_info ,
@@ -565,7 +570,7 @@ def collider_kernel_reset(
565570 collider_state : array_class .ColliderState ,
566571 cache_only : ti .template (),
567572):
568- n_geoms = collider_state .active_buffer .shape [0 ]
573+ max_possible_pairs = collider_state .contact_cache . normal .shape [0 ]
569574
570575 ti .loop_config (serialize = static_rigid_sim_config .para_level < gs .PARA_LEVEL .ALL )
571576 for i_b_ in range (envs_idx .shape [0 ]):
@@ -574,10 +579,10 @@ def collider_kernel_reset(
574579 if ti .static (not cache_only ):
575580 collider_state .first_time [i_b ] = True
576581
577- for i_ga , i_gb in ti . ndrange ( n_geoms , n_geoms ):
578- collider_state .contact_cache .i_va_ws [i_ga , i_gb , i_b ] = - 1
579- collider_state .contact_cache .i_va_ws [i_gb , i_ga , i_b ] = - 1
580- collider_state .contact_cache .normal [i_ga , i_gb , i_b ] = ti .Vector .zero (gs .ti_float , 3 )
582+ for i_pair in range ( max_possible_pairs ):
583+ collider_state .contact_cache .i_va_ws [0 , i_pair , i_b ] = - 1
584+ collider_state .contact_cache .i_va_ws [1 , i_pair , i_b ] = - 1
585+ collider_state .contact_cache .normal [i_pair , i_b ] = ti .Vector .zero (gs .ti_float , 3 )
581586
582587
583588# only used with hibernation ??
@@ -1214,7 +1219,7 @@ def func_check_collision_valid(
12141219 equalities_info : array_class .EqualitiesInfo ,
12151220 collider_info : array_class .ColliderInfo ,
12161221):
1217- is_valid = collider_info .collision_pair_validity [i_ga , i_gb ]
1222+ is_valid = collider_info .collision_pair_idx [i_ga , i_gb ] != - 1
12181223
12191224 if is_valid :
12201225 i_la = geoms_info .link_idx [i_ga ]
@@ -1408,9 +1413,10 @@ def func_broad_phase(
14081413 if not func_is_geom_aabbs_overlap (i_ga , i_gb , i_b , geoms_state , geoms_info ):
14091414 # Clear collision normal cache if not in contact
14101415 if ti .static (not static_rigid_sim_config .enable_mujoco_compatibility ):
1411- collider_state .contact_cache .i_va_ws [i_ga , i_gb , i_b ] = - 1
1412- collider_state .contact_cache .i_va_ws [i_gb , i_ga , i_b ] = - 1
1413- collider_state .contact_cache .normal [i_ga , i_gb , i_b ] = ti .Vector .zero (gs .ti_float , 3 )
1416+ i_pair = collider_info .collision_pair_idx [i_ga , i_gb ]
1417+ collider_state .contact_cache .i_va_ws [0 , i_pair , i_b ] = - 1
1418+ collider_state .contact_cache .i_va_ws [1 , i_pair , i_b ] = - 1
1419+ collider_state .contact_cache .normal [i_pair , i_b ] = ti .Vector .zero (gs .ti_float , 3 )
14141420 continue
14151421
14161422 i_p = collider_state .n_broad_pairs [i_b ]
@@ -1465,11 +1471,10 @@ def func_broad_phase(
14651471 if not func_is_geom_aabbs_overlap (i_ga , i_gb , i_b , geoms_state , geoms_info ):
14661472 # Clear collision normal cache if not in contact
14671473 if ti .static (not static_rigid_sim_config .enable_mujoco_compatibility ):
1468- collider_state .contact_cache .i_va_ws [i_ga , i_gb , i_b ] = - 1
1469- collider_state .contact_cache .i_va_ws [i_gb , i_ga , i_b ] = - 1
1470- collider_state .contact_cache .normal [i_ga , i_gb , i_b ] = ti .Vector .zero (
1471- gs .ti_float , 3
1472- )
1474+ i_pair = collider_info .collision_pair_idx [i_ga , i_gb ]
1475+ collider_state .contact_cache .i_va_ws [0 , i_pair , i_b ] = - 1
1476+ collider_state .contact_cache .i_va_ws [1 , i_pair , i_b ] = - 1
1477+ collider_state .contact_cache .normal [i_pair , i_b ] = ti .Vector .zero (gs .ti_float , 3 )
14731478 continue
14741479
14751480 collider_state .broad_collision_pairs [collider_state .n_broad_pairs [i_b ], i_b ][0 ] = i_ga
@@ -1501,11 +1506,10 @@ def func_broad_phase(
15011506
15021507 if not func_is_geom_aabbs_overlap (i_ga , i_gb , i_b , geoms_state , geoms_info ):
15031508 # Clear collision normal cache if not in contact
1504- collider_state .contact_cache .i_va_ws [i_ga , i_gb , i_b ] = - 1
1505- collider_state .contact_cache .i_va_ws [i_gb , i_ga , i_b ] = - 1
1506- collider_state .contact_cache .normal [i_ga , i_gb , i_b ] = ti .Vector .zero (
1507- gs .ti_float , 3
1508- )
1509+ i_pair = collider_info .collision_pair_idx [i_ga , i_gb ]
1510+ collider_state .contact_cache .i_va_ws [0 , i_pair , i_b ] = - 1
1511+ collider_state .contact_cache .i_va_ws [1 , i_pair , i_b ] = - 1
1512+ collider_state .contact_cache .normal [i_pair , i_b ] = ti .Vector .zero (gs .ti_float , 3 )
15091513 continue
15101514
15111515 collider_state .broad_collision_pairs [collider_state .n_broad_pairs [i_b ], i_b ][0 ] = i_ga
@@ -2395,6 +2399,7 @@ def func_convex_convex_contact(
23952399 axis_1 = ti .Vector .zero (gs .ti_float , 3 )
23962400 qrot = ti .Vector .zero (gs .ti_float , 4 )
23972401
2402+ i_pair = collider_info .collision_pair_idx [(i_gb , i_ga ) if i_ga > i_gb else (i_ga , i_gb )]
23982403 for i_detection in range (5 ):
23992404 try_sdf = False
24002405 prefer_sdf = False
@@ -2410,8 +2415,7 @@ def func_convex_convex_contact(
24102415 if (multi_contact and is_col_0 ) or (i_detection == 0 ):
24112416 if geoms_info .type [i_ga ] == gs .GEOM_TYPE .PLANE :
24122417 plane_dir = ti .Vector (
2413- [geoms_info .data [i_ga ][0 ], geoms_info .data [i_ga ][1 ], geoms_info .data [i_ga ][2 ]],
2414- dt = gs .ti_float ,
2418+ [geoms_info .data [i_ga ][0 ], geoms_info .data [i_ga ][1 ], geoms_info .data [i_ga ][2 ]], dt = gs .ti_float
24152419 )
24162420 plane_dir = gu .ti_transform_by_quat (plane_dir , geoms_state .quat [i_ga , i_b ])
24172421 normal = - plane_dir .normalized ()
@@ -2438,7 +2442,7 @@ def func_convex_convex_contact(
24382442 # Try using MPR before anything else
24392443 is_mpr_updated = False
24402444 is_mpr_guess_direction_available = True
2441- normal_ws = collider_state .contact_cache .normal [i_ga , i_gb , i_b ]
2445+ normal_ws = collider_state .contact_cache .normal [i_pair , i_b ]
24422446 for i_mpr in range (2 ):
24432447 if i_mpr == 1 :
24442448 # Try without warm-start if no contact was detected using it.
@@ -2604,8 +2608,8 @@ def func_convex_convex_contact(
26042608 penetration_a = gs .ti_float (0.0 )
26052609 contact_pos_a = ti .Vector .zero (gs .ti_float , 3 )
26062610 contact_pos_b = ti .Vector .zero (gs .ti_float , 3 )
2607- i_va = collider_state .contact_cache .i_va_ws [i_ga , i_gb , i_b ]
2608- i_vb = collider_state .contact_cache .i_va_ws [i_gb , i_ga , i_b ]
2611+ i_va = collider_state .contact_cache .i_va_ws [0 , i_pair , i_b ]
2612+ i_vb = collider_state .contact_cache .i_va_ws [1 , i_pair , i_b ]
26092613 for i_sdf in range (2 ):
26102614 is_col_i , normal_i , penetration_i , contact_pos_i , i_vi = func_contact_convex_convex_sdf (
26112615 i_ga if i_sdf == 0 else i_gb ,
@@ -2654,14 +2658,14 @@ def func_convex_convex_contact(
26542658 normal = normal_a
26552659 penetration = penetration_a
26562660 contact_pos = contact_pos_a
2657- collider_state .contact_cache .i_va_ws [i_ga , i_gb , i_b ] = i_va
2661+ collider_state .contact_cache .i_va_ws [0 , i_pair , i_b ] = i_va
26582662 elif is_col_b and (
26592663 not is_col_a or penetration_b > max (penetration_a , (not prefer_sdf ) * penetration )
26602664 ):
26612665 normal = normal_b
26622666 penetration = penetration_b
26632667 contact_pos = contact_pos_b
2664- collider_state .contact_cache .i_va_ws [i_gb , i_ga , i_b ] = i_vb
2668+ collider_state .contact_cache .i_va_ws [1 , i_pair , i_b ] = i_vb
26652669 elif not is_col_a and not is_col_b :
26662670 is_col = False
26672671
@@ -2701,12 +2705,12 @@ def func_convex_convex_contact(
27012705 if ti .static (
27022706 collider_static_config .ccd_algorithm in (CCD_ALGORITHM_CODE .MPR , CCD_ALGORITHM_CODE .GJK )
27032707 ):
2704- collider_state .contact_cache .normal [i_ga , i_gb , i_b ] = normal
2708+ collider_state .contact_cache .normal [i_pair , i_b ] = normal
27052709 else :
27062710 # Clear collision normal cache if not in contact
2707- collider_state .contact_cache .i_va_ws [i_ga , i_gb , i_b ] = - 1
2708- collider_state .contact_cache .i_va_ws [i_gb , i_ga , i_b ] = - 1
2709- collider_state .contact_cache .normal [i_ga , i_gb , i_b ] = ti .Vector .zero (gs .ti_float , 3 )
2711+ collider_state .contact_cache .i_va_ws [0 , i_pair , i_b ] = - 1
2712+ collider_state .contact_cache .i_va_ws [1 , i_pair , i_b ] = - 1
2713+ collider_state .contact_cache .normal [i_pair , i_b ] = ti .Vector .zero (gs .ti_float , 3 )
27102714
27112715 elif multi_contact and is_col_0 > 0 and is_col > 0 :
27122716 if ti .static (collider_static_config .ccd_algorithm in (CCD_ALGORITHM_CODE .MPR , CCD_ALGORITHM_CODE .GJK )):
0 commit comments