Skip to content

Commit 6ba836c

Browse files
authored
[MISC] Reduce contact cache memory usage. (#2031)
* Reduce contact cache memory usage. * Avoid calling collision detection if no valid collision pairs.
1 parent 2b64f0e commit 6ba836c

File tree

4 files changed

+65
-61
lines changed

4 files changed

+65
-61
lines changed

genesis/engine/bvh.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,4 +569,4 @@ def filter(self, i_a, i_q):
569569
"""
570570
i_ag = self.coupler.rigid_volume_elems_geom_idx[i_a]
571571
i_qg = self.coupler.rigid_volume_elems_geom_idx[i_q]
572-
return not self.coupler.rigid_collision_pair_validity[i_ag, i_qg]
572+
return self.coupler.rigid_collision_pair_idx[i_ag, i_qg] == -1

genesis/engine/couplers/sap_coupler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -357,10 +357,10 @@ def _init_hydroelastic_rigid_fields_and_info(self):
357357
self.rigid_volume_verts_geom_idx.from_numpy(rigid_volume_verts_geom_idx_np)
358358
self.rigid_volume_elems_geom_idx = ti.field(gs.ti_int, shape=(self.n_rigid_volume_elems,))
359359
self.rigid_volume_elems_geom_idx.from_numpy(rigid_volume_elems_geom_idx_np)
360-
# FIXME: Convert collision_pair_validity to field here because SAPCouler cannot support ndarray/field switch yet
361-
np_collision_pair_validity = self.rigid_solver.collider._collider_info.collision_pair_validity.to_numpy()
362-
self.rigid_collision_pair_validity = ti.field(gs.ti_int, shape=np_collision_pair_validity.shape)
363-
self.rigid_collision_pair_validity.from_numpy(np_collision_pair_validity)
360+
# FIXME: Convert collision_pair_idx to field here because SAPCoupler cannot support ndarray/field switch yet
361+
np_collision_pair_idx = self.rigid_solver.collider._collider_info.collision_pair_idx.to_numpy()
362+
self.rigid_collision_pair_idx = ti.field(gs.ti_int, shape=np_collision_pair_idx.shape)
363+
self.rigid_collision_pair_idx.from_numpy(np_collision_pair_idx)
364364
self.rigid_pressure_field = ti.field(gs.ti_float, shape=(self.n_rigid_volume_verts,))
365365
self.rigid_pressure_field.from_numpy(rigid_pressure_field_np)
366366
self.rigid_pressure_gradient_rest = ti.field(gs.ti_vec3, shape=(self.n_rigid_volume_elems,))

genesis/engine/solvers/rigid/collider_decomp.py

Lines changed: 54 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def _init_static_config(self) -> None:
115115

116116
def _init_collision_fields(self) -> None:
117117
# Pre-compute fields, as they are needed to initialize the collider state and info.
118-
n_possible_pairs_, collision_pair_validity = self._compute_collision_pair_validity()
118+
self._n_possible_pairs, collision_pair_idx = self._compute_collision_pair_idx()
119119
vert_neighbors, vert_neighbor_start, vert_n_neighbors = self._compute_verts_connectivity()
120120
n_vert_neighbors = len(vert_neighbors)
121121

@@ -131,17 +131,17 @@ def _init_collision_fields(self) -> None:
131131
diff_pos_tolerance=self._diff_pos_tolerance,
132132
diff_normal_tolerance=self._diff_normal_tolerance,
133133
)
134-
self._init_collision_pair_validity(collision_pair_validity)
134+
self._init_collision_pair_idx(collision_pair_idx)
135135
self._init_verts_connectivity(vert_neighbors, vert_neighbor_start, vert_n_neighbors)
136-
self._init_max_contact_pairs(n_possible_pairs_)
136+
self._init_max_contact_pairs(self._n_possible_pairs)
137137
self._init_terrain_state()
138138

139139
# Initialize [state], which stores every data that are may be updated at every single simulation step
140-
n_possible_pairs = max(n_possible_pairs_, 1)
140+
n_possible_pairs_ = max(self._n_possible_pairs, 1)
141141
self._collider_state = array_class.get_collider_state(
142142
self._solver,
143143
self._solver._static_rigid_sim_config,
144-
n_possible_pairs,
144+
n_possible_pairs_,
145145
self._solver._options.multiplier_collision_broad_phase,
146146
self._collider_info,
147147
self._collider_static_config,
@@ -152,9 +152,9 @@ def _init_collision_fields(self) -> None:
152152

153153
self.reset()
154154

155-
def _compute_collision_pair_validity(self):
155+
def _compute_collision_pair_idx(self):
156156
"""
157-
Compute the collision pair validity matrix.
157+
Compute flat indices of all valid collision pairs.
158158
159159
For each pair of geoms, determine if they can collide based on their properties and the solver configuration.
160160
"""
@@ -183,7 +183,7 @@ def _compute_collision_pair_validity(self):
183183
entities_is_local_collision_mask = solver.entities_info.is_local_collision_mask.to_numpy()
184184

185185
n_possible_pairs = 0
186-
collision_pair_validity = np.zeros((n_geoms, n_geoms), dtype=gs.np_int)
186+
collision_pair_idx = np.full((n_geoms, n_geoms), fill_value=-1, dtype=gs.np_int)
187187
for i_ga in range(n_geoms):
188188
for i_gb in range(i_ga + 1, n_geoms):
189189
i_la = geoms_link_idx[i_ga]
@@ -229,10 +229,10 @@ def _compute_collision_pair_validity(self):
229229
if links_is_fixed[i_la] and links_is_fixed[i_lb]:
230230
continue
231231

232-
collision_pair_validity[i_ga, i_gb] = 1
232+
collision_pair_idx[i_ga, i_gb] = n_possible_pairs
233233
n_possible_pairs += 1
234234

235-
return n_possible_pairs, collision_pair_validity
235+
return n_possible_pairs, collision_pair_idx
236236

237237
def _compute_verts_connectivity(self):
238238
"""
@@ -255,8 +255,8 @@ def _compute_verts_connectivity(self):
255255

256256
return vert_neighbors, vert_neighbor_start, vert_n_neighbors
257257

258-
def _init_collision_pair_validity(self, collision_pair_validity):
259-
self._collider_info.collision_pair_validity.from_numpy(collision_pair_validity)
258+
def _init_collision_pair_idx(self, collision_pair_idx):
259+
self._collider_info.collision_pair_idx.from_numpy(collision_pair_idx)
260260

261261
def _init_verts_connectivity(self, vert_neighbors, vert_neighbor_start, vert_n_neighbors):
262262
if self._solver.n_verts > 0:
@@ -299,22 +299,23 @@ def _init_terrain_state(self):
299299
def reset(self, envs_idx: npt.NDArray[np.int32] | None = None, cache_only: bool = False) -> None:
300300
self._contacts_info_cache.clear()
301301
if gs.use_zerocopy:
302-
mask = () if envs_idx is None else envs_idx
302+
envs_idx = slice(None) if envs_idx is None else envs_idx
303303
if not cache_only:
304304
first_time = ti_to_torch(self._collider_state.first_time, copy=False)
305305
if isinstance(envs_idx, torch.Tensor):
306306
first_time.scatter_(0, envs_idx, True)
307307
else:
308-
first_time[mask] = True
308+
first_time[envs_idx] = True
309+
309310
i_va_ws = ti_to_torch(self._collider_state.contact_cache.i_va_ws, copy=False)
310311
normal = ti_to_torch(self._collider_state.contact_cache.normal, copy=False)
311312
if isinstance(envs_idx, torch.Tensor):
312-
n_geoms = i_va_ws.shape[0]
313-
i_va_ws.scatter_(2, envs_idx[None, None].expand((n_geoms, n_geoms, -1)), -1)
314-
normal.scatter_(2, envs_idx[None, None, :, None].expand((n_geoms, n_geoms, -1, 3)), 0.0)
313+
max_possible_pairs = normal.shape[0]
314+
i_va_ws.scatter_(2, envs_idx[None, None].expand((2, max_possible_pairs, -1)), -1)
315+
normal.scatter_(1, envs_idx[None, :, None].expand((max_possible_pairs, -1, 3)), 0.0)
315316
else:
316-
i_va_ws[mask] = -1
317-
normal[mask] = 0.0
317+
i_va_ws[:, :, envs_idx] = -1
318+
normal[:, envs_idx] = 0.0
318319
return
319320

320321
if envs_idx is None:
@@ -333,12 +334,16 @@ def clear(self, envs_idx=None):
333334
)
334335

