Skip to content

Commit b1a4935

Browse files
authored
[MISC] Improve robust many box unit test. (#2001)
* Improve robust many box unit test. * Improve unit test coverage. * More efficient 'kernel_update_all_verts'. * Disable fastcache on the CI because it causes crashes.
1 parent a869b46 commit b1a4935

File tree

5 files changed

+29
-21
lines changed

5 files changed

+29
-21
lines changed

.github/workflows/generic.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ jobs:
6868
PY_COLORS: "1"
6969
GS_CACHE_FILE_PATH: ".cache/genesis"
7070
GS_ENABLE_NDARRAY: ${{ matrix.GS_ENABLE_NDARRAY }}
71+
# FIXME: Disabling fastcache because it causes crashes for some reason...
72+
GS_ENABLE_FASTCACHE: "0"
7173
GS_TORCH_FORCE_CPU_DEVICE: ${{ startsWith(matrix.OS, 'macos-') && '1' || '0' }}
7274
TI_OFFLINE_CACHE: "1"
7375
TI_OFFLINE_CACHE_CLEANING_POLICY: "never"

genesis/engine/couplers/sap_coupler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,7 @@ def precompute(self, i_step):
601601

602602
if self.rigid_solver.is_active:
603603
kernel_update_all_verts(
604+
geoms_info=self.rigid_solver.geoms_info,
604605
geoms_state=self.rigid_solver.geoms_state,
605606
verts_info=self.rigid_solver.verts_info,
606607
free_verts_state=self.rigid_solver.free_verts_state,

genesis/engine/sensors/raycaster.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ def _update_bvh(cls, shared_metadata: RaycasterSharedMetadata):
334334
from genesis.engine.solvers.rigid.rigid_solver_decomp import kernel_update_all_verts
335335

336336
kernel_update_all_verts(
337+
geoms_info=shared_metadata.solver.geoms_info,
337338
geoms_state=shared_metadata.solver.geoms_state,
338339
verts_info=shared_metadata.solver.verts_info,
339340
free_verts_state=shared_metadata.solver.free_verts_state,

genesis/engine/solvers/rigid/rigid_solver_decomp.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5005,9 +5005,9 @@ def func_update_verts_for_geom(
50055005
fixed_verts_state: array_class.VertsState,
50065006
):
50075007
if not geoms_state.verts_updated[i_g, i_b]:
5008-
i_g_start = geoms_info.vert_start[i_g]
5009-
if verts_info.is_fixed[i_g_start]:
5010-
for i_v in range(i_g_start, geoms_info.vert_end[i_g]):
5008+
i_v_start = geoms_info.vert_start[i_g]
5009+
if verts_info.is_fixed[i_v_start]:
5010+
for i_v in range(i_v_start, geoms_info.vert_end[i_g]):
50115011
verts_state_idx = verts_info.verts_state_idx[i_v]
50125012
fixed_verts_state.pos[verts_state_idx] = gu.ti_transform_by_trans_quat(
50135013
verts_info.init_pos[i_v], geoms_state.pos[i_g, i_b], geoms_state.quat[i_g, i_b]
@@ -5016,7 +5016,7 @@ def func_update_verts_for_geom(
50165016
for j_b in range(_B):
50175017
geoms_state.verts_updated[i_g, j_b] = True
50185018
else:
5019-
for i_v in range(i_g_start, geoms_info.vert_end[i_g]):
5019+
for i_v in range(i_v_start, geoms_info.vert_end[i_g]):
50205020
verts_state_idx = verts_info.verts_state_idx[i_v]
50215021
free_verts_state.pos[verts_state_idx, i_b] = gu.ti_transform_by_trans_quat(
50225022
verts_info.init_pos[i_v], geoms_state.pos[i_g, i_b], geoms_state.quat[i_g, i_b]
@@ -5026,34 +5026,26 @@ def func_update_verts_for_geom(
50265026

50275027
@ti.func
50285028
def func_update_all_verts(
5029+
geoms_info: array_class.GeomsInfo,
50295030
geoms_state: array_class.GeomsState,
50305031
verts_info: array_class.VertsInfo,
50315032
free_verts_state: array_class.VertsState,
50325033
fixed_verts_state: array_class.VertsState,
50335034
):
5034-
n_verts = verts_info.geom_idx.shape[0]
5035-
_B = geoms_state.pos.shape[1]
5036-
for i_v, i_b in ti.ndrange(n_verts, _B):
5037-
i_g = verts_info.geom_idx[i_v]
5038-
verts_state_idx = verts_info.verts_state_idx[i_v]
5039-
if verts_info.is_fixed[i_v]:
5040-
fixed_verts_state.pos[verts_state_idx] = gu.ti_transform_by_trans_quat(
5041-
verts_info.init_pos[i_v], geoms_state.pos[i_g, i_b], geoms_state.quat[i_g, i_b]
5042-
)
5043-
else:
5044-
free_verts_state.pos[verts_state_idx, i_b] = gu.ti_transform_by_trans_quat(
5045-
verts_info.init_pos[i_v], geoms_state.pos[i_g, i_b], geoms_state.quat[i_g, i_b]
5046-
)
5035+
n_geoms, _B = geoms_state.pos.shape
5036+
for i_g, i_b in ti.ndrange(n_geoms, _B):
5037+
func_update_verts_for_geom(i_g, i_b, geoms_state, geoms_info, verts_info, free_verts_state, fixed_verts_state)
50475038

50485039

50495040
@ti.kernel(fastcache=gs.use_fastcache)
50505041
def kernel_update_all_verts(
5042+
geoms_info: array_class.GeomsInfo,
50515043
geoms_state: array_class.GeomsState,
50525044
verts_info: array_class.VertsInfo,
50535045
free_verts_state: array_class.VertsState,
50545046
fixed_verts_state: array_class.VertsState,
50555047
):
5056-
func_update_all_verts(geoms_state, verts_info, free_verts_state, fixed_verts_state)
5048+
func_update_all_verts(geoms_info, geoms_state, verts_info, free_verts_state, fixed_verts_state)
50575049

50585050

50595051
@ti.kernel
@@ -5115,7 +5107,6 @@ def func_hibernate__for_all_awake_islands_either_hiberanate_or_update_aabb_sort_
51155107
static_rigid_sim_config: ti.template(),
51165108
contact_island_state: array_class.ContactIslandState,
51175109
) -> None:
5118-
51195110
n_entities = entities_state.hibernated.shape[0]
51205111
_B = entities_state.hibernated.shape[1]
51215112

tests/test_rigid_physics.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,7 @@ def test_many_boxes_dynamics(box_box_detection, gjk_collision, dynamics, show_vi
877877
dt=0.01,
878878
),
879879
rigid_options=gs.options.RigidOptions(
880+
max_collision_pairs=1000,
880881
box_box_detection=box_box_detection,
881882
use_gjk_collision=gjk_collision,
882883
),
@@ -893,7 +894,7 @@ def test_many_boxes_dynamics(box_box_detection, gjk_collision, dynamics, show_vi
893894
i, j, k = int(n / 25), int(n / 5) % 5, n % 5
894895
scene.add_entity(
895896
gs.morphs.Box(
896-
pos=(i * 1.01, j * 1.01, k * 1.01 + 0.5),
897+
pos=(i * (1.0 - 1e-3), j * (1.0 - 1e-3), k * (1.0 - 1e-3) + 0.5),
897898
size=(1.0, 1.0, 1.0),
898899
),
899900
surface=gs.surfaces.Default(
@@ -920,7 +921,7 @@ def test_many_boxes_dynamics(box_box_detection, gjk_collision, dynamics, show_vi
920921
assert qpos[:2].norm() < 20.0
921922
assert qpos[2] < 5.0
922923
else:
923-
qpos0 = np.array((i * 1.01, j * 1.01, k * 1.01 + 0.5))
924+
qpos0 = np.array((i * (1.0 - 1e-3), j * (1.0 - 1e-3), k * (1.0 - 1e-3) + 0.5))
924925
assert_allclose(qpos[:3], qpos0, atol=0.05)
925926
assert_allclose(qpos[3:], 0, atol=0.03)
926927

@@ -1220,6 +1221,13 @@ def test_set_root_pose(batch_fixed_verts, relative, show_viewer, tol):
12201221
batch_fixed_verts=batch_fixed_verts,
12211222
),
12221223
)
1224+
sphere = scene.add_entity(
1225+
gs.morphs.Sphere(
1226+
radius=0.04,
1227+
batch_fixed_verts=False,
1228+
fixed=True,
1229+
),
1230+
)
12231231
cube = scene.add_entity(
12241232
gs.morphs.Box(
12251233
size=(0.04, 0.04, 0.04),
@@ -1244,6 +1252,7 @@ def test_set_root_pose(batch_fixed_verts, relative, show_viewer, tol):
12441252
scene.visualizer.update()
12451253
cube.set_pos(pos_delta[[0]] + (0.0, 0.0, 0.16), envs_idx=[0])
12461254
cube.set_pos(pos_delta[[1]] + (0.0, 0.0, 0.11), envs_idx=[1])
1255+
sphere.set_pos(np.tile(pos_delta[[0]], (2, 1)) + 1.0)
12471256
quat_delta = np.random.rand(2, 4)
12481257
with nullcontext() if batch_fixed_verts else pytest.raises(gs.GenesisException):
12491258
robot.set_quat(quat_delta)
@@ -1257,6 +1266,10 @@ def test_set_root_pose(batch_fixed_verts, relative, show_viewer, tol):
12571266
if show_viewer:
12581267
scene.visualizer.update()
12591268

1269+
sphere_aabb, sphere_base_aabb = sphere.get_AABB(), sphere.geoms[0].get_AABB()
1270+
assert_allclose(sphere_aabb.mean(dim=-2), pos_delta[0] + 1.0, tol=tol)
1271+
assert_allclose(sphere.get_AABB(), sphere.geoms[0].get_AABB(), tol=tol)
1272+
12601273
# Simulate for a while to check if the dynamic object is colliding with the static one
12611274
if batch_fixed_verts:
12621275
has_collided = torch.tensor([False, False], dtype=torch.bool, device=gs.device)

0 commit comments

Comments
 (0)