Skip to content

Commit 13cd721

Browse files
authored
[MISC] Speedup zero-copy mode. (#2019)
* Speedup 'extract_slice'. * Remove non-blocking mode.
1 parent 2c95b26 commit 13cd721

File tree

2 files changed

+12
-21
lines changed

2 files changed

+12
-21
lines changed

genesis/engine/solvers/rigid/rigid_solver_decomp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -897,7 +897,7 @@ def check_errno(self):
897897
# Note that errno must be evaluated BEFORE match because otherwise it will be evaluated for each case...
898898
# See official documentation: https://docs.python.org/3.10/reference/compound_stmts.html#overview
899899
if gs.use_zerocopy:
900-
errno = int(ti_to_torch(self._errno, copy=None, non_blocking=True))
900+
errno = int(ti_to_torch(self._errno, copy=None))
901901
else:
902902
errno = kernel_get_errno(self._errno)
903903
match errno:
@@ -2299,7 +2299,7 @@ def get_equality_constraints(self, as_tensor: bool = True, to_torch: bool = True
22992299
def clear_external_force(self):
23002300
if gs.use_zerocopy:
23012301
for tensor in (self.links_state.cfrc_applied_ang, self.links_state.cfrc_applied_vel):
2302-
out = ti_to_python(tensor, copy=False, non_blocking=True)
2302+
out = ti_to_python(tensor, copy=False)
23032303
out.zero_()
23042304
else:
23052305
kernel_clear_external_force(self.links_state, self._rigid_global_info, self._static_rigid_sim_config)

genesis/utils/misc.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,6 @@ def ti_to_python(
573573
transpose: bool = False,
574574
copy: bool | None = True,
575575
to_torch: bool = True,
576-
non_blocking: bool = False,
577576
) -> torch.Tensor | np.ndarray:
578577
"""Converts a GsTaichi field / ndarray instance to a PyTorch tensor / Numpy array.
579578
@@ -582,8 +581,6 @@ def ti_to_python(
582581
transpose (bool, optional): Whether to move the last batch dimension in front. Defaults to False.
583582
copy (bool, optional): Wether to enforce returning a copy no matter what. None to avoid copy if possible
584583
without raising an exception if not.
585-
non_blocking (bool): Whether to skip GPU synchronization. It will be faster, but there will be no guarantee
586-
that the return buffer is up-to-date. Default to False.
587584
to_torch (bool): Whether to convert to Torch tensor or Numpy array. Defaults to True.
588585
"""
589586
# Check if copy mode is supported while setting default mode if not specified.
@@ -621,8 +618,6 @@ def ti_to_python(
621618
value._np = value._tc.numpy()
622619
if not to_torch:
623620
out = value._np
624-
if not non_blocking:
625-
ti.sync()
626621
if copy:
627622
if to_torch:
628623
out = out.clone()
@@ -669,11 +664,11 @@ def ti_to_python(
669664
# Transpose if necessary and requested.
670665
# Note that it is worth transposing here before slicing, as it preserve row-major memory alignment in case of
671666
# advanced masking, which would spare computation later on if expected from the user.
672-
if transpose and len(ti_data_meta.shape) > 1:
667+
if transpose and (batch_ndim := len(ti_data_meta.shape)) > 1:
673668
if to_torch:
674-
out = out.movedim(out.ndim - ti_data_meta.ndim - 1, 0)
669+
out = out.movedim(batch_ndim - 1, 0)
675670
else:
676-
out = np.moveaxis(out, out.ndim - ti_data_meta.ndim - 1, 0)
671+
out = np.moveaxis(out, batch_ndim - 1, 0)
677672

678673
return out
679674

@@ -695,8 +690,10 @@ def extract_slice(
695690
unsafe (bool): Whether to skip validity check of the masks.
696691
"""
697692
# Make sure that the user-arguments are valid if requested
693+
if col_mask is not None:
694+
is_vector = value.ndim == 1
698695
if not unsafe:
699-
if col_mask is not None and value.ndim == 1:
696+
if col_mask is not None and is_vector:
700697
gs.raise_exception("Cannot specify column mask for 1D tensor.")
701698
for i, mask in enumerate((row_mask, col_mask)):
702699
if mask is None or isinstance(mask, slice):
@@ -749,7 +746,7 @@ def extract_slice(
749746
out = out[row_mask, col_mask]
750747
else:
751748
if col_mask is not None:
752-
out = out[col_mask] if out.ndim == 1 else out[:, col_mask]
749+
out = out[col_mask] if is_vector else out[:, col_mask]
753750
if row_mask is not None:
754751
out = out[row_mask]
755752
except IndexError as e:
@@ -765,7 +762,7 @@ def extract_slice(
765762
if is_single_row:
766763
out = out[None]
767764
if is_single_col:
768-
out = out[None] if value.ndim == 1 else out[:, None]
765+
out = out[None] if is_vector else out[:, None]
769766

770767
return out
771768

@@ -778,7 +775,6 @@ def ti_to_torch(
778775
transpose=False,
779776
*,
780777
copy: bool | None = True,
781-
non_blocking: bool = False,
782778
unsafe=False,
783779
) -> torch.Tensor:
784780
"""Converts a GsTaichi field / ndarray instance to a PyTorch tensor.
@@ -791,12 +787,10 @@ def ti_to_torch(
791787
transpose (bool): Whether move to front the first non-batch dimension.
792788
copy (bool, optional): Wether to enforce returning a copy no matter what. None to avoid copy if possible
793789
without raising an exception if not.
794-
non_blocking (bool): Whether to skip GPU synchronization. It will be faster, but there will be no guarantee
795-
that the return buffer is up-to-date. Default to False.
796790
unsafe (bool, optional): Whether to skip validity check of the masks.
797791
"""
798792
# FIXME: Ideally one should detect if slicing would require a copy to avoid enforcing copy here
799-
tensor = ti_to_python(value, transpose, copy=copy, non_blocking=non_blocking, to_torch=True)
793+
tensor = ti_to_python(value, transpose, copy=copy, to_torch=True)
800794
if row_mask is None and col_mask is None:
801795
return tensor
802796

@@ -820,7 +814,6 @@ def ti_to_numpy(
820814
transpose=False,
821815
*,
822816
copy: bool | None = True,
823-
non_blocking: bool = False,
824817
unsafe=False,
825818
) -> np.ndarray:
826819
"""Converts a GsTaichi field / ndarray instance to a Numpy array.
@@ -833,11 +826,9 @@ def ti_to_numpy(
833826
transpose (bool, optional): Whether move to front the first non-batch dimension.
834827
copy (bool, optional): Wether to enforce returning a copy no matter what. None to avoid copy if possible
835828
without raising an exception if not.
836-
non_blocking (bool): Whether to skip GPU synchronization. It will be faster, but there will be no guarantee
837-
that the return buffer is up-to-date. Default to False.
838829
unsafe (bool, optional): Whether to skip validity check of the masks.
839830
"""
840-
tensor = ti_to_python(value, transpose, copy=copy, non_blocking=non_blocking, to_torch=False)
831+
tensor = ti_to_python(value, transpose, copy=copy, to_torch=False)
841832
if row_mask is None and col_mask is None:
842833
return tensor
843834

0 commit comments

Comments
 (0)