Skip to content

Commit 2b64f0e

Browse files
authored
[MISC] Avoid forcing copy in getters. (#2030)
* Remove fragile link data caching mechanism now that zero-copy is available. * Avoid forcing copy in getters. * Disable 'ti.u1' on Vulkan because it is broken.
1 parent cc2b954 commit 2b64f0e

File tree

9 files changed

+37
-76
lines changed

9 files changed

+37
-76
lines changed

genesis/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,9 @@ def init(
171171
tc_int = torch.int32
172172

173173
# Bool
174-
# Note that `ti.u1` is broken on Apple Metal and output garbage.
174+
# Note that `ti.u1` is broken on Apple Metal and Vulkan.
175175
global ti_bool, np_bool, tc_bool
176-
if backend == gs_backend.metal:
176+
if backend in (gs_backend.metal, gs_backend.vulkan):
177177
ti_bool = ti.i32
178178
np_bool = np.int32
179179
tc_bool = torch.int32

genesis/engine/solvers/rigid/rigid_solver_decomp.py

Lines changed: 11 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,7 @@
1414
from genesis.engine.states.solvers import RigidSolverState
1515
from genesis.options.solvers import RigidOptions
1616
from genesis.utils import linalg as lu
17-
from genesis.utils.misc import (
18-
ALLOCATE_TENSOR_WARNING,
19-
DeprecationError,
20-
ti_to_torch,
21-
ti_to_numpy,
22-
ti_to_python,
23-
indices_to_mask,
24-
_get_ti_metadata,
25-
)
17+
from genesis.utils.misc import ALLOCATE_TENSOR_WARNING, DeprecationError, ti_to_torch, ti_to_numpy, indices_to_mask
2618
from genesis.utils.sdf_decomp import SDF
2719

2820
from ..base_solver import Solver
@@ -131,7 +123,6 @@ def __init__(self, scene: "Scene", sim: "Simulator", options: RigidOptions) -> N
131123
self._options = options
132124

133125
self._cur_step = -1
134-
self._links_state_cache = {}
135126

136127
self.qpos: ti.Template | ti.types.NDArray | None = None
137128

@@ -819,38 +810,10 @@ def _init_constraint_solver(self):
819810
else:
820811
self.constraint_solver = ConstraintSolver(self)
821812

822-
def _get_links_data(
823-
self,
824-
field_name: str,
825-
row_mask: slice | int | range | list | torch.Tensor | np.ndarray | None = None,
826-
col_mask: slice | int | range | list | torch.Tensor | np.ndarray | None = None,
827-
keepdim=True,
828-
*,
829-
to_torch=True,
830-
):
831-
links_state_py = self._links_state_cache.setdefault((to_torch,), {})
832-
833-
field = getattr(self.links_state, field_name)
834-
tensor = links_state_py.get(field_name)
835-
if tensor is None:
836-
tensor = links_state_py[field_name] = ti_to_python(field, transpose=True, to_torch=to_torch)
837-
838-
ti_data_meta = _get_ti_metadata(field)
839-
if len(ti_data_meta.shape) < 2:
840-
if row_mask is not None and col_mask is not None:
841-
gs.raise_exception("Cannot specify both row and colum masks for tensor with 1D batch.")
842-
mask = indices_to_mask(row_mask if col_mask is None else col_mask, keepdim=keepdim, to_torch=to_torch)
843-
else:
844-
mask = indices_to_mask(row_mask, col_mask, keepdim=keepdim, to_torch=to_torch)
845-
846-
return tensor[mask]
847-
848813
def substep(self):
849814
# from genesis.utils.tools import create_timer
850815
from genesis.engine.couplers import SAPCoupler
851816

852-
self._links_state_cache.clear()
853-
854817
kernel_step_1(
855818
links_state=self.links_state,
856819
links_info=self.links_info,
@@ -1301,7 +1264,6 @@ def set_state(self, f, state, envs_idx=None):
13011264
self.collider.clear(envs_idx)
13021265
if self.constraint_solver is not None:
13031266
self.constraint_solver.reset(envs_idx)
1304-
self._links_state_cache.clear()
13051267
self._cur_step = -1
13061268

13071269
def process_input(self, in_backward=False):
@@ -1541,7 +1503,6 @@ def set_base_links_pos(self, pos, links_idx=None, envs_idx=None, *, relative=Fal
15411503
static_rigid_sim_config=self._static_rigid_sim_config,
15421504
)
15431505

1544-
self._links_state_cache.clear()
15451506
kernel_forward_kinematics_links_geoms(
15461507
envs_idx,
15471508
links_state=self.links_state,
@@ -1590,7 +1551,6 @@ def set_base_links_quat(self, quat, links_idx=None, envs_idx=None, *, relative=F
15901551
static_rigid_sim_config=self._static_rigid_sim_config,
15911552
)
15921553

1593-
self._links_state_cache.clear()
15941554
kernel_forward_kinematics_links_geoms(
15951555
envs_idx,
15961556
links_state=self.links_state,
@@ -1668,7 +1628,6 @@ def set_qpos(self, qpos, qs_idx=None, envs_idx=None, *, skip_forward=False, unsa
16681628
qpos = qpos.unsqueeze(0)
16691629
kernel_set_qpos(qpos, qs_idx, envs_idx, self._rigid_global_info, self._static_rigid_sim_config)
16701630

1671-
self._links_state_cache.clear()
16721631
self.collider.reset(envs_idx, cache_only=True)
16731632
if not isinstance(envs_idx, torch.Tensor):
16741633
envs_idx = self._scene._sanitize_envs_idx(envs_idx, unsafe=unsafe)
@@ -1877,7 +1836,6 @@ def set_dofs_velocity(self, velocity, dofs_idx=None, envs_idx=None, *, skip_forw
18771836
velocity = velocity.unsqueeze(0)
18781837
kernel_set_dofs_velocity(velocity, dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config)
18791838

1880-
self._links_state_cache.clear()
18811839
if not skip_forward:
18821840
kernel_forward_velocity(
18831841
envs_idx,
@@ -1908,7 +1866,6 @@ def set_dofs_position(self, position, dofs_idx=None, envs_idx=None, *, unsafe=Fa
19081866
self._static_rigid_sim_config,
19091867
)
19101868

1911-
self._links_state_cache.clear()
19121869
self.collider.reset(envs_idx, cache_only=True)
19131870
self.collider.clear(envs_idx)
19141871
if self.constraint_solver is not None:
@@ -2055,20 +2012,20 @@ def get_links_pos(
20552012
):
20562013
ref = self._convert_ref_to_idx(ref)
20572014
if ref == 0:
2058-
tensor = self._get_links_data("root_COM", envs_idx, links_idx, to_torch=to_torch)
2015+
tensor = ti_to_torch(self.links_state.root_COM, envs_idx, links_idx, transpose=True)
20592016
elif ref == 1:
2060-
i_pos = self._get_links_data("i_pos", envs_idx, links_idx, to_torch=to_torch)
2061-
root_COM = self._get_links_data("root_COM", envs_idx, links_idx, to_torch=to_torch)
2017+
i_pos = ti_to_torch(self.links_state.i_pos, envs_idx, links_idx, transpose=True)
2018+
root_COM = ti_to_torch(self.links_state.root_COM, envs_idx, links_idx, transpose=True)
20622019
tensor = i_pos + root_COM
20632020
elif ref == 2:
2064-
tensor = self._get_links_data("pos", envs_idx, links_idx, to_torch=to_torch)
2021+
tensor = ti_to_torch(self.links_state.pos, envs_idx, links_idx, transpose=True)
20652022
else:
20662023
gs.raise_exception("'ref' must be either 'link_origin', 'link_com', or 'root_com'.")
20672024

20682025
return tensor[0] if self.n_envs == 0 else tensor
20692026

20702027
def get_links_quat(self, links_idx=None, envs_idx=None, *, to_torch=True, unsafe=False):
2071-
tensor = self._get_links_data("quat", envs_idx, links_idx, to_torch=to_torch)
2028+
tensor = ti_to_torch(self.links_state.quat, envs_idx, links_idx, transpose=True)
20722029
return tensor[0] if self.n_envs == 0 else tensor
20732030

20742031
def get_links_vel(
@@ -2103,7 +2060,7 @@ def get_links_vel(
21032060
return _tensor
21042061

21052062
def get_links_ang(self, links_idx=None, envs_idx=None, *, to_torch=True, unsafe=False):
2106-
tensor = self._get_links_data("cd_ang", envs_idx, links_idx, to_torch=to_torch)
2063+
tensor = ti_to_torch(self.links_state.cd_ang, envs_idx, links_idx, transpose=True)
21072064
return tensor[0] if self.n_envs == 0 else tensor
21082065

21092066
def get_links_acc(self, links_idx=None, envs_idx=None, *, unsafe=False):
@@ -2121,7 +2078,7 @@ def get_links_acc(self, links_idx=None, envs_idx=None, *, unsafe=False):
21212078
return _tensor
21222079

21232080
def get_links_acc_ang(self, links_idx=None, envs_idx=None, *, to_torch=True, unsafe=False):
2124-
tensor = self._get_links_data("cacc_ang", envs_idx, links_idx, to_torch=to_torch)
2081+
tensor = ti_to_torch(self.links_state.cacc_ang, envs_idx, links_idx, transpose=True)
21252082
return tensor[0] if self.n_envs == 0 else tensor
21262083

21272084
def get_links_root_COM(self, links_idx=None, envs_idx=None, *, to_torch=True, unsafe=False):
@@ -2131,15 +2088,15 @@ def get_links_root_COM(self, links_idx=None, envs_idx=None, *, to_torch=True, un
21312088
This corresponds to the global COM of each entity, assuming a single-rooted structure - that is, as long as no
21322089
two successive links are connected by a free-floating joint (ie a joint that allows all 6 degrees of freedom).
21332090
"""
2134-
tensor = self._get_links_data("root_COM", envs_idx, links_idx, to_torch=to_torch)
2091+
tensor = ti_to_torch(self.links_state.root_COM, envs_idx, links_idx, transpose=True)
21352092
return tensor[0] if self.n_envs == 0 else tensor
21362093

21372094
def get_links_mass_shift(self, links_idx=None, envs_idx=None, *, to_torch=True, unsafe=False):
2138-
tensor = self._get_links_data("mass_shift", envs_idx, links_idx, to_torch=to_torch)
2095+
tensor = ti_to_torch(self.links_state.mass_shift, envs_idx, links_idx, transpose=True)
21392096
return tensor[0] if self.n_envs == 0 else tensor
21402097

21412098
def get_links_COM_shift(self, links_idx=None, envs_idx=None, *, to_torch=True, unsafe=False):
2142-
tensor = self._get_links_data("i_pos_shift", envs_idx, links_idx, to_torch=to_torch)
2099+
tensor = ti_to_torch(self.links_state.i_pos_shift, envs_idx, links_idx, transpose=True)
21432100
return tensor[0] if self.n_envs == 0 else tensor
21442101

21452102
def get_links_inertial_mass(self, links_idx=None, envs_idx=None, *, unsafe=False):

genesis/options/renderers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class RayTracer(RendererOptions):
7272
rr_depth: int = 0
7373
rr_threshold: float = 0.95
7474

75-
# environment texure
75+
# environment texture
7676
env_surface: Optional[Surface] = None
7777
env_radius: float = 1000.0
7878
env_pos: tuple = (0.0, 0.0, 0.0)

genesis/utils/misc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ def _get_ti_metadata(value: ti.Field | ti.Ndarray) -> FieldMetadata:
571571
def ti_to_python(
572572
value: ti.Field | ti.Ndarray,
573573
transpose: bool = False,
574-
copy: bool | None = True,
574+
copy: bool | None = None,
575575
to_torch: bool = True,
576576
) -> torch.Tensor | np.ndarray:
577577
"""Converts a GsTaichi field / ndarray instance to a PyTorch tensor / Numpy array.
@@ -761,7 +761,7 @@ def ti_to_torch(
761761
keepdim=True,
762762
transpose=False,
763763
*,
764-
copy: bool | None = True,
764+
copy: bool | None = None,
765765
) -> torch.Tensor:
766766
"""Converts a GsTaichi field / ndarray instance to a PyTorch tensor.
767767
@@ -800,7 +800,7 @@ def ti_to_numpy(
800800
keepdim=True,
801801
transpose=False,
802802
*,
803-
copy: bool | None = True,
803+
copy: bool | None = None,
804804
) -> np.ndarray:
805805
"""Converts a GsTaichi field / ndarray instance to a Numpy array.
806806

genesis/utils/path_planning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def update_object(self, ee_link_idx, obj_link_idx, _pos, _quat, envs_idx):
8383
# ------------------------------------------------------------------------------------
8484

8585
def _sanitize_qposs(self, qpos_goal, qpos_start, envs_idx):
86-
qpos_cur = self._entity.get_qpos(envs_idx=envs_idx)
86+
qpos_cur = self._entity.get_qpos(envs_idx=envs_idx).clone()
8787

8888
qpos_goal, _, _ = self._solver._sanitize_1D_io_variables(
8989
qpos_goal, None, self._entity.n_qs, envs_idx, idx_name="qpos_idx", skip_allocation=True

tests/test_pbd.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,8 @@ def test_cloth_attach_rigid_link(show_viewer):
187187
vel = np.array([[-0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]], dtype=np.float32)
188188
box.set_dofs_velocity(vel, dofs_idx_local=[0, 1, 2])
189189

190-
cloth_pos0 = cloth.get_particles_pos()[:, particles_idx]
191-
link_pos0 = scene.rigid_solver.links[box_link_idx].get_pos()
190+
cloth_pos0 = cloth.get_particles_pos()[:, particles_idx].clone()
191+
link_pos0 = scene.rigid_solver.links[box_link_idx].get_pos().clone()
192192

193193
for _ in range(25):
194194
scene.step()
@@ -199,8 +199,8 @@ def test_cloth_attach_rigid_link(show_viewer):
199199
scene.step()
200200

201201
# Check that the attached particles followed the link displacement per env
202-
cloth_pos1 = cloth.get_particles_pos()[:, particles_idx]
203-
link_pos1 = scene.rigid_solver.links[box_link_idx].get_pos()
202+
cloth_pos1 = cloth.get_particles_pos()[:, particles_idx].clone()
203+
link_pos1 = scene.rigid_solver.links[box_link_idx].get_pos().clone()
204204

205205
cloth_disp = cloth_pos1 - cloth_pos0
206206
link_disp = link_pos1 - link_pos0
@@ -216,8 +216,8 @@ def test_cloth_attach_rigid_link(show_viewer):
216216
scene.step()
217217

218218
# Make sure that the cloth is laying on the ground without moving
219-
cloth_pos2 = cloth.get_particles_pos()[:, particles_idx]
220-
link_pos2 = scene.rigid_solver.links[box_link_idx].get_pos()
219+
cloth_pos2 = cloth.get_particles_pos()[:, particles_idx].clone()
220+
link_pos2 = scene.rigid_solver.links[box_link_idx].get_pos().clone()
221221
cloth_disp = cloth_pos2 - cloth_pos1
222222
link_disp = link_pos2 - link_pos1
223223
link_disp = link_disp.unsqueeze(1)

tests/test_rigid_physics.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import tempfile
66
import xml.etree.ElementTree as ET
77
from contextlib import nullcontext
8+
from copy import deepcopy
89
from typing import cast
910
from pathlib import Path
1011

@@ -979,17 +980,17 @@ def test_robot_scale_and_dofs_armature(xml_path, tol):
979980
# It is also a good opportunity to check that it updates 'invweight' and meaninertia accordingly.
980981
attr_orig = {}
981982
for scale, robot in zip(ROBOT_SCALES, scene.entities):
982-
links_invweight = robot.get_links_invweight()
983-
dofs_invweight = robot.get_dofs_invweight()
983+
links_invweight = robot.get_links_invweight().clone()
984+
dofs_invweight = robot.get_dofs_invweight().clone()
984985
robot.set_dofs_armature(torch.ones((robot.n_dofs,), dtype=gs.tc_float, device=gs.device))
985986
assert torch.all(robot.get_dofs_invweight() < 1.0)
986987
with pytest.raises(AssertionError):
987988
assert_allclose(robot.get_dofs_invweight(), dofs_invweight, tol=tol)
988989
with pytest.raises(AssertionError):
989990
assert_allclose(robot.get_links_invweight(), links_invweight, tol=tol)
990991
robot.set_dofs_armature(torch.zeros((robot.n_dofs,), dtype=gs.tc_float, device=gs.device))
991-
links_invweight = robot.get_links_invweight()
992-
dofs_invweight = robot.get_dofs_invweight()
992+
links_invweight = robot.get_links_invweight().clone()
993+
dofs_invweight = robot.get_dofs_invweight().clone()
993994
qpos = np.random.rand(robot.n_dofs)
994995
robot.set_dofs_position(qpos)
995996
robot.set_dofs_armature(torch.zeros((robot.n_dofs,), dtype=gs.tc_float, device=gs.device))
@@ -2911,7 +2912,7 @@ def must_cast(value):
29112912

29122913
# Check getter and setter without row or column masking
29132914
if getter is not None:
2914-
datas = getter()
2915+
datas = deepcopy(getter())
29152916
is_tuple = isinstance(datas, (tuple, list))
29162917
if arg1_max > 0:
29172918
assert_allclose(getter(range(arg1_max)), datas, tol=tol)
@@ -2964,7 +2965,7 @@ def must_cast(value):
29642965
if arg1 is None and arg2 is not None:
29652966
unsafe = not must_cast(arg2)
29662967
if getter is not None:
2967-
data = getter(arg2, unsafe=unsafe)
2968+
data = deepcopy(getter(arg2, unsafe=unsafe))
29682969
else:
29692970
if is_tuple:
29702971
data = [torch.ones((1, *shape)) for shape in spec]
@@ -2982,7 +2983,7 @@ def must_cast(value):
29822983
elif arg1 is not None and arg2 is None:
29832984
unsafe = not must_cast(arg1)
29842985
if getter is not None:
2985-
data = getter(arg1, unsafe=unsafe)
2986+
data = deepcopy(getter(arg1, unsafe=unsafe))
29862987
else:
29872988
if is_tuple:
29882989
data = [torch.ones((1, *shape)) for shape in spec]
@@ -3000,7 +3001,7 @@ def must_cast(value):
30003001
else:
30013002
unsafe = not any(map(must_cast, (arg1, arg2)))
30023003
if getter is not None:
3003-
data = getter(arg1, arg2, unsafe=unsafe)
3004+
data = deepcopy(getter(arg1, arg2, unsafe=unsafe))
30043005
else:
30053006
if is_tuple:
30063007
data = [torch.ones((1, 1, *shape)) for shape in spec]

tests/test_sensors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ def test_raycaster_hits(show_viewer, n_envs):
419419
entity.set_pos(pos)
420420
if show_viewer:
421421
scene.visualizer.update(force=True)
422-
grid_sensor_pos = grid_sensor.get_pos()
422+
grid_sensor_pos = grid_sensor.get_pos().clone()
423423
for _ in range(100):
424424
scene.step()
425425
grid_sensor.set_pos(grid_sensor_pos)

tests/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from functools import cache
1212
from itertools import chain
1313
from pathlib import Path
14+
from types import GeneratorType
1415
from typing import Literal, Sequence
1516

1617
import cpuinfo
@@ -261,6 +262,8 @@ def assert_allclose(actual, desired, *, atol=None, rtol=None, tol=None, err_msg=
261262
# Convert input arguments as numpy arrays
262263
args = [actual, desired]
263264
for i, arg in enumerate(args):
265+
if isinstance(arg, (GeneratorType, map)):
266+
arg = tuple(arg)
264267
if isinstance(arg, (tuple, list)):
265268
arg = np.stack([tensor_to_array(val) for val in arg], axis=0)
266269
args[i] = tensor_to_array(arg)

0 commit comments

Comments
 (0)