Skip to content

Commit 2ed3cd0

Browse files
authored
[BUG FIX] Fix differentiable simulation support for ndarray. (#2068)
* Fix differentiable simulation support for ndarray. * Fix shape issue with zero-copy for matrix field of shape Nx1.
1 parent 4ce5ca8 commit 2ed3cd0

File tree

7 files changed

+125
-230
lines changed

7 files changed

+125
-230
lines changed

genesis/engine/solvers/rigid/collider_decomp.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ def detection(self) -> None:
378378
self._gjk._gjk_static_config,
379379
self._sdf._sdf_info,
380380
self._support_field._support_field_info,
381+
self._gjk._gjk_state.diff_contact_input,
381382
self._solver._errno,
382383
)
383384
func_narrow_phase_convex_specializations(
@@ -439,7 +440,7 @@ def get_contacts(self, as_tensor: bool = True, to_torch: bool = True, keep_batch
439440

440441
for key, data in self._contacts_info.items():
441442
if n_envs == 0:
442-
data = data[0, :n_contacts_max]
443+
data = data[0, :n_contacts_max] if not keep_batch_dim else data[:, :n_contacts_max]
443444
if not to_torch:
444445
data = tensor_to_array(data)
445446
else:
@@ -542,6 +543,7 @@ def backward(self, dL_dposition, dL_dnormal, dL_dpenetration):
542543
self._gjk._gjk_state,
543544
self._gjk._gjk_info,
544545
self._gjk._gjk_static_config,
546+
self._collider_state.diff_contact_input,
545547
)
546548

547549

@@ -1570,6 +1572,7 @@ def func_narrow_phase_convex_vs_convex(
15701572
gjk_static_config: ti.template(),
15711573
sdf_info: array_class.SDFInfo,
15721574
support_field_info: array_class.SupportFieldInfo,
1575+
diff_contact_input: array_class.DiffContactInput,
15731576
errno: array_class.V_ANNOTATION,
15741577
):
15751578
"""
@@ -1626,6 +1629,8 @@ def func_narrow_phase_convex_vs_convex(
16261629
gjk_static_config=gjk_static_config,
16271630
sdf_info=sdf_info,
16281631
support_field_info=support_field_info,
1632+
# FIXME: Passing nested data structure as input argument is not supported for now.
1633+
diff_contact_input=diff_contact_input,
16291634
errno=errno,
16301635
)
16311636
else:
@@ -1654,6 +1659,8 @@ def func_narrow_phase_convex_vs_convex(
16541659
gjk_static_config=gjk_static_config,
16551660
sdf_info=sdf_info,
16561661
support_field_info=support_field_info,
1662+
# FIXME: Passing nested data structure as input argument is not supported for now.
1663+
diff_contact_input=diff_contact_input,
16571664
errno=errno,
16581665
)
16591666

@@ -1668,6 +1675,8 @@ def func_narrow_phase_diff_convex_vs_convex(
16681675
gjk_state: array_class.GJKState,
16691676
gjk_info: array_class.GJKInfo,
16701677
gjk_static_config: ti.template(),
1678+
# FIXME: Passing nested data structure as input argument is not supported for now.
1679+
diff_contact_input: array_class.DiffContactInput,
16711680
):
16721681
# Compute reference contacts
16731682
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL)
@@ -1681,16 +1690,8 @@ def func_narrow_phase_diff_convex_vs_convex(
16811690
if is_ref:
16821691
ref_penetration = -1.0
16831692
contact_pos, contact_normal, penetration, weight = diff_gjk.func_differentiable_contact(
1684-
geoms_state,
1685-
collider_state.diff_contact_input,
1686-
gjk_info,
1687-
i_ga,
1688-
i_gb,
1689-
i_b,
1690-
i_c,
1691-
ref_penetration,
1693+
geoms_state, diff_contact_input, gjk_info, i_ga, i_gb, i_b, i_c, ref_penetration
16921694
)
1693-
16941695
collider_state.diff_contact_input.ref_penetration[i_b, i_c] = penetration
16951696

16961697
func_set_contact(
@@ -1718,15 +1719,9 @@ def func_narrow_phase_diff_convex_vs_convex(
17181719
if not is_ref:
17191720
ref_penetration = collider_state.diff_contact_input.ref_penetration[i_b, ref_id]
17201721
contact_pos, contact_normal, penetration, weight = diff_gjk.func_differentiable_contact(
1721-
geoms_state,
1722-
collider_state.diff_contact_input,
1723-
gjk_info,
1724-
i_ga,
1725-
i_gb,
1726-
i_b,
1727-
i_c,
1728-
ref_penetration,
1722+
geoms_state, diff_contact_input, gjk_info, i_ga, i_gb, i_b, i_c, ref_penetration
17291723
)
1724+
17301725
func_set_contact(
17311726
i_ga,
17321727
i_gb,
@@ -2338,6 +2333,8 @@ def func_convex_convex_contact(
23382333
gjk_static_config: ti.template(),
23392334
sdf_info: array_class.SDFInfo,
23402335
support_field_info: array_class.SupportFieldInfo,
2336+
# FIXME: Passing nested data structure as input argument is not supported for now.
2337+
diff_contact_input: array_class.DiffContactInput,
23412338
errno: array_class.V_ANNOTATION,
23422339
):
23432340
if geoms_info.type[i_ga] == gs.GEOM_TYPE.PLANE and geoms_info.type[i_gb] == gs.GEOM_TYPE.BOX:
@@ -2510,6 +2507,7 @@ def func_convex_convex_contact(
25102507
gjk_state,
25112508
gjk_info,
25122509
support_field_info,
2510+
diff_contact_input,
25132511
i_ga,
25142512
i_gb,
25152513
i_b,

genesis/engine/solvers/rigid/diff_gjk_decomp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ def func_gjk_contact(
2121
gjk_state: array_class.GJKState,
2222
gjk_info: array_class.GJKInfo,
2323
support_field_info: array_class.SupportFieldInfo,
24+
# FIXME: Passing nested data structure as input argument is not supported for now.
25+
diff_contact_input: array_class.DiffContactInput,
2426
i_ga,
2527
i_gb,
2628
i_b,
@@ -277,7 +279,7 @@ def func_gjk_contact(
277279
if i_c > 0:
278280
ref_penetration = default_penetration
279281
contact_pos, contact_normal, penetration, weight = func_differentiable_contact(
280-
geoms_state, gjk_state.diff_contact_input, gjk_info, i_ga, i_gb, i_b, i_c, ref_penetration
282+
geoms_state, diff_contact_input, gjk_info, i_ga, i_gb, i_b, i_c, ref_penetration
281283
)
282284
if i_c == 0:
283285
default_penetration = penetration

0 commit comments

Comments
 (0)