Skip to content

Commit 62dcf0f

Browse files
committed
Merge branch 'main' into easydiffrigid
2 parents b04a29a + 1794b85 commit 62dcf0f

File tree

11 files changed

+1527
-91
lines changed

11 files changed

+1527
-91
lines changed

genesis/engine/scene.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,22 @@ def _validate_options(
273273
if not isinstance(renderer_options, RendererOptions):
274274
gs.raise_exception("`renderer` should be an instance of `gs.renderers.Renderer`.")
275275

276+
# Validate rigid_options against sim_options
277+
if rigid_options.box_box_detection is None:
278+
rigid_options.box_box_detection = not sim_options.requires_grad
279+
elif rigid_options.box_box_detection and sim_options.requires_grad:
280+
gs.raise_exception(
281+
"`rigid_options.box_box_detection` cannot be True when `sim_options.requires_grad` is True."
282+
)
283+
if not rigid_options.use_gjk_collision and sim_options.requires_grad:
284+
gs.raise_exception(
285+
"`rigid_options.use_gjk_collision` cannot be False when `sim_options.requires_grad` is True."
286+
)
287+
if rigid_options.enable_mujoco_compatibility and sim_options.requires_grad:
288+
gs.raise_exception(
289+
"`rigid_options.enable_mujoco_compatibility` cannot be True when `sim_options.requires_grad` is True."
290+
)
291+
276292
def destroy(self):
277293
if getattr(self, "_recorder_manager", None) is not None:
278294
if self._recorder_manager.is_recording:

genesis/engine/solvers/pbd_solver.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -923,6 +923,7 @@ def set_animate_particles_by_link(
923923
envs_idx = self._scene._sanitize_envs_idx(envs_idx)
924924
self._sim._coupler.kernel_attach_pbd_to_rigid_link(particles_idx, envs_idx, link_idx, links_state)
925925

926+
@ti.kernel
926927
def _kernel_get_particles_vel(
927928
self,
928929
particle_start: ti.i32,

genesis/engine/solvers/rigid/collider_decomp.py

Lines changed: 294 additions & 51 deletions
Large diffs are not rendered by default.

genesis/engine/solvers/rigid/diff_gjk_decomp.py

Lines changed: 897 additions & 0 deletions
Large diffs are not rendered by default.

genesis/engine/solvers/rigid/gjk_decomp.py

Lines changed: 89 additions & 21 deletions
Large diffs are not rendered by default.

genesis/engine/solvers/rigid/mpr_decomp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,18 +196,18 @@ def support_driver(
196196
v = ti.Vector.zero(gs.ti_float, 3)
197197
geom_type = geoms_info.type[i_g]
198198
if geom_type == gs.GEOM_TYPE.SPHERE:
199-
v = support_field._func_support_sphere(geoms_state, geoms_info, direction, i_g, i_b, False)
199+
v, v_, vid = support_field._func_support_sphere(geoms_state, geoms_info, direction, i_g, i_b, False)
200200
elif geom_type == gs.GEOM_TYPE.ELLIPSOID:
201201
v = support_field._func_support_ellipsoid(geoms_state, geoms_info, direction, i_g, i_b)
202202
elif geom_type == gs.GEOM_TYPE.CAPSULE:
203203
v = support_field._func_support_capsule(geoms_state, geoms_info, direction, i_g, i_b, False)
204204
elif geom_type == gs.GEOM_TYPE.BOX:
205-
v, _ = support_field._func_support_box(geoms_state, geoms_info, direction, i_g, i_b)
205+
v, v_, vid = support_field._func_support_box(geoms_state, geoms_info, direction, i_g, i_b)
206206
elif geom_type == gs.GEOM_TYPE.TERRAIN:
207207
if ti.static(collider_static_config.has_terrain):
208208
v, _ = support_field._func_support_prism(collider_state, direction, i_g, i_b)
209209
else:
210-
v, _ = support_field._func_support_world(
210+
v, v_, vid = support_field._func_support_world(
211211
geoms_state, geoms_info, support_field_info, support_field_static_config, direction, i_g, i_b
212212
)
213213
return v

genesis/engine/solvers/rigid/rigid_solver_decomp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ def build(self):
247247
self._static_rigid_sim_cache_key = array_class.get_static_rigid_sim_cache_key(self)
248248
self._static_rigid_sim_config = self.StaticRigidSimConfig(
249249
para_level=self.sim._para_level,
250+
requires_grad=getattr(self.sim.options, "requires_grad", False),
250251
use_hibernation=getattr(self, "_use_hibernation", False),
251252
use_contact_island=getattr(self, "_use_contact_island", False),
252253
batch_links_info=getattr(self._options, "batch_links_info", False),

genesis/engine/solvers/rigid/support_field_decomp.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,9 @@ def _func_support_world(
139139
g_pos = geoms_state.pos[i_g, i_b]
140140
g_quat = geoms_state.quat[i_g, i_b]
141141
d_mesh = gu.ti_transform_by_quat(d, gu.ti_inv_quat(g_quat))
142-
v, vid = _func_support_mesh(support_field_info, support_field_static_config, d_mesh, i_g)
143-
v_ = gu.ti_transform_by_trans_quat(v, g_pos, g_quat)
144-
return v_, vid
142+
v_, vid = _func_support_mesh(support_field_info, support_field_static_config, d_mesh, i_g)
143+
v = gu.ti_transform_by_trans_quat(v_, g_pos, g_quat)
144+
return v, v_, vid
145145

146146

147147
@ti.func
@@ -207,10 +207,18 @@ def _func_support_sphere(
207207
sphere_radius = geoms_info.data[i_g][0]
208208

209209
# Shrink the sphere to a point
210-
res = sphere_center
210+
v = sphere_center
211+
v_ = ti.Vector.zero(gs.ti_float, 3)
212+
vid = -1
211213
if not shrink:
212-
res += d * sphere_radius
213-
return res
214+
v += d * sphere_radius
215+
216+
# Local position of the support point
217+
g_quat = geoms_state.quat[i_g, i_b]
218+
local_d = gu.ti_inv_transform_by_quat(d, g_quat)
219+
v_ = local_d * sphere_radius
220+
221+
return v, v_, vid
214222

215223

216224
@ti.func
@@ -308,7 +316,7 @@ def _func_support_box(
308316
vid = (v_[0] > 0.0) * 1 + (v_[1] > 0.0) * 2 + (v_[2] > 0.0) * 4
309317
vid += geoms_info.vert_start[i_g]
310318
v = gu.ti_transform_by_trans_quat(v_, g_pos, g_quat)
311-
return v, vid
319+
return v, v_, vid
312320

313321

314322
@ti.func

genesis/options/solvers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ class RigidOptions(Options):
315315
contact_resolve_time: Optional[float] = None
316316
constraint_timeconst: float = 0.01
317317
use_contact_island: bool = False
318-
box_box_detection: bool = True
318+
box_box_detection: Optional[bool] = None
319319

320320
# hibernation threshold
321321
use_hibernation: bool = False

genesis/utils/array_class.py

Lines changed: 86 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -360,15 +360,15 @@ class StructContactData:
360360
link_b: V_ANNOTATION
361361

362362

363-
def get_contact_data(solver, max_contact_pairs):
363+
def get_contact_data(solver, max_contact_pairs, requires_grad):
364364
f_batch = solver._batch_shape
365365
max_contact_pairs_ = max(1, max_contact_pairs)
366366
kwargs = {
367367
"geom_a": V(dtype=gs.ti_int, shape=f_batch(max_contact_pairs_)),
368368
"geom_b": V(dtype=gs.ti_int, shape=f_batch(max_contact_pairs_)),
369-
"penetration": V(dtype=gs.ti_float, shape=f_batch(max_contact_pairs_)),
370-
"normal": V_VEC(3, dtype=gs.ti_float, shape=f_batch(max_contact_pairs_)),
371-
"pos": V_VEC(3, dtype=gs.ti_float, shape=f_batch(max_contact_pairs_)),
369+
"normal": V(dtype=gs.ti_vec3, shape=f_batch(max_contact_pairs_), needs_grad=requires_grad),
370+
"pos": V(dtype=gs.ti_vec3, shape=f_batch(max_contact_pairs_), needs_grad=requires_grad),
371+
"penetration": V(dtype=gs.ti_float, shape=f_batch(max_contact_pairs_), needs_grad=requires_grad),
372372
"friction": V(dtype=gs.ti_float, shape=f_batch(max_contact_pairs_)),
373373
"sol_params": V_VEC(7, dtype=gs.ti_float, shape=f_batch(max_contact_pairs_)),
374374
"force": V(dtype=gs.ti_vec3, shape=f_batch(max_contact_pairs_)),
@@ -389,6 +389,62 @@ def __init__(self):
389389
return ClassContactData()
390390

391391

392+
@dataclasses.dataclass
393+
class StructDiffContactInput:
394+
### Non-differentiable input data
395+
# Geom id of the two geometries
396+
geom_a: V_ANNOTATION
397+
geom_b: V_ANNOTATION
398+
# Local positions of the 3 vertices from the two geometries that define the face on the Minkowski difference
399+
local_pos1_a: V_ANNOTATION
400+
local_pos1_b: V_ANNOTATION
401+
local_pos1_c: V_ANNOTATION
402+
local_pos2_a: V_ANNOTATION
403+
local_pos2_b: V_ANNOTATION
404+
local_pos2_c: V_ANNOTATION
405+
# Local positions of the 1 vertex from the two geometries that define the support point for the face above
406+
w_local_pos1: V_ANNOTATION
407+
w_local_pos2: V_ANNOTATION
408+
# Reference id of the contact point, which is needed for the backward pass
409+
ref_id: V_ANNOTATION
410+
# Flag whether the contact data can be computed in numerically stable way in both the forward and backward passes
411+
valid: V_ANNOTATION
412+
### Differentiable input data
413+
# Reference penetration depth, which is needed for computing the weight of the contact point
414+
ref_penetration: V_ANNOTATION
415+
416+
417+
def get_diff_contact_input(solver, max_contacts_per_pair):
418+
_B = solver._B
419+
kwargs = {
420+
"geom_a": V(dtype=gs.ti_int, shape=(_B, max_contacts_per_pair)),
421+
"geom_b": V(dtype=gs.ti_int, shape=(_B, max_contacts_per_pair)),
422+
"local_pos1_a": V_VEC(3, dtype=gs.ti_float, shape=(_B, max_contacts_per_pair)),
423+
"local_pos1_b": V_VEC(3, dtype=gs.ti_float, shape=(_B, max_contacts_per_pair)),
424+
"local_pos1_c": V_VEC(3, dtype=gs.ti_float, shape=(_B, max_contacts_per_pair)),
425+
"local_pos2_a": V_VEC(3, dtype=gs.ti_float, shape=(_B, max_contacts_per_pair)),
426+
"local_pos2_b": V_VEC(3, dtype=gs.ti_float, shape=(_B, max_contacts_per_pair)),
427+
"local_pos2_c": V_VEC(3, dtype=gs.ti_float, shape=(_B, max_contacts_per_pair)),
428+
"w_local_pos1": V_VEC(3, dtype=gs.ti_float, shape=(_B, max_contacts_per_pair)),
429+
"w_local_pos2": V_VEC(3, dtype=gs.ti_float, shape=(_B, max_contacts_per_pair)),
430+
"ref_id": V(dtype=gs.ti_int, shape=(_B, max_contacts_per_pair)),
431+
"valid": V(dtype=gs.ti_int, shape=(_B, max_contacts_per_pair)),
432+
"ref_penetration": V(dtype=gs.ti_float, shape=(_B, max_contacts_per_pair), needs_grad=True),
433+
}
434+
435+
if use_ndarray:
436+
return StructDiffContactInput(**kwargs)
437+
else:
438+
439+
@ti.data_oriented
440+
class ClassDiffContactInput:
441+
def __init__(self):
442+
for k, v in kwargs.items():
443+
setattr(self, k, v)
444+
445+
return ClassDiffContactInput()
446+
447+
392448
@dataclasses.dataclass
393449
class StructSortBuffer:
394450
value: V_ANNOTATION
@@ -547,19 +603,22 @@ class StructColliderState:
547603
n_contacts_hibernated: V_ANNOTATION
548604
first_time: V_ANNOTATION
549605
contact_cache: StructContactCache
606+
# Input data for differentiable contact detection used in the backward pass
607+
diff_contact_input: StructDiffContactInput
550608

551609

552-
def get_collider_state(solver, n_possible_pairs, collider_static_config):
610+
def get_collider_state(solver, static_rigid_sim_config, n_possible_pairs, collider_static_config):
553611
_B = solver._B
554612
f_batch = solver._batch_shape
555613
n_geoms = solver.n_geoms_
556614
max_collision_pairs = min(solver._max_collision_pairs, n_possible_pairs)
557615
max_collision_pairs_broad = max_collision_pairs * collider_static_config.max_collision_pairs_broad_k
558616
max_contact_pairs = max_collision_pairs * collider_static_config.n_contacts_per_pair
617+
requires_grad = static_rigid_sim_config.requires_grad
559618

560619
############## broad phase SAP ##############
561620

562-
contact_data = get_contact_data(solver, max_contact_pairs)
621+
contact_data = get_contact_data(solver, max_contact_pairs, requires_grad)
563622
sort_buffer = get_sort_buffer(solver)
564623
contact_cache = get_contact_cache(solver)
565624
kwargs = {
@@ -584,6 +643,7 @@ def get_collider_state(solver, n_possible_pairs, collider_static_config):
584643
"n_contacts_hibernated": V(dtype=gs.ti_int, shape=_B),
585644
"first_time": V(dtype=gs.ti_int, shape=_B),
586645
"contact_cache": contact_cache,
646+
"diff_contact_input": get_diff_contact_input(solver, max_contact_pairs if requires_grad else 1),
587647
}
588648

589649
if use_ndarray:
@@ -722,6 +782,8 @@ class StructMDVertex:
722782
# Vertex of the Minkowski difference
723783
obj1: V_ANNOTATION
724784
obj2: V_ANNOTATION
785+
local_obj1: V_ANNOTATION
786+
local_obj2: V_ANNOTATION
725787
id1: V_ANNOTATION
726788
id2: V_ANNOTATION
727789
mink: V_ANNOTATION
@@ -732,6 +794,8 @@ def get_gjk_simplex_vertex(solver):
732794
kwargs = {
733795
"obj1": V_VEC(3, dtype=gs.ti_float, shape=(_B, 4)),
734796
"obj2": V_VEC(3, dtype=gs.ti_float, shape=(_B, 4)),
797+
"local_obj1": V_VEC(3, dtype=gs.ti_float, shape=(_B, 4)),
798+
"local_obj2": V_VEC(3, dtype=gs.ti_float, shape=(_B, 4)),
735799
"id1": V(dtype=gs.ti_int, shape=(_B, 4)),
736800
"id2": V(dtype=gs.ti_int, shape=(_B, 4)),
737801
"mink": V_VEC(3, dtype=gs.ti_float, shape=(_B, 4)),
@@ -756,6 +820,8 @@ def get_epa_polytope_vertex(solver, gjk_static_config):
756820
kwargs = {
757821
"obj1": V_VEC(3, dtype=gs.ti_float, shape=(_B, max_num_polytope_verts)),
758822
"obj2": V_VEC(3, dtype=gs.ti_float, shape=(_B, max_num_polytope_verts)),
823+
"local_obj1": V_VEC(3, dtype=gs.ti_float, shape=(_B, max_num_polytope_verts)),
824+
"local_obj2": V_VEC(3, dtype=gs.ti_float, shape=(_B, max_num_polytope_verts)),
759825
"id1": V(dtype=gs.ti_int, shape=(_B, max_num_polytope_verts)),
760826
"id2": V(dtype=gs.ti_int, shape=(_B, max_num_polytope_verts)),
761827
"mink": V_VEC(3, dtype=gs.ti_float, shape=(_B, max_num_polytope_verts)),
@@ -865,6 +931,7 @@ class StructEPAPolytopeFace:
865931
normal: V_ANNOTATION
866932
dist2: V_ANNOTATION
867933
map_idx: V_ANNOTATION
934+
visited: V_ANNOTATION
868935

869936

870937
def get_epa_polytope_face(solver, polytope_max_faces):
@@ -876,6 +943,7 @@ def get_epa_polytope_face(solver, polytope_max_faces):
876943
"normal": V_VEC(3, dtype=gs.ti_float, shape=(_B, polytope_max_faces)),
877944
"dist2": V(dtype=gs.ti_float, shape=(_B, polytope_max_faces)),
878945
"map_idx": V(dtype=gs.ti_int, shape=(_B, polytope_max_faces)),
946+
"visited": V(dtype=gs.ti_int, shape=(_B, polytope_max_faces)),
879947
}
880948

881949
if use_ndarray:
@@ -1060,6 +1128,10 @@ class StructGJKState:
10601128
is_col: V_ANNOTATION
10611129
penetration: V_ANNOTATION
10621130
distance: V_ANNOTATION
1131+
# Differentiable contact detection
1132+
diff_contact_input: StructDiffContactInput
1133+
n_diff_contact_input: V_ANNOTATION
1134+
diff_penetration: V_ANNOTATION
10631135

10641136

10651137
def get_gjk_state(solver, static_rigid_sim_config, gjk_static_config):
@@ -1068,6 +1140,7 @@ def get_gjk_state(solver, static_rigid_sim_config, gjk_static_config):
10681140
polytope_max_faces = gjk_static_config.polytope_max_faces
10691141
max_contacts_per_pair = gjk_static_config.max_contacts_per_pair
10701142
max_contact_polygon_verts = gjk_static_config.max_contact_polygon_verts
1143+
requires_grad = solver._static_rigid_sim_config.requires_grad
10711144

10721145
### GJK simplex
10731146
simplex_vertex = get_gjk_simplex_vertex(solver)
@@ -1141,6 +1214,9 @@ def get_gjk_state(solver, static_rigid_sim_config, gjk_static_config):
11411214
"is_col": V(dtype=gs.ti_bool, shape=(_B,)),
11421215
"penetration": V(dtype=gs.ti_float, shape=(_B,)),
11431216
"distance": V(dtype=gs.ti_float, shape=(_B,)),
1217+
"diff_contact_input": get_diff_contact_input(solver, max(1, max_contacts_per_pair if requires_grad else 1)),
1218+
"n_diff_contact_input": V(dtype=gs.ti_int, shape=(_B,)),
1219+
"diff_penetration": V(dtype=gs.ti_float, shape=(_B, max_contacts_per_pair)),
11441220
}
11451221
)
11461222

@@ -1799,9 +1875,10 @@ class StructGeomsState:
17991875

18001876
def get_geoms_state(solver):
18011877
shape = solver._batch_shape(solver.n_geoms_)
1878+
requires_grad = solver._static_rigid_sim_config.requires_grad
18021879
kwargs = {
1803-
"pos": V(dtype=gs.ti_vec3, shape=shape),
1804-
"quat": V(dtype=gs.ti_vec4, shape=shape),
1880+
"pos": V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad),
1881+
"quat": V(dtype=gs.ti_vec4, shape=shape, needs_grad=requires_grad),
18051882
"aabb_min": V(dtype=gs.ti_vec3, shape=shape),
18061883
"aabb_max": V(dtype=gs.ti_vec3, shape=shape),
18071884
"verts_updated": V(dtype=gs.ti_bool, shape=shape),
@@ -2381,3 +2458,4 @@ def __init__(self, solver):
23812458
LinksStateAdjointCache = ti.template() if not use_ndarray else StructLinksState
23822459
JointsStateAdjointCache = ti.template() if not use_ndarray else StructJointsState
23832460
GeomsStateAdjointCache = ti.template() if not use_ndarray else StructGeomsState
2461+
DiffContactInput = ti.template() if not use_ndarray else StructDiffContactInput

0 commit comments

Comments
 (0)