Skip to content

Commit 855c494

Browse files
authored
[MISC] Leverage zero-copy mode to speed up Rigid Body accessors. (#2021)
* Disable input out-of-range check for efficiency. * Add zero-copy setters for control_* methods. * Add zero-copy Rigid link velocity getter. * Add zero-copy contacts info getter.
1 parent cb8dbd3 commit 855c494

File tree

8 files changed

+249
-154
lines changed

8 files changed

+249
-154
lines changed

genesis/engine/entities/rigid_entity/rigid_entity.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __init__(
8787

8888
self._free_verts_idx_local = torch.tensor([], dtype=gs.tc_int, device=gs.device)
8989
self._fixed_verts_idx_local = torch.tensor([], dtype=gs.tc_int, device=gs.device)
90-
self._base_links_idx = torch.tensor([self.base_link_idx], dtype=gs.tc_int, device=gs.device)
90+
self._base_links_idx_ = torch.tensor([self.base_link_idx], dtype=gs.tc_int, device=gs.device)
9191

9292
self._batch_fixed_verts = morph.batch_fixed_verts
9393

@@ -1702,7 +1702,7 @@ def get_pos(self, envs_idx=None, *, unsafe=False):
17021702
pos : torch.Tensor, shape (3,) or (n_envs, 3)
17031703
The position of the entity's base link.
17041704
"""
1705-
return self._solver.get_links_pos(self._base_links_idx, envs_idx, unsafe=unsafe).squeeze(-2)
1705+
return self._solver.get_links_pos(self.base_link_idx, envs_idx, unsafe=unsafe).squeeze(-2)
17061706

17071707
@gs.assert_built
17081708
def get_quat(self, envs_idx=None, *, unsafe=False):
@@ -1719,7 +1719,7 @@ def get_quat(self, envs_idx=None, *, unsafe=False):
17191719
quat : torch.Tensor, shape (4,) or (n_envs, 4)
17201720
The quaternion of the entity's base link.
17211721
"""
1722-
return self._solver.get_links_quat(self._base_links_idx, envs_idx, unsafe=unsafe).squeeze(-2)
1722+
return self._solver.get_links_quat(self.base_link_idx, envs_idx, unsafe=unsafe).squeeze(-2)
17231723

17241724
@gs.assert_built
17251725
def get_vel(self, envs_idx=None, *, unsafe=False):
@@ -1736,7 +1736,7 @@ def get_vel(self, envs_idx=None, *, unsafe=False):
17361736
vel : torch.Tensor, shape (3,) or (n_envs, 3)
17371737
The linear velocity of the entity's base link.
17381738
"""
1739-
return self._solver.get_links_vel(self._base_links_idx, envs_idx, unsafe=unsafe).squeeze(-2)
1739+
return self._solver.get_links_vel(self.base_link_idx, envs_idx, unsafe=unsafe).squeeze(-2)
17401740

17411741
@gs.assert_built
17421742
def get_ang(self, envs_idx=None, *, unsafe=False):
@@ -1753,7 +1753,7 @@ def get_ang(self, envs_idx=None, *, unsafe=False):
17531753
ang : torch.Tensor, shape (3,) or (n_envs, 3)
17541754
The angular velocity of the entity's base link.
17551755
"""
1756-
return self._solver.get_links_ang(self._base_links_idx, envs_idx, unsafe=unsafe).squeeze(-2)
1756+
return self._solver.get_links_ang(self.base_link_idx, envs_idx, unsafe=unsafe).squeeze(-2)
17571757

17581758
@gs.assert_built
17591759
def get_links_pos(
@@ -1973,7 +1973,7 @@ def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, uns
19731973
pos = _pos
19741974
self._solver.set_base_links_pos(
19751975
pos.unsqueeze(-2),
1976-
self._base_links_idx,
1976+
self._base_links_idx_,
19771977
envs_idx,
19781978
relative=relative,
19791979
unsafe=unsafe,
@@ -2007,7 +2007,7 @@ def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, u
20072007
quat = _quat
20082008
self._solver.set_base_links_quat(
20092009
quat.unsqueeze(-2),
2010-
self._base_links_idx,
2010+
self._base_links_idx_,
20112011
envs_idx,
20122012
relative=relative,
20132013
unsafe=unsafe,

genesis/engine/solvers/rigid/collider_decomp.py

Lines changed: 101 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99

1010
import genesis as gs
1111
import genesis.utils.geom as gu
12-
from genesis.styles import colors, formats
1312
import genesis.utils.array_class as array_class
1413
import genesis.engine.solvers.rigid.gjk_decomp as gjk
1514
import genesis.engine.solvers.rigid.diff_gjk_decomp as diff_gjk
1615
import genesis.engine.solvers.rigid.mpr_decomp as mpr
1716
import genesis.utils.sdf_decomp as sdf
1817
import genesis.engine.solvers.rigid.support_field_decomp as support_field
1918
import genesis.engine.solvers.rigid.rigid_solver_decomp as rigid_solver
19+
from genesis.utils.misc import tensor_to_array, ti_to_torch, ti_to_numpy
2020

2121
from .mpr_decomp import MPR
2222
from .gjk_decomp import GJK
@@ -62,6 +62,22 @@ def __init__(self, rigid_solver: "RigidSolver"):
6262
self._init_static_config()
6363
self._init_collision_fields()
6464

65+
if gs.use_zerocopy:
66+
self._contacts_info: dict[str, torch.Tensor] = {}
67+
for key, name in (
68+
("link_a", "link_a"),
69+
("link_b", "link_b"),
70+
("geom_a", "geom_a"),
71+
("geom_b", "geom_b"),
72+
("penetration", "penetration"),
73+
("position", "pos"),
74+
("normal", "normal"),
75+
("force", "force"),
76+
):
77+
self._contacts_info[key] = ti_to_torch(
78+
getattr(self._collider_state.contact_data, name), transpose=True, copy=False
79+
)
80+
6581
# Support field used for mpr and gjk. Rather than having separate support fields for each algorithm, keep only
6682
# one copy here to save memory and maintain cleaner code.
6783
self._support_field = SupportField(rigid_solver)
@@ -131,8 +147,8 @@ def _init_collision_fields(self) -> None:
131147
self._collider_static_config,
132148
)
133149

134-
# [contacts_info_cache] is not used in Taichi kernels, so keep it outside of the collider state / info.
135-
self._contacts_info_cache = {}
150+
# 'contacts_info_cache' is not used in Taichi kernels, so keep it outside of the collider state / info
151+
self._contacts_info_cache: dict[tuple[bool, bool], dict[str, torch.Tensor | tuple[torch.Tensor]]] = {}
136152

137153
self.reset()
138154

@@ -288,7 +304,7 @@ def reset(self, envs_idx: npt.NDArray[np.int32] | None = None) -> None:
288304
self._solver._static_rigid_sim_config,
289305
self._collider_state,
290306
)
291-
self._contacts_info_cache = {}
307+
self._contacts_info_cache.clear()
292308

293309
def clear(self, envs_idx=None):
294310
if envs_idx is None:
@@ -302,7 +318,7 @@ def clear(self, envs_idx=None):
302318
)
303319

304320
def detection(self) -> None:
305-
self._contacts_info_cache = {}
321+
self._contacts_info_cache.clear()
306322
rigid_solver.kernel_update_geom_aabbs(
307323
self._solver.geoms_state,
308324
self._solver.geoms_init_AABB,
@@ -391,18 +407,46 @@ def detection(self) -> None:
391407

392408
def get_contacts(self, as_tensor: bool = True, to_torch: bool = True, keep_batch_dim: bool = False):
393409
# Early return if already pre-computed
394-
contacts_info = self._contacts_info_cache.get((as_tensor, to_torch))
395-
if contacts_info is not None:
410+
contacts_info = self._contacts_info_cache.setdefault((as_tensor, to_torch), {})
411+
if contacts_info:
412+
return contacts_info.copy()
413+
414+
n_envs = self._solver.n_envs
415+
if gs.use_zerocopy:
416+
if as_tensor or n_envs == 0:
417+
n_contacts_max = ti_to_torch(self._collider_state.n_contacts_max, copy=False).item()
418+
else:
419+
n_contacts = ti_to_torch(self._collider_state.n_contacts, copy=False)
420+
421+
for key, data in self._contacts_info.items():
422+
if n_envs == 0:
423+
data = data[0, :n_contacts_max]
424+
if not to_torch:
425+
data = tensor_to_array(data)
426+
else:
427+
if as_tensor:
428+
data = data[:, :n_contacts_max]
429+
if not to_torch:
430+
data = tensor_to_array(data)
431+
else:
432+
if not to_torch:
433+
data = tensor_to_array(data)
434+
if keep_batch_dim:
435+
data = tuple([data[i : i + 1, :j] for i, j in enumerate(n_contacts.tolist())])
436+
else:
437+
data = tuple([data[i, :j] for i, j in enumerate(n_contacts.tolist())])
438+
contacts_info[key] = data
439+
396440
return contacts_info.copy()
397441

398442
# Find out how much dynamic memory must be allocated
399-
n_contacts = tuple(self._collider_state.n_contacts.to_numpy())
400-
n_envs = len(n_contacts)
401-
n_contacts_max = max(n_contacts)
443+
n_contacts = ti_to_numpy(self._collider_state.n_contacts)
444+
n_contacts_max = n_contacts.max().item()
402445
if as_tensor:
403-
out_size = n_contacts_max * n_envs
446+
out_size = n_contacts_max * max(n_envs, 1)
404447
else:
405448
*n_contacts_starts, out_size = np.cumsum(n_contacts)
449+
n_contacts = n_contacts.tolist()
406450

407451
# Allocate output buffer
408452
if to_torch:
@@ -415,29 +459,23 @@ def get_contacts(self, as_tensor: bool = True, to_torch: bool = True, keep_batch
415459
# Copy contact data
416460
if n_contacts_max > 0:
417461
collider_kernel_get_contacts(
418-
as_tensor,
419-
iout,
420-
fout,
421-
self._solver._rigid_global_info,
422-
self._solver._static_rigid_sim_config,
423-
self._collider_state,
424-
self._collider_info,
462+
as_tensor, iout, fout, self._solver._static_rigid_sim_config, self._collider_state
425463
)
426464

427465
# Build structured view (no copy)
428466
if as_tensor:
429-
if self._solver.n_envs > 0:
467+
if n_envs > 0:
430468
iout = iout.reshape((n_envs, n_contacts_max, 4))
431469
fout = fout.reshape((n_envs, n_contacts_max, 10))
432-
if keep_batch_dim and self._solver.n_envs == 0:
470+
if keep_batch_dim and n_envs == 0:
433471
iout = iout.reshape((1, n_contacts_max, 4))
434472
fout = fout.reshape((1, n_contacts_max, 10))
435473
iout_chunks = (iout[..., 0], iout[..., 1], iout[..., 2], iout[..., 3])
436474
fout_chunks = (fout[..., 0], fout[..., 1:4], fout[..., 4:7], fout[..., 7:])
437475
values = (*iout_chunks, *fout_chunks)
438476
else:
439477
# Split smallest dimension first, then largest dimension
440-
if self._solver.n_envs == 0:
478+
if n_envs == 0:
441479
iout_chunks = (iout[..., 0], iout[..., 1], iout[..., 2], iout[..., 3])
442480
fout_chunks = (fout[..., 0], fout[..., 1:4], fout[..., 4:7], fout[..., 7:])
443481
values = (*iout_chunks, *fout_chunks)
@@ -454,7 +492,7 @@ def get_contacts(self, as_tensor: bool = True, to_torch: bool = True, keep_batch
454492
else:
455493
iout_chunks = (iout[..., 0], iout[..., 1], iout[..., 2], iout[..., 3])
456494
fout_chunks = (fout[..., 0], fout[..., 1:4], fout[..., 4:7], fout[..., 7:])
457-
if self._solver.n_envs == 1:
495+
if n_envs == 1:
458496
values = [(value,) for value in (*iout_chunks, *fout_chunks)]
459497
else:
460498
if to_torch:
@@ -465,13 +503,11 @@ def get_contacts(self, as_tensor: bool = True, to_torch: bool = True, keep_batch
465503
fout_chunks = (np.split(out, n_contacts_starts) for out in fout_chunks)
466504
values = (*iout_chunks, *fout_chunks)
467505

468-
contacts_info = dict(
506+
# Store contact information in cache
507+
contacts_info.update(
469508
zip(("link_a", "link_b", "geom_a", "geom_b", "penetration", "position", "normal", "force"), values)
470509
)
471510

472-
# Cache contact information before returning
473-
self._contacts_info_cache[(as_tensor, to_torch)] = contacts_info
474-
475511
return contacts_info.copy()
476512

477513
def backward(self, dL_dposition, dL_dnormal, dL_dpenetration):
@@ -514,16 +550,16 @@ def collider_kernel_reset(
514550
static_rigid_sim_config: ti.template(),
515551
collider_state: array_class.ColliderState,
516552
):
553+
n_geoms = collider_state.active_buffer.shape[0]
554+
517555
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
518556
for i_b_ in range(envs_idx.shape[0]):
519557
i_b = envs_idx[i_b_]
520558
collider_state.first_time[i_b] = 1
521-
n_geoms = collider_state.active_buffer.shape[0]
522-
for i_ga in range(n_geoms):
523-
for i_gb in range(n_geoms):
524-
collider_state.contact_cache.i_va_ws[i_ga, i_gb, i_b] = -1
525-
collider_state.contact_cache.i_va_ws[i_gb, i_ga, i_b] = -1
526-
collider_state.contact_cache.normal[i_ga, i_gb, i_b] = ti.Vector.zero(gs.ti_float, 3)
559+
for i_ga, i_gb in ti.ndrange(n_geoms, n_geoms):
560+
collider_state.contact_cache.i_va_ws[i_ga, i_gb, i_b] = -1
561+
collider_state.contact_cache.i_va_ws[i_gb, i_ga, i_b] = -1
562+
collider_state.contact_cache.normal[i_ga, i_gb, i_b] = ti.Vector.zero(gs.ti_float, 3)
527563

528564

529565
# only used with hibernation ??
@@ -574,31 +610,34 @@ def kernel_collider_clear(
574610

575611
collider_state.n_contacts_hibernated[i_b] = i_c_hibernated + 1
576612

613+
for i_c in range(collider_state.n_contacts[i_b]):
614+
collider_state.contact_data.link_a[i_c, i_b] = -1
615+
collider_state.contact_data.link_b[i_c, i_b] = -1
616+
collider_state.contact_data.geom_a[i_c, i_b] = -1
617+
collider_state.contact_data.geom_b[i_c, i_b] = -1
618+
collider_state.contact_data.penetration[i_c, i_b] = 0.0
619+
collider_state.contact_data.pos[i_c, i_b] = ti.Vector.zero(gs.ti_float, 3)
620+
collider_state.contact_data.normal[i_c, i_b] = ti.Vector.zero(gs.ti_float, 3)
621+
collider_state.contact_data.force[i_c, i_b] = ti.Vector.zero(gs.ti_float, 3)
622+
623+
if ti.static(static_rigid_sim_config.use_hibernation):
577624
collider_state.n_contacts[i_b] = collider_state.n_contacts_hibernated[i_b]
578625
else:
579626
collider_state.n_contacts[i_b] = 0
580627

628+
collider_state.n_contacts_max[None] = 0
629+
581630

582631
@ti.kernel(fastcache=gs.use_fastcache)
583632
def collider_kernel_get_contacts(
584633
is_padded: ti.template(),
585634
iout: ti.types.ndarray(),
586635
fout: ti.types.ndarray(),
587-
rigid_global_info: array_class.RigidGlobalInfo,
588636
static_rigid_sim_config: ti.template(),
589637
collider_state: array_class.ColliderState,
590-
collider_info: array_class.ColliderInfo,
591638
):
592639
_B = collider_state.active_buffer.shape[1]
593-
n_contacts_max = gs.ti_int(0)
594-
595-
# this is a reduction operation (global max), we have to serialize it
596-
# TODO: a good unittest and a better implementation from gstaichi for this kind of reduction
597-
ti.loop_config(serialize=True)
598-
for i_b in range(_B):
599-
n_contacts = collider_state.n_contacts[i_b]
600-
if n_contacts > n_contacts_max:
601-
n_contacts_max = n_contacts
640+
n_contacts_max = collider_state.n_contacts_max[None]
602641

603642
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
604643
for i_b in range(_B):
@@ -1187,9 +1226,10 @@ def func_collision_clear(
11871226
static_rigid_sim_config: ti.template(),
11881227
):
11891228
_B = collider_state.n_contacts.shape[0]
1190-
if ti.static(static_rigid_sim_config.use_hibernation):
1191-
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
1192-
for i_b in range(_B):
1229+
1230+
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
1231+
for i_b in range(_B):
1232+
if ti.static(static_rigid_sim_config.use_hibernation):
11931233
collider_state.n_contacts_hibernated[i_b] = 0
11941234

11951235
# Advect hibernated contacts
@@ -1210,12 +1250,23 @@ def func_collision_clear(
12101250
collider_state.contact_data[i_c_hibernated, i_b] = collider_state.contact_data[i_c, i_b]
12111251
collider_state.n_contacts_hibernated[i_b] = i_c_hibernated + 1
12121252

1253+
for i_c in range(collider_state.n_contacts[i_b]):
1254+
collider_state.contact_data.link_a[i_c, i_b] = -1
1255+
collider_state.contact_data.link_b[i_c, i_b] = -1
1256+
collider_state.contact_data.geom_a[i_c, i_b] = -1
1257+
collider_state.contact_data.geom_b[i_c, i_b] = -1
1258+
collider_state.contact_data.penetration[i_c, i_b] = 0.0
1259+
collider_state.contact_data.pos[i_c, i_b] = ti.Vector.zero(gs.ti_float, 3)
1260+
collider_state.contact_data.normal[i_c, i_b] = ti.Vector.zero(gs.ti_float, 3)
1261+
collider_state.contact_data.force[i_c, i_b] = ti.Vector.zero(gs.ti_float, 3)
1262+
1263+
if ti.static(static_rigid_sim_config.use_hibernation):
12131264
collider_state.n_contacts[i_b] = collider_state.n_contacts_hibernated[i_b]
1214-
else:
1215-
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
1216-
for i_b in range(_B):
1265+
else:
12171266
collider_state.n_contacts[i_b] = 0
12181267

1268+
collider_state.n_contacts_max[None] = 0
1269+
12191270

12201271
@ti.kernel(fastcache=gs.use_fastcache)
12211272
def func_broad_phase(
@@ -2087,6 +2138,7 @@ def func_add_contact(
20872138
collider_state.contact_data.link_b[i_c, i_b] = geoms_info.link_idx[i_gb]
20882139

20892140
collider_state.n_contacts[i_b] = i_c + 1
2141+
ti.atomic_max(collider_state.n_contacts_max[None], i_c + 1)
20902142
else:
20912143
errno[None] = 2
20922144

0 commit comments

Comments
 (0)