Skip to content

Commit 2c95b26

Browse files
authored
[FEATURE] Leverage GsTaichi zero-copy in data accessors. (#2011)
* Fix 'redirect_libc_stderr' raising exception if stderr is not a real OS-like file descriptor. * Enable GsTaichi dynamic array mode on Mac OS. * Leverage GsTaichi zero-copy in data accessors.
1 parent ab43ab7 commit 2c95b26

File tree

8 files changed

+249
-131
lines changed

8 files changed

+249
-131
lines changed

genesis/__init__.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
backend: gs_backend | None = None
4242
use_ndarray: bool | None = None
4343
use_fastcache: bool | None = None
44+
use_zerocopy: bool | None = None
4445
EPS: float | None = None
4546

4647

@@ -117,8 +118,8 @@ def init(
117118
backend = gs_backend.cpu
118119

119120
# Configure GsTaichi fast cache and array type
120-
global use_ndarray, use_fastcache
121-
is_ndarray_disabled = (os.environ.get("GS_ENABLE_NDARRAY") or ("0" if sys.platform == "darwin" else "1")) == "0"
121+
global use_ndarray, use_fastcache, use_zerocopy
122+
is_ndarray_disabled = (os.environ.get("GS_ENABLE_NDARRAY") or ("0" if backend == gs_backend.metal else "1")) == "0"
122123
if use_ndarray is None:
123124
_use_ndarray = not (is_ndarray_disabled or performance_mode)
124125
else:
@@ -136,6 +137,20 @@ def init(
136137
raise_exception("Genesis previous initialized. GsTaichi fast cache mode cannot be disabled anymore.")
137138
use_ndarray, use_fastcache = _use_ndarray, _use_fastcache
138139

140+
# Unlike dynamic vs static array mode, and fastcache, zero-copy can be toggle on/off between init without issue.
141+
# FIXME: ti.Field does not support zero-copy on Metal for now because of a bug in Torch itself.
142+
# See: https://github.com/pytorch/pytorch/pull/168193
143+
# FIXME: Zero-copy is currently broken for ti.Field for some reason...
144+
_use_zerocopy = int(os.environ["GS_ENABLE_ZEROCOPY"]) if "GS_ENABLE_ZEROCOPY" in os.environ else None
145+
if backend in (gs_backend.cpu, gs_backend.cuda):
146+
if _use_zerocopy is None:
147+
_use_zerocopy = True
148+
else:
149+
if _use_zerocopy:
150+
raise_exception(f"Zero-copy only support by GsTaichi dynamic array mode on CPU and CUDA backend.")
151+
_use_zerocopy = False
152+
use_zerocopy = _use_zerocopy and _use_ndarray # (_use_ndarray or backend != gs_backend.metal)
153+
139154
# Define the right dtypes in accordance with selected backend and precision
140155
global ti_float, np_float, tc_float
141156
if precision == "32":

genesis/engine/entities/rigid_entity/rigid_entity.py

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,7 @@ def _build(self):
561561

562562
self._n_qs = self.n_qs
563563
self._n_dofs = self.n_dofs
564+
self._n_geoms = self.n_geoms
564565
self._is_built = True
565566

566567
verts_start = 0
@@ -576,6 +577,8 @@ def _build(self):
576577
self._free_verts_idx_local = torch.cat(free_verts_idx_local)
577578
if fixed_verts_idx_local:
578579
self._fixed_verts_idx_local = torch.cat(fixed_verts_idx_local)
580+
self._n_free_verts = len(self._free_verts_idx_local)
581+
self._n_fixed_verts = len(self._fixed_verts_idx_local)
579582

580583
self._geoms = self.geoms
581584
self._vgeoms = self.vgeoms
@@ -1347,19 +1350,13 @@ def inverse_kinematics_multilink(
13471350
)
13481351

13491352
qpos = ti_to_torch(self._IK_qpos_best, transpose=True)
1350-
if self._solver.n_envs == 0:
1351-
qpos = qpos[0].clone()
1352-
else:
1353-
qpos = qpos[envs_idx]
1353+
qpos = qpos[0] if self._solver.n_envs == 0 else qpos[envs_idx]
13541354

13551355
if return_error:
13561356
error_pose = ti_to_torch(self._IK_err_pose_best, transpose=True).reshape((-1, self._IK_n_tgts, 6))[
13571357
:, :n_links
13581358
]
1359-
if self._solver.n_envs == 0:
1360-
error_pose = error_pose[0].clone()
1361-
else:
1362-
error_pose = error_pose[envs_idx]
1359+
error_pose = error_pose[0] if self._solver.n_envs == 0 else error_pose[envs_idx]
13631360
return qpos, error_pose
13641361
return qpos
13651362

@@ -2029,23 +2026,36 @@ def get_verts(self):
20292026
verts : torch.Tensor, shape (n_envs, n_verts, 3)
20302027
The vertices of the entity.
20312028
"""
2032-
self._solver.update_verts_for_geoms(range(self.geom_start, self.geom_end))
2029+
self._solver.update_verts_for_geoms(slice(self.geom_start, self.geom_end))
20332030

2034-
tensor = torch.empty((self._solver._B, self.n_verts, 3), dtype=gs.tc_float, device=gs.device)
2035-
has_fixed_verts, has_free_vertices = len(self._fixed_verts_idx_local) > 0, len(self._free_verts_idx_local) > 0
2036-
if has_fixed_verts:
2037-
_kernel_get_fixed_verts(
2038-
tensor, self._fixed_verts_idx_local, self._fixed_verts_state_start, self._solver.fixed_verts_state
2039-
)
2040-
if has_free_vertices:
2041-
# FIXME: Get around some bug in gstaichi when using gstaichi with metal backend
2042-
must_copy = gs.backend == gs.metal and has_fixed_verts
2043-
tensor_free = torch.zeros_like(tensor) if must_copy else tensor
2044-
_kernel_get_free_verts(
2045-
tensor_free, self._free_verts_idx_local, self._free_verts_state_start, self._solver.free_verts_state
2046-
)
2047-
if must_copy:
2048-
tensor += tensor_free
2031+
n_fixed_verts, n_free_vertices = self._n_fixed_verts, self._n_free_verts
2032+
tensor = torch.empty((self._solver._B, n_fixed_verts + n_free_vertices, 3), dtype=gs.tc_float, device=gs.device)
2033+
2034+
if n_fixed_verts > 0:
2035+
if gs.use_zerocopy:
2036+
fixed_verts_state = ti_to_torch(self._solver.fixed_verts_state.pos)
2037+
tensor[:, self._fixed_verts_idx_local] = fixed_verts_state[
2038+
self._fixed_verts_state_start : self._fixed_verts_state_start + n_fixed_verts
2039+
]
2040+
else:
2041+
_kernel_get_fixed_verts(
2042+
tensor, self._fixed_verts_idx_local, self._fixed_verts_state_start, self._solver.fixed_verts_state
2043+
)
2044+
if n_free_vertices > 0:
2045+
if gs.use_zerocopy:
2046+
free_verts_state = ti_to_torch(self._solver.free_verts_state.pos, transpose=True)
2047+
tensor[:, self._free_verts_idx_local] = free_verts_state[
2048+
:, self._free_verts_state_start : self._free_verts_state_start + n_free_vertices
2049+
]
2050+
else:
2051+
# FIXME: Get around some bug in gstaichi when using gstaichi with metal backend
2052+
must_copy = gs.backend == gs.metal and n_fixed_verts > 0
2053+
tensor_free = torch.zeros_like(tensor) if must_copy else tensor
2054+
_kernel_get_free_verts(
2055+
tensor_free, self._free_verts_idx_local, self._free_verts_state_start, self._solver.free_verts_state
2056+
)
2057+
if must_copy:
2058+
tensor += tensor_free
20492059

20502060
if self._solver.n_envs == 0:
20512061
tensor = tensor[0]
@@ -2854,6 +2864,8 @@ def n_dofs(self):
28542864
@property
28552865
def n_geoms(self):
28562866
"""The number of `RigidGeom` in the entity."""
2867+
if self._is_built:
2868+
return self._n_geoms
28572869
return sum(link.n_geoms for link in self._links)
28582870

28592871
@property

genesis/engine/entities/rigid_entity/rigid_link.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def get_verts(self):
305305
"""
306306
Get the vertices of the link's collision body (concatenation of all `link.geoms`) in the world frame.
307307
"""
308-
self._solver.update_verts_for_geoms(range(self.geom_start, self.geom_end))
308+
self._solver.update_verts_for_geoms(slice(self.geom_start, self.geom_end))
309309

310310
if self.is_fixed and not self._entity._batch_fixed_verts:
311311
tensor = torch.empty((self.n_verts, 3), dtype=gs.tc_float, device=gs.device)

genesis/engine/solvers/rigid/rigid_solver_decomp.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -894,7 +894,13 @@ def substep(self):
894894
)
895895

896896
def check_errno(self):
897-
match kernel_get_errno(self._errno):
897+
# Note that errno must be evaluated BEFORE match because otherwise it will be evaluated for each case...
898+
# See official documentation: https://docs.python.org/3.10/reference/compound_stmts.html#overview
899+
if gs.use_zerocopy:
900+
errno = int(ti_to_torch(self._errno, copy=None, non_blocking=True))
901+
else:
902+
errno = kernel_get_errno(self._errno)
903+
match errno:
898904
case 1:
899905
max_collision_pairs_broad = self.collider._collider_info.max_collision_pairs_broad[None]
900906
gs.raise_exception(
@@ -1362,8 +1368,10 @@ def _sanitize_1D_io_variables(
13621368
_inputs_idx = torch.as_tensor(inputs_idx, dtype=gs.tc_int, device=gs.device).contiguous()
13631369
if _inputs_idx is not inputs_idx:
13641370
gs.logger.debug(ALLOCATE_TENSOR_WARNING)
1365-
_inputs_idx = torch.atleast_1d(_inputs_idx)
1366-
if _inputs_idx.ndim != 1:
1371+
_inputs_ndim = _inputs_idx.ndim
1372+
if _inputs_ndim == 0:
1373+
_inputs_idx = _inputs_idx[None]
1374+
elif _inputs_ndim > 1:
13671375
gs.raise_exception(f"Expecting 1D tensor for `{idx_name}`.")
13681376
if not ((0 <= _inputs_idx).all() or (_inputs_idx < input_size).all()):
13691377
gs.raise_exception(f"`{idx_name}` is out-of-range.")
@@ -1372,19 +1380,23 @@ def _sanitize_1D_io_variables(
13721380
_tensor = torch.as_tensor(tensor, dtype=gs.tc_float, device=gs.device).contiguous()
13731381
if _tensor is not tensor:
13741382
gs.logger.debug(ALLOCATE_TENSOR_WARNING)
1375-
tensor = _tensor.unsqueeze(0) if batched and self.n_envs and _tensor.ndim == 1 else _tensor
1376-
1383+
tensor_ndim = _tensor.ndim
1384+
if batched and self.n_envs and tensor_ndim == 1:
1385+
tensor = _tensor.unsqueeze(0)
1386+
tensor_ndim += 1
1387+
else:
1388+
tensor = _tensor
13771389
if tensor.shape[-1] != len(inputs_idx):
13781390
gs.raise_exception(f"Last dimension of the input tensor does not match length of `{idx_name}`.")
13791391

13801392
if batched:
13811393
if self.n_envs == 0:
1382-
if tensor.ndim != 1:
1394+
if tensor_ndim != 1:
13831395
gs.raise_exception(
13841396
f"Invalid input shape: {tensor.shape}. Expecting a 1D tensor for non-parallelized scene."
13851397
)
13861398
else:
1387-
if tensor.ndim == 2:
1399+
if tensor_ndim == 2:
13881400
if tensor.shape[0] != len(envs_idx):
13891401
gs.raise_exception(
13901402
f"Invalid input shape: {tensor.shape}. First dimension of the input tensor does not match "
@@ -1395,7 +1407,7 @@ def _sanitize_1D_io_variables(
13951407
f"Invalid input shape: {tensor.shape}. Expecting a 2D tensor for scene with parallelized envs."
13961408
)
13971409
else:
1398-
if tensor.ndim != 1:
1410+
if tensor_ndim != 1:
13991411
gs.raise_exception("Expecting 1D output tensor.")
14001412
return tensor, _inputs_idx, envs_idx
14011413

@@ -2285,7 +2297,12 @@ def get_equality_constraints(self, as_tensor: bool = True, to_torch: bool = True
22852297
return self.constraint_solver.get_equality_constraints(as_tensor, to_torch)
22862298

22872299
def clear_external_force(self):
2288-
kernel_clear_external_force(self.links_state, self._rigid_global_info, self._static_rigid_sim_config)
2300+
if gs.use_zerocopy:
2301+
for tensor in (self.links_state.cfrc_applied_ang, self.links_state.cfrc_applied_vel):
2302+
out = ti_to_python(tensor, copy=False, non_blocking=True)
2303+
out.zero_()
2304+
else:
2305+
kernel_clear_external_force(self.links_state, self._rigid_global_info, self._static_rigid_sim_config)
22892306

22902307
def update_vgeoms(self):
22912308
kernel_update_vgeoms(self.vgeoms_info, self.vgeoms_state, self.links_state, self._static_rigid_sim_config)
@@ -2320,6 +2337,11 @@ def set_drone_rpm(self, n_propellers, propellers_link_idxs, propellers_rpm, prop
23202337
)
23212338

23222339
def update_verts_for_geoms(self, geoms_idx):
2340+
if gs.use_zerocopy:
2341+
verts_updated = ti_to_torch(self.geoms_state.verts_updated, transpose=False)
2342+
if verts_updated[geoms_idx].all():
2343+
return
2344+
23232345
_, geoms_idx, _ = self._sanitize_1D_io_variables(
23242346
None, geoms_idx, self.n_geoms, None, idx_name="geoms_idx", skip_allocation=True, unsafe=False
23252347
)

0 commit comments

Comments
 (0)