335336
def detection(self) -> None:
336-
self._contacts_info_cache.clear()
337337
rigid_solver.kernel_update_geom_aabbs(
338338
self._solver.geoms_state,
339339
self._solver.geoms_init_AABB,
340340
self._solver._static_rigid_sim_config,
341341
)
342+
343+
if self._n_possible_pairs == 0:
344+
return
345+
346+
self._contacts_info_cache.clear()
342347
func_broad_phase(
343348
self._solver.links_state,
344349
self._solver.links_info,
@@ -565,7 +570,7 @@ def collider_kernel_reset(
565570
collider_state: array_class.ColliderState,
566571
cache_only: ti.template(),
567572
):
568-
n_geoms = collider_state.active_buffer.shape[0]
573+
max_possible_pairs = collider_state.contact_cache.normal.shape[0]
569574

570575
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
571576
for i_b_ in range(envs_idx.shape[0]):
@@ -574,10 +579,10 @@ def collider_kernel_reset(
574579
if ti.static(not cache_only):
575580
collider_state.first_time[i_b] = True
576581

577-
for i_ga, i_gb in ti.ndrange(n_geoms, n_geoms):
578-
collider_state.contact_cache.i_va_ws[i_ga, i_gb, i_b] = -1
579-
collider_state.contact_cache.i_va_ws[i_gb, i_ga, i_b] = -1
580-
collider_state.contact_cache.normal[i_ga, i_gb, i_b] = ti.Vector.zero(gs.ti_float, 3)
582+
for i_pair in range(max_possible_pairs):
583+
collider_state.contact_cache.i_va_ws[0, i_pair, i_b] = -1
584+
collider_state.contact_cache.i_va_ws[1, i_pair, i_b] = -1
585+
collider_state.contact_cache.normal[i_pair, i_b] = ti.Vector.zero(gs.ti_float, 3)
581586

582587

583588
# only used with hibernation ??
@@ -1214,7 +1219,7 @@ def func_check_collision_valid(
12141219
equalities_info: array_class.EqualitiesInfo,
12151220
collider_info: array_class.ColliderInfo,
12161221
):
1217-
is_valid = collider_info.collision_pair_validity[i_ga, i_gb]
1222+
is_valid = collider_info.collision_pair_idx[i_ga, i_gb] != -1
12181223

12191224
if is_valid:
12201225
i_la = geoms_info.link_idx[i_ga]
@@ -1408,9 +1413,10 @@ def func_broad_phase(
14081413
if not func_is_geom_aabbs_overlap(i_ga, i_gb, i_b, geoms_state, geoms_info):
14091414
# Clear collision normal cache if not in contact
14101415
if ti.static(not static_rigid_sim_config.enable_mujoco_compatibility):
1411-
collider_state.contact_cache.i_va_ws[i_ga, i_gb, i_b] = -1
1412-
collider_state.contact_cache.i_va_ws[i_gb, i_ga, i_b] = -1
1413-
collider_state.contact_cache.normal[i_ga, i_gb, i_b] = ti.Vector.zero(gs.ti_float, 3)
1416+
i_pair = collider_info.collision_pair_idx[i_ga, i_gb]
1417+
collider_state.contact_cache.i_va_ws[0, i_pair, i_b] = -1
1418+
collider_state.contact_cache.i_va_ws[1, i_pair, i_b] = -1
1419+
collider_state.contact_cache.normal[i_pair, i_b] = ti.Vector.zero(gs.ti_float, 3)
14141420
continue
14151421

14161422
i_p = collider_state.n_broad_pairs[i_b]
@@ -1465,11 +1471,10 @@ def func_broad_phase(
14651471
if not func_is_geom_aabbs_overlap(i_ga, i_gb, i_b, geoms_state, geoms_info):
14661472
# Clear collision normal cache if not in contact
14671473
if ti.static(not static_rigid_sim_config.enable_mujoco_compatibility):
1468-
collider_state.contact_cache.i_va_ws[i_ga, i_gb, i_b] = -1
1469-
collider_state.contact_cache.i_va_ws[i_gb, i_ga, i_b] = -1
1470-
collider_state.contact_cache.normal[i_ga, i_gb, i_b] = ti.Vector.zero(
1471-
gs.ti_float, 3
1472-
)
1474+
i_pair = collider_info.collision_pair_idx[i_ga, i_gb]
1475+
collider_state.contact_cache.i_va_ws[0, i_pair, i_b] = -1
1476+
collider_state.contact_cache.i_va_ws[1, i_pair, i_b] = -1
1477+
collider_state.contact_cache.normal[i_pair, i_b] = ti.Vector.zero(gs.ti_float, 3)
14731478
continue
14741479

14751480
collider_state.broad_collision_pairs[collider_state.n_broad_pairs[i_b], i_b][0] = i_ga
@@ -1501,11 +1506,10 @@ def func_broad_phase(
15011506

15021507
if not func_is_geom_aabbs_overlap(i_ga, i_gb, i_b, geoms_state, geoms_info):
15031508
# Clear collision normal cache if not in contact
1504-
collider_state.contact_cache.i_va_ws[i_ga, i_gb, i_b] = -1
1505-
collider_state.contact_cache.i_va_ws[i_gb, i_ga, i_b] = -1
1506-
collider_state.contact_cache.normal[i_ga, i_gb, i_b] = ti.Vector.zero(
1507-
gs.ti_float, 3
1508-
)
1509+
i_pair = collider_info.collision_pair_idx[i_ga, i_gb]
1510+
collider_state.contact_cache.i_va_ws[0, i_pair, i_b] = -1
1511+
collider_state.contact_cache.i_va_ws[1, i_pair, i_b] = -1
1512+
collider_state.contact_cache.normal[i_pair, i_b] = ti.Vector.zero(gs.ti_float, 3)
15091513
continue
15101514

15111515
collider_state.broad_collision_pairs[collider_state.n_broad_pairs[i_b], i_b][0] = i_ga
@@ -2395,6 +2399,7 @@ def func_convex_convex_contact(
23952399
axis_1 = ti.Vector.zero(gs.ti_float, 3)
23962400
qrot = ti.Vector.zero(gs.ti_float, 4)
23972401

2402+
i_pair = collider_info.collision_pair_idx[(i_gb, i_ga) if i_ga > i_gb else (i_ga, i_gb)]
23982403
for i_detection in range(5):
23992404
try_sdf = False
24002405
prefer_sdf = False
@@ -2410,8 +2415,7 @@ def func_convex_convex_contact(
24102415
if (multi_contact and is_col_0) or (i_detection == 0):
24112416
if geoms_info.type[i_ga] == gs.GEOM_TYPE.PLANE:
24122417
plane_dir = ti.Vector(
2413-
[geoms_info.data[i_ga][0], geoms_info.data[i_ga][1], geoms_info.data[i_ga][2]],
2414-
dt=gs.ti_float,
2418+
[geoms_info.data[i_ga][0], geoms_info.data[i_ga][1], geoms_info.data[i_ga][2]], dt=gs.ti_float
24152419
)
24162420
plane_dir = gu.ti_transform_by_quat(plane_dir, geoms_state.quat[i_ga, i_b])
24172421
normal = -plane_dir.normalized()
@@ -2438,7 +2442,7 @@ def func_convex_convex_contact(
24382442
# Try using MPR before anything else
24392443
is_mpr_updated = False
24402444
is_mpr_guess_direction_available = True
2441-
normal_ws = collider_state.contact_cache.normal[i_ga, i_gb, i_b]
2445+
normal_ws = collider_state.contact_cache.normal[i_pair, i_b]
24422446
for i_mpr in range(2):
24432447
if i_mpr == 1:
24442448
# Try without warm-start if no contact was detected using it.
@@ -2604,8 +2608,8 @@ def func_convex_convex_contact(
26042608
penetration_a = gs.ti_float(0.0)
26052609
contact_pos_a = ti.Vector.zero(gs.ti_float, 3)
26062610
contact_pos_b = ti.Vector.zero(gs.ti_float, 3)
2607-
i_va = collider_state.contact_cache.i_va_ws[i_ga, i_gb, i_b]
2608-
i_vb = collider_state.contact_cache.i_va_ws[i_gb, i_ga, i_b]
2611+
i_va = collider_state.contact_cache.i_va_ws[0, i_pair, i_b]
2612+
i_vb = collider_state.contact_cache.i_va_ws[1, i_pair, i_b]
26092613
for i_sdf in range(2):
26102614
is_col_i, normal_i, penetration_i, contact_pos_i, i_vi = func_contact_convex_convex_sdf(
26112615
i_ga if i_sdf == 0 else i_gb,
@@ -2654,14 +2658,14 @@ def func_convex_convex_contact(
26542658
normal = normal_a
26552659
penetration = penetration_a
26562660
contact_pos = contact_pos_a
2657-
collider_state.contact_cache.i_va_ws[i_ga, i_gb, i_b] = i_va
2661+
collider_state.contact_cache.i_va_ws[0, i_pair, i_b] = i_va
26582662
elif is_col_b and (
26592663
not is_col_a or penetration_b > max(penetration_a, (not prefer_sdf) * penetration)
26602664
):
26612665
normal = normal_b
26622666
penetration = penetration_b
26632667
contact_pos = contact_pos_b
2664-
collider_state.contact_cache.i_va_ws[i_gb, i_ga, i_b] = i_vb
2668+
collider_state.contact_cache.i_va_ws[1, i_pair, i_b] = i_vb
26652669
elif not is_col_a and not is_col_b:
26662670
is_col = False
26672671

@@ -2701,12 +2705,12 @@ def func_convex_convex_contact(
27012705
if ti.static(
27022706
collider_static_config.ccd_algorithm in (CCD_ALGORITHM_CODE.MPR, CCD_ALGORITHM_CODE.GJK)
27032707
):
2704-
collider_state.contact_cache.normal[i_ga, i_gb, i_b] = normal
2708+
collider_state.contact_cache.normal[i_pair, i_b] = normal
27052709
else:
27062710
# Clear collision normal cache if not in contact
2707-
collider_state.contact_cache.i_va_ws[i_ga, i_gb, i_b] = -1
2708-
collider_state.contact_cache.i_va_ws[i_gb, i_ga, i_b] = -1
2709-
collider_state.contact_cache.normal[i_ga, i_gb, i_b] = ti.Vector.zero(gs.ti_float, 3)
2711+
collider_state.contact_cache.i_va_ws[0, i_pair, i_b] = -1
2712+
collider_state.contact_cache.i_va_ws[1, i_pair, i_b] = -1
2713+
collider_state.contact_cache.normal[i_pair, i_b] = ti.Vector.zero(gs.ti_float, 3)
27102714

27112715
elif multi_contact and is_col_0 > 0 and is_col > 0:
27122716
if ti.static(collider_static_config.ccd_algorithm in (CCD_ALGORITHM_CODE.MPR, CCD_ALGORITHM_CODE.GJK)):

genesis/utils/array_class.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -406,11 +406,11 @@ class StructContactCache(metaclass=BASE_METACLASS):
406406
normal: V_ANNOTATION
407407

408408

409-
def get_contact_cache(solver):
409+
def get_contact_cache(solver, n_possible_pairs):
410410
_B = solver._B
411411
return StructContactCache(
412-
i_va_ws=V(dtype=gs.ti_int, shape=(solver.n_geoms_, solver.n_geoms_, _B)),
413-
normal=V_VEC(3, dtype=gs.ti_float, shape=(solver.n_geoms_, solver.n_geoms_, _B)),
412+
i_va_ws=V(dtype=gs.ti_int, shape=(2, n_possible_pairs, _B)),
413+
normal=V_VEC(3, dtype=gs.ti_float, shape=(n_possible_pairs, _B)),
414414
)
415415

416416

@@ -548,7 +548,7 @@ def get_collider_state(
548548
n_contacts=V(dtype=gs.ti_int, shape=(_B,)),
549549
n_contacts_hibernated=V(dtype=gs.ti_int, shape=(_B,)),
550550
first_time=V(dtype=gs.ti_bool, shape=(_B,)),
551-
contact_cache=get_contact_cache(solver),
551+
contact_cache=get_contact_cache(solver, n_possible_pairs),
552552
broad_collision_pairs=V_VEC(2, dtype=gs.ti_int, shape=(max(max_collision_pairs_broad, 1), _B)),
553553
contact_data=get_contact_data(solver, max_contact_pairs, requires_grad),
554554
diff_contact_input=get_diff_contact_input(solver, max(max_contact_pairs, 1), is_active=True),
@@ -560,7 +560,7 @@ class StructColliderInfo(metaclass=BASE_METACLASS):
560560
vert_neighbors: V_ANNOTATION
561561
vert_neighbor_start: V_ANNOTATION
562562
vert_n_neighbors: V_ANNOTATION
563-
collision_pair_validity: V_ANNOTATION
563+
collision_pair_idx: V_ANNOTATION
564564
max_possible_pairs: V_ANNOTATION
565565
max_collision_pairs: V_ANNOTATION
566566
max_contact_pairs: V_ANNOTATION
@@ -591,7 +591,7 @@ def get_collider_info(solver, n_vert_neighbors, collider_static_config, **kwargs
591591
vert_neighbors=V(dtype=gs.ti_int, shape=(max(n_vert_neighbors, 1),)),
592592
vert_neighbor_start=V(dtype=gs.ti_int, shape=(solver.n_verts_,)),
593593
vert_n_neighbors=V(dtype=gs.ti_int, shape=(solver.n_verts_,)),
594-
collision_pair_validity=V(dtype=gs.ti_int, shape=(solver.n_geoms_, solver.n_geoms_)),
594+
collision_pair_idx=V(dtype=gs.ti_int, shape=(solver.n_geoms_, solver.n_geoms_)),
595595
max_possible_pairs=V(dtype=gs.ti_int, shape=()),
596596
max_collision_pairs=V(dtype=gs.ti_int, shape=()),
597597
max_contact_pairs=V(dtype=gs.ti_int, shape=()),

0 commit comments

Comments
 (0)