@@ -360,15 +360,15 @@ class StructContactData:
360360 link_b : V_ANNOTATION
361361
362362
363- def get_contact_data (solver , max_contact_pairs ):
363+ def get_contact_data (solver , max_contact_pairs , requires_grad ):
364364 f_batch = solver ._batch_shape
365365 max_contact_pairs_ = max (1 , max_contact_pairs )
366366 kwargs = {
367367 "geom_a" : V (dtype = gs .ti_int , shape = f_batch (max_contact_pairs_ )),
368368 "geom_b" : V (dtype = gs .ti_int , shape = f_batch (max_contact_pairs_ )),
369- "penetration " : V (dtype = gs .ti_float , shape = f_batch (max_contact_pairs_ )),
370- "normal " : V_VEC ( 3 , dtype = gs .ti_float , shape = f_batch (max_contact_pairs_ )),
371- "pos " : V_VEC ( 3 , dtype = gs .ti_float , shape = f_batch (max_contact_pairs_ )),
369+ "normal " : V (dtype = gs .ti_vec3 , shape = f_batch (max_contact_pairs_ ), needs_grad = requires_grad ),
370+ "pos " : V ( dtype = gs .ti_vec3 , shape = f_batch (max_contact_pairs_ ), needs_grad = requires_grad ),
371+ "penetration " : V ( dtype = gs .ti_float , shape = f_batch (max_contact_pairs_ ), needs_grad = requires_grad ),
372372 "friction" : V (dtype = gs .ti_float , shape = f_batch (max_contact_pairs_ )),
373373 "sol_params" : V_VEC (7 , dtype = gs .ti_float , shape = f_batch (max_contact_pairs_ )),
374374 "force" : V (dtype = gs .ti_vec3 , shape = f_batch (max_contact_pairs_ )),
@@ -389,6 +389,62 @@ def __init__(self):
389389 return ClassContactData ()
390390
391391
392+ @dataclasses .dataclass
393+ class StructDiffContactInput :
394+ ### Non-differentiable input data
395+ # Geom id of the two geometries
396+ geom_a : V_ANNOTATION
397+ geom_b : V_ANNOTATION
398+ # Local positions of the 3 vertices from the two geometries that define the face on the Minkowski difference
399+ local_pos1_a : V_ANNOTATION
400+ local_pos1_b : V_ANNOTATION
401+ local_pos1_c : V_ANNOTATION
402+ local_pos2_a : V_ANNOTATION
403+ local_pos2_b : V_ANNOTATION
404+ local_pos2_c : V_ANNOTATION
405+ # Local positions of the 1 vertex from the two geometries that define the support point for the face above
406+ w_local_pos1 : V_ANNOTATION
407+ w_local_pos2 : V_ANNOTATION
408+ # Reference id of the contact point, which is needed for the backward pass
409+ ref_id : V_ANNOTATION
410+ # Flag whether the contact data can be computed in numerically stable way in both the forward and backward passes
411+ valid : V_ANNOTATION
412+ ### Differentiable input data
413+ # Reference penetration depth, which is needed for computing the weight of the contact point
414+ ref_penetration : V_ANNOTATION
415+
416+
417+ def get_diff_contact_input (solver , max_contacts_per_pair ):
418+ _B = solver ._B
419+ kwargs = {
420+ "geom_a" : V (dtype = gs .ti_int , shape = (_B , max_contacts_per_pair )),
421+ "geom_b" : V (dtype = gs .ti_int , shape = (_B , max_contacts_per_pair )),
422+ "local_pos1_a" : V_VEC (3 , dtype = gs .ti_float , shape = (_B , max_contacts_per_pair )),
423+ "local_pos1_b" : V_VEC (3 , dtype = gs .ti_float , shape = (_B , max_contacts_per_pair )),
424+ "local_pos1_c" : V_VEC (3 , dtype = gs .ti_float , shape = (_B , max_contacts_per_pair )),
425+ "local_pos2_a" : V_VEC (3 , dtype = gs .ti_float , shape = (_B , max_contacts_per_pair )),
426+ "local_pos2_b" : V_VEC (3 , dtype = gs .ti_float , shape = (_B , max_contacts_per_pair )),
427+ "local_pos2_c" : V_VEC (3 , dtype = gs .ti_float , shape = (_B , max_contacts_per_pair )),
428+ "w_local_pos1" : V_VEC (3 , dtype = gs .ti_float , shape = (_B , max_contacts_per_pair )),
429+ "w_local_pos2" : V_VEC (3 , dtype = gs .ti_float , shape = (_B , max_contacts_per_pair )),
430+ "ref_id" : V (dtype = gs .ti_int , shape = (_B , max_contacts_per_pair )),
431+ "valid" : V (dtype = gs .ti_int , shape = (_B , max_contacts_per_pair )),
432+ "ref_penetration" : V (dtype = gs .ti_float , shape = (_B , max_contacts_per_pair ), needs_grad = True ),
433+ }
434+
435+ if use_ndarray :
436+ return StructDiffContactInput (** kwargs )
437+ else :
438+
439+ @ti .data_oriented
440+ class ClassDiffContactInput :
441+ def __init__ (self ):
442+ for k , v in kwargs .items ():
443+ setattr (self , k , v )
444+
445+ return ClassDiffContactInput ()
446+
447+
392448@dataclasses .dataclass
393449class StructSortBuffer :
394450 value : V_ANNOTATION
@@ -547,19 +603,22 @@ class StructColliderState:
547603 n_contacts_hibernated : V_ANNOTATION
548604 first_time : V_ANNOTATION
549605 contact_cache : StructContactCache
606+ # Input data for differentiable contact detection used in the backward pass
607+ diff_contact_input : StructDiffContactInput
550608
551609
552- def get_collider_state (solver , n_possible_pairs , collider_static_config ):
610+ def get_collider_state (solver , static_rigid_sim_config , n_possible_pairs , collider_static_config ):
553611 _B = solver ._B
554612 f_batch = solver ._batch_shape
555613 n_geoms = solver .n_geoms_
556614 max_collision_pairs = min (solver ._max_collision_pairs , n_possible_pairs )
557615 max_collision_pairs_broad = max_collision_pairs * collider_static_config .max_collision_pairs_broad_k
558616 max_contact_pairs = max_collision_pairs * collider_static_config .n_contacts_per_pair
617+ requires_grad = static_rigid_sim_config .requires_grad
559618
560619 ############## broad phase SAP ##############
561620
562- contact_data = get_contact_data (solver , max_contact_pairs )
621+ contact_data = get_contact_data (solver , max_contact_pairs , requires_grad )
563622 sort_buffer = get_sort_buffer (solver )
564623 contact_cache = get_contact_cache (solver )
565624 kwargs = {
@@ -584,6 +643,7 @@ def get_collider_state(solver, n_possible_pairs, collider_static_config):
584643 "n_contacts_hibernated" : V (dtype = gs .ti_int , shape = _B ),
585644 "first_time" : V (dtype = gs .ti_int , shape = _B ),
586645 "contact_cache" : contact_cache ,
646+ "diff_contact_input" : get_diff_contact_input (solver , max_contact_pairs if requires_grad else 1 ),
587647 }
588648
589649 if use_ndarray :
@@ -722,6 +782,8 @@ class StructMDVertex:
722782 # Vertex of the Minkowski difference
723783 obj1 : V_ANNOTATION
724784 obj2 : V_ANNOTATION
785+ local_obj1 : V_ANNOTATION
786+ local_obj2 : V_ANNOTATION
725787 id1 : V_ANNOTATION
726788 id2 : V_ANNOTATION
727789 mink : V_ANNOTATION
@@ -732,6 +794,8 @@ def get_gjk_simplex_vertex(solver):
732794 kwargs = {
733795 "obj1" : V_VEC (3 , dtype = gs .ti_float , shape = (_B , 4 )),
734796 "obj2" : V_VEC (3 , dtype = gs .ti_float , shape = (_B , 4 )),
797+ "local_obj1" : V_VEC (3 , dtype = gs .ti_float , shape = (_B , 4 )),
798+ "local_obj2" : V_VEC (3 , dtype = gs .ti_float , shape = (_B , 4 )),
735799 "id1" : V (dtype = gs .ti_int , shape = (_B , 4 )),
736800 "id2" : V (dtype = gs .ti_int , shape = (_B , 4 )),
737801 "mink" : V_VEC (3 , dtype = gs .ti_float , shape = (_B , 4 )),
@@ -756,6 +820,8 @@ def get_epa_polytope_vertex(solver, gjk_static_config):
756820 kwargs = {
757821 "obj1" : V_VEC (3 , dtype = gs .ti_float , shape = (_B , max_num_polytope_verts )),
758822 "obj2" : V_VEC (3 , dtype = gs .ti_float , shape = (_B , max_num_polytope_verts )),
823+ "local_obj1" : V_VEC (3 , dtype = gs .ti_float , shape = (_B , max_num_polytope_verts )),
824+ "local_obj2" : V_VEC (3 , dtype = gs .ti_float , shape = (_B , max_num_polytope_verts )),
759825 "id1" : V (dtype = gs .ti_int , shape = (_B , max_num_polytope_verts )),
760826 "id2" : V (dtype = gs .ti_int , shape = (_B , max_num_polytope_verts )),
761827 "mink" : V_VEC (3 , dtype = gs .ti_float , shape = (_B , max_num_polytope_verts )),
@@ -865,6 +931,7 @@ class StructEPAPolytopeFace:
865931 normal : V_ANNOTATION
866932 dist2 : V_ANNOTATION
867933 map_idx : V_ANNOTATION
934+ visited : V_ANNOTATION
868935
869936
870937def get_epa_polytope_face (solver , polytope_max_faces ):
@@ -876,6 +943,7 @@ def get_epa_polytope_face(solver, polytope_max_faces):
876943 "normal" : V_VEC (3 , dtype = gs .ti_float , shape = (_B , polytope_max_faces )),
877944 "dist2" : V (dtype = gs .ti_float , shape = (_B , polytope_max_faces )),
878945 "map_idx" : V (dtype = gs .ti_int , shape = (_B , polytope_max_faces )),
946+ "visited" : V (dtype = gs .ti_int , shape = (_B , polytope_max_faces )),
879947 }
880948
881949 if use_ndarray :
@@ -1060,6 +1128,10 @@ class StructGJKState:
10601128 is_col : V_ANNOTATION
10611129 penetration : V_ANNOTATION
10621130 distance : V_ANNOTATION
1131+ # Differentiable contact detection
1132+ diff_contact_input : StructDiffContactInput
1133+ n_diff_contact_input : V_ANNOTATION
1134+ diff_penetration : V_ANNOTATION
10631135
10641136
10651137def get_gjk_state (solver , static_rigid_sim_config , gjk_static_config ):
@@ -1068,6 +1140,7 @@ def get_gjk_state(solver, static_rigid_sim_config, gjk_static_config):
10681140 polytope_max_faces = gjk_static_config .polytope_max_faces
10691141 max_contacts_per_pair = gjk_static_config .max_contacts_per_pair
10701142 max_contact_polygon_verts = gjk_static_config .max_contact_polygon_verts
1143+ requires_grad = solver ._static_rigid_sim_config .requires_grad
10711144
10721145 ### GJK simplex
10731146 simplex_vertex = get_gjk_simplex_vertex (solver )
@@ -1141,6 +1214,9 @@ def get_gjk_state(solver, static_rigid_sim_config, gjk_static_config):
11411214 "is_col" : V (dtype = gs .ti_bool , shape = (_B ,)),
11421215 "penetration" : V (dtype = gs .ti_float , shape = (_B ,)),
11431216 "distance" : V (dtype = gs .ti_float , shape = (_B ,)),
1217+ "diff_contact_input" : get_diff_contact_input (solver , max (1 , max_contacts_per_pair if requires_grad else 1 )),
1218+ "n_diff_contact_input" : V (dtype = gs .ti_int , shape = (_B ,)),
1219+ "diff_penetration" : V (dtype = gs .ti_float , shape = (_B , max_contacts_per_pair )),
11441220 }
11451221 )
11461222
@@ -1799,9 +1875,10 @@ class StructGeomsState:
17991875
18001876def get_geoms_state (solver ):
18011877 shape = solver ._batch_shape (solver .n_geoms_ )
1878+ requires_grad = solver ._static_rigid_sim_config .requires_grad
18021879 kwargs = {
1803- "pos" : V (dtype = gs .ti_vec3 , shape = shape ),
1804- "quat" : V (dtype = gs .ti_vec4 , shape = shape ),
1880+ "pos" : V (dtype = gs .ti_vec3 , shape = shape , needs_grad = requires_grad ),
1881+ "quat" : V (dtype = gs .ti_vec4 , shape = shape , needs_grad = requires_grad ),
18051882 "aabb_min" : V (dtype = gs .ti_vec3 , shape = shape ),
18061883 "aabb_max" : V (dtype = gs .ti_vec3 , shape = shape ),
18071884 "verts_updated" : V (dtype = gs .ti_bool , shape = shape ),
@@ -2381,3 +2458,4 @@ def __init__(self, solver):
23812458LinksStateAdjointCache = ti .template () if not use_ndarray else StructLinksState
23822459JointsStateAdjointCache = ti .template () if not use_ndarray else StructJointsState
23832460GeomsStateAdjointCache = ti .template () if not use_ndarray else StructGeomsState
2461+ DiffContactInput = ti .template () if not use_ndarray else StructDiffContactInput
0 commit comments