@@ -378,6 +378,7 @@ def detection(self) -> None:
378378 self ._gjk ._gjk_static_config ,
379379 self ._sdf ._sdf_info ,
380380 self ._support_field ._support_field_info ,
381+ self ._gjk ._gjk_state .diff_contact_input ,
381382 self ._solver ._errno ,
382383 )
383384 func_narrow_phase_convex_specializations (
@@ -439,7 +440,7 @@ def get_contacts(self, as_tensor: bool = True, to_torch: bool = True, keep_batch
439440
440441 for key , data in self ._contacts_info .items ():
441442 if n_envs == 0 :
442- data = data [0 , :n_contacts_max ]
443+ data = data [0 , :n_contacts_max ] if not keep_batch_dim else data [:, : n_contacts_max ]
443444 if not to_torch :
444445 data = tensor_to_array (data )
445446 else :
@@ -542,6 +543,7 @@ def backward(self, dL_dposition, dL_dnormal, dL_dpenetration):
542543 self ._gjk ._gjk_state ,
543544 self ._gjk ._gjk_info ,
544545 self ._gjk ._gjk_static_config ,
546+ self ._collider_state .diff_contact_input ,
545547 )
546548
547549
@@ -1570,6 +1572,7 @@ def func_narrow_phase_convex_vs_convex(
15701572 gjk_static_config : ti .template (),
15711573 sdf_info : array_class .SDFInfo ,
15721574 support_field_info : array_class .SupportFieldInfo ,
1575+ diff_contact_input : array_class .DiffContactInput ,
15731576 errno : array_class .V_ANNOTATION ,
15741577):
15751578 """
@@ -1626,6 +1629,8 @@ def func_narrow_phase_convex_vs_convex(
16261629 gjk_static_config = gjk_static_config ,
16271630 sdf_info = sdf_info ,
16281631 support_field_info = support_field_info ,
1632+ # FIXME: Passing nested data structure as input argument is not supported for now.
1633+ diff_contact_input = diff_contact_input ,
16291634 errno = errno ,
16301635 )
16311636 else :
@@ -1654,6 +1659,8 @@ def func_narrow_phase_convex_vs_convex(
16541659 gjk_static_config = gjk_static_config ,
16551660 sdf_info = sdf_info ,
16561661 support_field_info = support_field_info ,
1662+ # FIXME: Passing nested data structure as input argument is not supported for now.
1663+ diff_contact_input = diff_contact_input ,
16571664 errno = errno ,
16581665 )
16591666
@@ -1668,6 +1675,8 @@ def func_narrow_phase_diff_convex_vs_convex(
16681675 gjk_state : array_class .GJKState ,
16691676 gjk_info : array_class .GJKInfo ,
16701677 gjk_static_config : ti .template (),
1678+ # FIXME: Passing nested data structure as input argument is not supported for now.
1679+ diff_contact_input : array_class .DiffContactInput ,
16711680):
16721681 # Compute reference contacts
16731682 ti .loop_config (serialize = static_rigid_sim_config .para_level < gs .PARA_LEVEL .PARTIAL )
@@ -1681,16 +1690,8 @@ def func_narrow_phase_diff_convex_vs_convex(
16811690 if is_ref :
16821691 ref_penetration = - 1.0
16831692 contact_pos , contact_normal , penetration , weight = diff_gjk .func_differentiable_contact (
1684- geoms_state ,
1685- collider_state .diff_contact_input ,
1686- gjk_info ,
1687- i_ga ,
1688- i_gb ,
1689- i_b ,
1690- i_c ,
1691- ref_penetration ,
1693+ geoms_state , diff_contact_input , gjk_info , i_ga , i_gb , i_b , i_c , ref_penetration
16921694 )
1693-
16941695 collider_state .diff_contact_input .ref_penetration [i_b , i_c ] = penetration
16951696
16961697 func_set_contact (
@@ -1718,15 +1719,9 @@ def func_narrow_phase_diff_convex_vs_convex(
17181719 if not is_ref :
17191720 ref_penetration = collider_state .diff_contact_input .ref_penetration [i_b , ref_id ]
17201721 contact_pos , contact_normal , penetration , weight = diff_gjk .func_differentiable_contact (
1721- geoms_state ,
1722- collider_state .diff_contact_input ,
1723- gjk_info ,
1724- i_ga ,
1725- i_gb ,
1726- i_b ,
1727- i_c ,
1728- ref_penetration ,
1722+ geoms_state , diff_contact_input , gjk_info , i_ga , i_gb , i_b , i_c , ref_penetration
17291723 )
1724+
17301725 func_set_contact (
17311726 i_ga ,
17321727 i_gb ,
@@ -2338,6 +2333,8 @@ def func_convex_convex_contact(
23382333 gjk_static_config : ti .template (),
23392334 sdf_info : array_class .SDFInfo ,
23402335 support_field_info : array_class .SupportFieldInfo ,
2336+ # FIXME: Passing nested data structure as input argument is not supported for now.
2337+ diff_contact_input : array_class .DiffContactInput ,
23412338 errno : array_class .V_ANNOTATION ,
23422339):
23432340 if geoms_info .type [i_ga ] == gs .GEOM_TYPE .PLANE and geoms_info .type [i_gb ] == gs .GEOM_TYPE .BOX :
@@ -2510,6 +2507,7 @@ def func_convex_convex_contact(
25102507 gjk_state ,
25112508 gjk_info ,
25122509 support_field_info ,
2510+ diff_contact_input ,
25132511 i_ga ,
25142512 i_gb ,
25152513 i_b ,
0 commit comments