Skip to content

Commit eea9caf

Browse files
authored
[BUG FIX] Fix zero-copy edge-case. (#2048)
* Fix zero-copy edge-case. * More robust data accessor unit test.
1 parent 9f06243 commit eea9caf

File tree

3 files changed

+60
-34
lines changed

3 files changed

+60
-34
lines changed

genesis/engine/solvers/rigid/rigid_solver_decomp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1600,7 +1600,7 @@ def set_qpos(self, qpos, qs_idx=None, envs_idx=None, *, skip_forward=False, unsa
16001600
data = ti_to_torch(self._rigid_global_info.qpos, transpose=True, copy=False)
16011601
data[mask] = torch.as_tensor(qpos, dtype=gs.tc_float, device=gs.device)
16021602
if mask and isinstance(mask[0], torch.Tensor):
1603-
envs_idx = mask[0]
1603+
envs_idx = mask[0].reshape((-1,))
16041604
else:
16051605
qpos, qs_idx, envs_idx = self._sanitize_1D_io_variables(
16061606
qpos, qs_idx, self.n_qs, envs_idx, idx_name="qs_idx", skip_allocation=True, unsafe=unsafe
@@ -1811,7 +1811,7 @@ def set_dofs_velocity(self, velocity, dofs_idx=None, envs_idx=None, *, skip_forw
18111811
mask = (0, *indices_to_mask(dofs_idx)) if self.n_envs == 0 else indices_to_mask(envs_idx, dofs_idx)
18121812
vel[mask] = 0.0 if velocity is None else torch.as_tensor(velocity, dtype=gs.tc_float, device=gs.device)
18131813
if mask and isinstance(mask[0], torch.Tensor):
1814-
envs_idx = mask[0]
1814+
envs_idx = mask[0].reshape((-1,))
18151815
elif not isinstance(envs_idx, torch.Tensor):
18161816
envs_idx = self._scene._sanitize_envs_idx(envs_idx, unsafe=unsafe)
18171817
else:

tests/test_rigid_physics.py

Lines changed: 53 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2875,19 +2875,35 @@ def test_data_accessor(n_envs, batched, tol):
28752875
# * Call 'Get' -> Call 'Set' with 'Get' output -> Call 'Get'
28762876
# * Compare first 'Get' output with last 'Get' output
28772877
# * Compare last 'Get' output with corresponding slice of non-masking 'Get' output
2878-
def get_all_supported_masks(i):
2878+
def get_all_supported_masks(i, max_length):
2879+
if max_length <= 0 or i > max_length - 1:
2880+
return (None,)
2881+
if i == max_length - 1:
2882+
return (
2883+
i,
2884+
[i],
2885+
slice(i, i + 1),
2886+
range(i, i + 1),
2887+
np.array([i], dtype=np.int32),
2888+
torch.tensor([i], dtype=torch.int64),
2889+
torch.tensor([i], dtype=gs.tc_int, device=gs.device),
2890+
)
28792891
return (
2880-
i,
2881-
[i],
2882-
slice(i, i + 1),
2883-
range(i, i + 1),
2884-
np.array([i], dtype=np.int32),
2885-
torch.tensor([i], dtype=torch.int64),
2886-
torch.tensor([i], dtype=gs.tc_int, device=gs.device),
2892+
[i, i + 1],
2893+
slice(i, i + 2),
2894+
range(i, i + 2),
2895+
np.array([i, i + 1], dtype=np.int32),
2896+
torch.tensor([i, i + 1], dtype=torch.int64),
2897+
torch.tensor([i, i + 1], dtype=gs.tc_int, device=gs.device),
28872898
)
28882899

2889-
def must_cast(value):
2890-
return not (isinstance(value, torch.Tensor) and value.dtype == gs.tc_int and value.device == gs.device)
2900+
def must_cast(value, dtype):
2901+
return not (
2902+
isinstance(value, torch.Tensor)
2903+
and value.is_contiguous()
2904+
and value.dtype == dtype
2905+
and value.device == gs.device
2906+
)
28912907

28922908
for arg1_max, arg2_max, getter_or_spec, setter, ti_data in (
28932909
# SOLVER
@@ -3010,66 +3026,73 @@ def must_cast(value):
30103026

30113027
# Check getter and setter for all possible combinations of row and column masking
30123028
for i in range(arg1_max) if arg1_max > 0 else (None,):
3013-
for arg1 in get_all_supported_masks(i) if arg1_max > 0 else (None,):
3029+
if i is not None:
3030+
mask_i = [i, i + 1] if i < arg1_max - 1 else [i]
3031+
for arg1 in get_all_supported_masks(i, arg1_max):
30143032
for j in range(max(arg2_max, 1)) if arg2_max >= 0 else (None,):
3015-
for arg2 in get_all_supported_masks(j) if arg2_max > 0 else (None,):
3033+
if j is not None:
3034+
mask_j = [j, j + 1] if j < arg2_max - 1 else [j]
3035+
for arg2 in get_all_supported_masks(j, arg2_max):
30163036
if arg1 is None and arg2 is not None:
3017-
unsafe = not must_cast(arg2)
3037+
unsafe = not must_cast(arg2, gs.tc_int)
30183038
if getter is not None:
30193039
data = deepcopy(getter(arg2, unsafe=unsafe))
30203040
else:
30213041
if is_tuple:
3022-
data = [torch.ones((1, *shape)) for shape in spec]
3042+
data = [torch.ones((len(mask_j), *shape)) for shape in spec]
30233043
else:
3024-
data = torch.ones((1, *spec))
3044+
data = torch.ones((len(mask_j), *spec))
30253045
if setter is not None:
3046+
unsafe &= not must_cast(data, gs.tc_float)
30263047
setter(data, arg2, unsafe=unsafe)
30273048
if n_envs:
30283049
if is_tuple:
3029-
data_ = [val[[j]] for val in datas]
3050+
data_ = [val[mask_j] for val in datas]
30303051
else:
3031-
data_ = datas[[j]]
3052+
data_ = datas[mask_j]
30323053
else:
30333054
data_ = datas
30343055
elif arg1 is not None and arg2 is None:
3035-
unsafe = not must_cast(arg1)
3056+
unsafe = not must_cast(arg1, gs.tc_int)
30363057
if getter is not None:
30373058
data = deepcopy(getter(arg1, unsafe=unsafe))
30383059
else:
30393060
if is_tuple:
3040-
data = [torch.ones((1, *shape)) for shape in spec]
3061+
data = [torch.ones((len(mask_i), *shape)) for shape in spec]
30413062
else:
3042-
data = torch.ones((1, *spec))
3063+
data = torch.ones((len(mask_i), *spec))
30433064
if setter is not None:
3065+
unsafe &= not must_cast(data, gs.tc_float)
30443066
if is_tuple:
30453067
setter(*data, arg1, unsafe=unsafe)
30463068
else:
30473069
setter(data, arg1, unsafe=unsafe)
30483070
if is_tuple:
3049-
data_ = [val[[i]] for val in datas]
3071+
data_ = [val[mask_i] for val in datas]
30503072
else:
3051-
data_ = datas[[i]]
3073+
data_ = datas[mask_i]
30523074
else:
3053-
unsafe = not any(map(must_cast, (arg1, arg2)))
3075+
unsafe = not any(must_cast(arg, gs.tc_int) for arg in (arg1, arg2))
30543076
if getter is not None:
30553077
data = deepcopy(getter(arg1, arg2, unsafe=unsafe))
30563078
else:
30573079
if is_tuple:
3058-
data = [torch.ones((1, 1, *shape)) for shape in spec]
3080+
data = [torch.ones((len(mask_j), len(mask_i), *shape)) for shape in spec]
30593081
else:
3060-
data = torch.ones((1, 1, *spec))
3082+
data = torch.ones((len(mask_j), len(mask_i), *spec))
30613083
if setter is not None:
3084+
unsafe &= not must_cast(data, gs.tc_float)
30623085
setter(data, arg1, arg2, unsafe=unsafe)
30633086
if is_tuple:
3064-
data_ = [val[[j], :][:, [i]] for val in datas]
3087+
data_ = [val[mask_j, :][:, mask_i] for val in datas]
30653088
else:
3066-
data_ = datas[[j], :][:, [i]]
3089+
data_ = datas[mask_j, :][:, mask_i]
30673090
# FIXME: Not sure why tolerance must be increased for tests to pass
30683091
assert_allclose(data_, data, tol=(5.0 * tol))
30693092

3070-
for dofs_idx in (*get_all_supported_masks(0), None):
3071-
for envs_idx in (*(get_all_supported_masks(0) if n_envs > 0 else ()), None):
3072-
unsafe = not any(map(must_cast, (dofs_idx, envs_idx)))
3093+
for dofs_idx in (*get_all_supported_masks(0, gs_s.n_dofs), None):
3094+
for envs_idx in (*(get_all_supported_masks(0, gs_s.n_dofs) if n_envs > 0 else ()), None):
3095+
unsafe = not any(must_cast(arg, gs.tc_int) for arg in (dofs_idx, envs_idx))
30733096
dofs_pos = gs_s.get_dofs_position(dofs_idx, envs_idx)
30743097
dofs_vel = gs_s.get_dofs_velocity(dofs_idx, envs_idx)
30753098
gs_s.control_dofs_position(dofs_pos, dofs_idx, envs_idx)

tests/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,11 @@ def assert_allclose(actual, desired, *, atol=None, rtol=None, tol=None, err_msg=
276276
# First, try to broadcast both matrices. Then it is does not work, squeeze them before trying again.
277277
try:
278278
args = np.broadcast_arrays(*args)
279-
except ValueError:
280-
args = np.broadcast_arrays(*map(np.squeeze, args))
279+
except ValueError as e:
280+
try:
281+
args = np.broadcast_arrays(*map(np.squeeze, args))
282+
except ValueError:
283+
raise e
281284

282285
np.testing.assert_allclose(*args, atol=atol, rtol=rtol, err_msg=err_msg)
283286

0 commit comments

Comments
 (0)