99
1010import genesis as gs
1111import genesis .utils .geom as gu
12- from genesis .styles import colors , formats
1312import genesis .utils .array_class as array_class
1413import genesis .engine .solvers .rigid .gjk_decomp as gjk
1514import genesis .engine .solvers .rigid .diff_gjk_decomp as diff_gjk
1615import genesis .engine .solvers .rigid .mpr_decomp as mpr
1716import genesis .utils .sdf_decomp as sdf
1817import genesis .engine .solvers .rigid .support_field_decomp as support_field
1918import genesis .engine .solvers .rigid .rigid_solver_decomp as rigid_solver
19+ from genesis .utils .misc import tensor_to_array , ti_to_torch , ti_to_numpy
2020
2121from .mpr_decomp import MPR
2222from .gjk_decomp import GJK
@@ -62,6 +62,22 @@ def __init__(self, rigid_solver: "RigidSolver"):
6262 self ._init_static_config ()
6363 self ._init_collision_fields ()
6464
65+ if gs .use_zerocopy :
66+ self ._contacts_info : dict [str , torch .Tensor ] = {}
67+ for key , name in (
68+ ("link_a" , "link_a" ),
69+ ("link_b" , "link_b" ),
70+ ("geom_a" , "geom_a" ),
71+ ("geom_b" , "geom_b" ),
72+ ("penetration" , "penetration" ),
73+ ("position" , "pos" ),
74+ ("normal" , "normal" ),
75+ ("force" , "force" ),
76+ ):
77+ self ._contacts_info [key ] = ti_to_torch (
78+ getattr (self ._collider_state .contact_data , name ), transpose = True , copy = False
79+ )
80+
6581 # Support field used for mpr and gjk. Rather than having separate support fields for each algorithm, keep only
6682 # one copy here to save memory and maintain cleaner code.
6783 self ._support_field = SupportField (rigid_solver )
@@ -131,8 +147,8 @@ def _init_collision_fields(self) -> None:
131147 self ._collider_static_config ,
132148 )
133149
134- # [ contacts_info_cache] is not used in Taichi kernels, so keep it outside of the collider state / info.
135- self ._contacts_info_cache = {}
150+ # ' contacts_info_cache' is not used in Taichi kernels, so keep it outside of the collider state / info
151+ self ._contacts_info_cache : dict [ tuple [ bool , bool ], dict [ str , torch . Tensor | tuple [ torch . Tensor ]]] = {}
136152
137153 self .reset ()
138154
@@ -288,7 +304,7 @@ def reset(self, envs_idx: npt.NDArray[np.int32] | None = None) -> None:
288304 self ._solver ._static_rigid_sim_config ,
289305 self ._collider_state ,
290306 )
291- self ._contacts_info_cache = {}
307+ self ._contacts_info_cache . clear ()
292308
293309 def clear (self , envs_idx = None ):
294310 if envs_idx is None :
@@ -302,7 +318,7 @@ def clear(self, envs_idx=None):
302318 )
303319
304320 def detection (self ) -> None :
305- self ._contacts_info_cache = {}
321+ self ._contacts_info_cache . clear ()
306322 rigid_solver .kernel_update_geom_aabbs (
307323 self ._solver .geoms_state ,
308324 self ._solver .geoms_init_AABB ,
@@ -391,18 +407,46 @@ def detection(self) -> None:
391407
392408 def get_contacts (self , as_tensor : bool = True , to_torch : bool = True , keep_batch_dim : bool = False ):
393409 # Early return if already pre-computed
394- contacts_info = self ._contacts_info_cache .get ((as_tensor , to_torch ))
395- if contacts_info is not None :
410+ contacts_info = self ._contacts_info_cache .setdefault ((as_tensor , to_torch ), {})
411+ if contacts_info :
412+ return contacts_info .copy ()
413+
414+ n_envs = self ._solver .n_envs
415+ if gs .use_zerocopy :
416+ if as_tensor or n_envs == 0 :
417+ n_contacts_max = ti_to_torch (self ._collider_state .n_contacts_max , copy = False ).item ()
418+ else :
419+ n_contacts = ti_to_torch (self ._collider_state .n_contacts , copy = False )
420+
421+ for key , data in self ._contacts_info .items ():
422+ if n_envs == 0 :
423+ data = data [0 , :n_contacts_max ]
424+ if not to_torch :
425+ data = tensor_to_array (data )
426+ else :
427+ if as_tensor :
428+ data = data [:, :n_contacts_max ]
429+ if not to_torch :
430+ data = tensor_to_array (data )
431+ else :
432+ if not to_torch :
433+ data = tensor_to_array (data )
434+ if keep_batch_dim :
435+ data = tuple ([data [i : i + 1 , :j ] for i , j in enumerate (n_contacts .tolist ())])
436+ else :
437+ data = tuple ([data [i , :j ] for i , j in enumerate (n_contacts .tolist ())])
438+ contacts_info [key ] = data
439+
396440 return contacts_info .copy ()
397441
398442 # Find out how much dynamic memory must be allocated
399- n_contacts = tuple (self ._collider_state .n_contacts .to_numpy ())
400- n_envs = len (n_contacts )
401- n_contacts_max = max (n_contacts )
443+ n_contacts = ti_to_numpy (self ._collider_state .n_contacts )
444+ n_contacts_max = n_contacts .max ().item ()
402445 if as_tensor :
403- out_size = n_contacts_max * n_envs
446+ out_size = n_contacts_max * max ( n_envs , 1 )
404447 else :
405448 * n_contacts_starts , out_size = np .cumsum (n_contacts )
449+ n_contacts = n_contacts .tolist ()
406450
407451 # Allocate output buffer
408452 if to_torch :
@@ -415,29 +459,23 @@ def get_contacts(self, as_tensor: bool = True, to_torch: bool = True, keep_batch
415459 # Copy contact data
416460 if n_contacts_max > 0 :
417461 collider_kernel_get_contacts (
418- as_tensor ,
419- iout ,
420- fout ,
421- self ._solver ._rigid_global_info ,
422- self ._solver ._static_rigid_sim_config ,
423- self ._collider_state ,
424- self ._collider_info ,
462+ as_tensor , iout , fout , self ._solver ._static_rigid_sim_config , self ._collider_state
425463 )
426464
427465 # Build structured view (no copy)
428466 if as_tensor :
429- if self . _solver . n_envs > 0 :
467+ if n_envs > 0 :
430468 iout = iout .reshape ((n_envs , n_contacts_max , 4 ))
431469 fout = fout .reshape ((n_envs , n_contacts_max , 10 ))
432- if keep_batch_dim and self . _solver . n_envs == 0 :
470+ if keep_batch_dim and n_envs == 0 :
433471 iout = iout .reshape ((1 , n_contacts_max , 4 ))
434472 fout = fout .reshape ((1 , n_contacts_max , 10 ))
435473 iout_chunks = (iout [..., 0 ], iout [..., 1 ], iout [..., 2 ], iout [..., 3 ])
436474 fout_chunks = (fout [..., 0 ], fout [..., 1 :4 ], fout [..., 4 :7 ], fout [..., 7 :])
437475 values = (* iout_chunks , * fout_chunks )
438476 else :
439477 # Split smallest dimension first, then largest dimension
440- if self . _solver . n_envs == 0 :
478+ if n_envs == 0 :
441479 iout_chunks = (iout [..., 0 ], iout [..., 1 ], iout [..., 2 ], iout [..., 3 ])
442480 fout_chunks = (fout [..., 0 ], fout [..., 1 :4 ], fout [..., 4 :7 ], fout [..., 7 :])
443481 values = (* iout_chunks , * fout_chunks )
@@ -454,7 +492,7 @@ def get_contacts(self, as_tensor: bool = True, to_torch: bool = True, keep_batch
454492 else :
455493 iout_chunks = (iout [..., 0 ], iout [..., 1 ], iout [..., 2 ], iout [..., 3 ])
456494 fout_chunks = (fout [..., 0 ], fout [..., 1 :4 ], fout [..., 4 :7 ], fout [..., 7 :])
457- if self . _solver . n_envs == 1 :
495+ if n_envs == 1 :
458496 values = [(value ,) for value in (* iout_chunks , * fout_chunks )]
459497 else :
460498 if to_torch :
@@ -465,13 +503,11 @@ def get_contacts(self, as_tensor: bool = True, to_torch: bool = True, keep_batch
465503 fout_chunks = (np .split (out , n_contacts_starts ) for out in fout_chunks )
466504 values = (* iout_chunks , * fout_chunks )
467505
468- contacts_info = dict (
506+ # Store contact information in cache
507+ contacts_info .update (
469508 zip (("link_a" , "link_b" , "geom_a" , "geom_b" , "penetration" , "position" , "normal" , "force" ), values )
470509 )
471510
472- # Cache contact information before returning
473- self ._contacts_info_cache [(as_tensor , to_torch )] = contacts_info
474-
475511 return contacts_info .copy ()
476512
477513 def backward (self , dL_dposition , dL_dnormal , dL_dpenetration ):
@@ -514,16 +550,16 @@ def collider_kernel_reset(
514550 static_rigid_sim_config : ti .template (),
515551 collider_state : array_class .ColliderState ,
516552):
553+ n_geoms = collider_state .active_buffer .shape [0 ]
554+
517555 ti .loop_config (serialize = static_rigid_sim_config .para_level < gs .PARA_LEVEL .ALL )
518556 for i_b_ in range (envs_idx .shape [0 ]):
519557 i_b = envs_idx [i_b_ ]
520558 collider_state .first_time [i_b ] = 1
521- n_geoms = collider_state .active_buffer .shape [0 ]
522- for i_ga in range (n_geoms ):
523- for i_gb in range (n_geoms ):
524- collider_state .contact_cache .i_va_ws [i_ga , i_gb , i_b ] = - 1
525- collider_state .contact_cache .i_va_ws [i_gb , i_ga , i_b ] = - 1
526- collider_state .contact_cache .normal [i_ga , i_gb , i_b ] = ti .Vector .zero (gs .ti_float , 3 )
559+ for i_ga , i_gb in ti .ndrange (n_geoms , n_geoms ):
560+ collider_state .contact_cache .i_va_ws [i_ga , i_gb , i_b ] = - 1
561+ collider_state .contact_cache .i_va_ws [i_gb , i_ga , i_b ] = - 1
562+ collider_state .contact_cache .normal [i_ga , i_gb , i_b ] = ti .Vector .zero (gs .ti_float , 3 )
527563
528564
529565# only used with hibernation ??
@@ -574,31 +610,34 @@ def kernel_collider_clear(
574610
575611 collider_state .n_contacts_hibernated [i_b ] = i_c_hibernated + 1
576612
613+ for i_c in range (collider_state .n_contacts [i_b ]):
614+ collider_state .contact_data .link_a [i_c , i_b ] = - 1
615+ collider_state .contact_data .link_b [i_c , i_b ] = - 1
616+ collider_state .contact_data .geom_a [i_c , i_b ] = - 1
617+ collider_state .contact_data .geom_b [i_c , i_b ] = - 1
618+ collider_state .contact_data .penetration [i_c , i_b ] = 0.0
619+ collider_state .contact_data .pos [i_c , i_b ] = ti .Vector .zero (gs .ti_float , 3 )
620+ collider_state .contact_data .normal [i_c , i_b ] = ti .Vector .zero (gs .ti_float , 3 )
621+ collider_state .contact_data .force [i_c , i_b ] = ti .Vector .zero (gs .ti_float , 3 )
622+
623+ if ti .static (static_rigid_sim_config .use_hibernation ):
577624 collider_state .n_contacts [i_b ] = collider_state .n_contacts_hibernated [i_b ]
578625 else :
579626 collider_state .n_contacts [i_b ] = 0
580627
628+ collider_state .n_contacts_max [None ] = 0
629+
581630
582631@ti .kernel (fastcache = gs .use_fastcache )
583632def collider_kernel_get_contacts (
584633 is_padded : ti .template (),
585634 iout : ti .types .ndarray (),
586635 fout : ti .types .ndarray (),
587- rigid_global_info : array_class .RigidGlobalInfo ,
588636 static_rigid_sim_config : ti .template (),
589637 collider_state : array_class .ColliderState ,
590- collider_info : array_class .ColliderInfo ,
591638):
592639 _B = collider_state .active_buffer .shape [1 ]
593- n_contacts_max = gs .ti_int (0 )
594-
595- # this is a reduction operation (global max), we have to serialize it
596- # TODO: a good unittest and a better implementation from gstaichi for this kind of reduction
597- ti .loop_config (serialize = True )
598- for i_b in range (_B ):
599- n_contacts = collider_state .n_contacts [i_b ]
600- if n_contacts > n_contacts_max :
601- n_contacts_max = n_contacts
640+ n_contacts_max = collider_state .n_contacts_max [None ]
602641
603642 ti .loop_config (serialize = static_rigid_sim_config .para_level < gs .PARA_LEVEL .ALL )
604643 for i_b in range (_B ):
@@ -1187,9 +1226,10 @@ def func_collision_clear(
11871226 static_rigid_sim_config : ti .template (),
11881227):
11891228 _B = collider_state .n_contacts .shape [0 ]
1190- if ti .static (static_rigid_sim_config .use_hibernation ):
1191- ti .loop_config (serialize = static_rigid_sim_config .para_level < gs .PARA_LEVEL .ALL )
1192- for i_b in range (_B ):
1229+
1230+ ti .loop_config (serialize = static_rigid_sim_config .para_level < gs .PARA_LEVEL .ALL )
1231+ for i_b in range (_B ):
1232+ if ti .static (static_rigid_sim_config .use_hibernation ):
11931233 collider_state .n_contacts_hibernated [i_b ] = 0
11941234
11951235 # Advect hibernated contacts
@@ -1210,12 +1250,23 @@ def func_collision_clear(
12101250 collider_state .contact_data [i_c_hibernated , i_b ] = collider_state .contact_data [i_c , i_b ]
12111251 collider_state .n_contacts_hibernated [i_b ] = i_c_hibernated + 1
12121252
1253+ for i_c in range (collider_state .n_contacts [i_b ]):
1254+ collider_state .contact_data .link_a [i_c , i_b ] = - 1
1255+ collider_state .contact_data .link_b [i_c , i_b ] = - 1
1256+ collider_state .contact_data .geom_a [i_c , i_b ] = - 1
1257+ collider_state .contact_data .geom_b [i_c , i_b ] = - 1
1258+ collider_state .contact_data .penetration [i_c , i_b ] = 0.0
1259+ collider_state .contact_data .pos [i_c , i_b ] = ti .Vector .zero (gs .ti_float , 3 )
1260+ collider_state .contact_data .normal [i_c , i_b ] = ti .Vector .zero (gs .ti_float , 3 )
1261+ collider_state .contact_data .force [i_c , i_b ] = ti .Vector .zero (gs .ti_float , 3 )
1262+
1263+ if ti .static (static_rigid_sim_config .use_hibernation ):
12131264 collider_state .n_contacts [i_b ] = collider_state .n_contacts_hibernated [i_b ]
1214- else :
1215- ti .loop_config (serialize = static_rigid_sim_config .para_level < gs .PARA_LEVEL .ALL )
1216- for i_b in range (_B ):
1265+ else :
12171266 collider_state .n_contacts [i_b ] = 0
12181267
1268+ collider_state .n_contacts_max [None ] = 0
1269+
12191270
12201271@ti .kernel (fastcache = gs .use_fastcache )
12211272def func_broad_phase (
@@ -2087,6 +2138,7 @@ def func_add_contact(
20872138 collider_state .contact_data .link_b [i_c , i_b ] = geoms_info .link_idx [i_gb ]
20882139
20892140 collider_state .n_contacts [i_b ] = i_c + 1
2141+ ti .atomic_max (collider_state .n_contacts_max [None ], i_c + 1 )
20902142 else :
20912143 errno [None ] = 2
20922144
0 commit comments