Skip to content

Commit cb8dbd3

Browse files
authored
[MISC] Add data accessor benchmark. (#2020)
1 parent 13cd721 commit cb8dbd3

File tree

1 file changed

+25
-8
lines changed

1 file changed

+25
-8
lines changed

tests/test_rigid_benchmarks.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def anymal_c(solver, n_envs, gjk):
391391
return {"compile_time": compile_time, "runtime_fps": runtime_fps, "realtime_factor": realtime_factor}
392392

393393

394-
def _batched_franka(solver, n_envs, gjk, is_collision_free):
394+
def _batched_franka(solver, n_envs, gjk, is_collision_free, accessors):
395395
scene = gs.Scene(
396396
rigid_options=gs.options.RigidOptions(
397397
**get_rigid_solver_options(
@@ -416,18 +416,28 @@ def _batched_franka(solver, n_envs, gjk, is_collision_free):
416416
scene.build(n_envs=n_envs)
417417
compile_time = time.time() - time_start
418418

419+
ctrl = torch.tensor([0, 0, 0, -1.0, 0, 1.0, 0, 0.02, 0.02], dtype=gs.tc_float, device=gs.device)
420+
if n_envs > 0:
421+
ctrl = torch.tile(ctrl, (n_envs, 1))
419422
if is_collision_free:
420-
franka.control_dofs_position(
421-
torch.tile(
422-
torch.tensor([0, 0, 0, -1.0, 0, 1.0, 0, 0.02, 0.02], dtype=gs.tc_float, device=gs.device), (n_envs, 1)
423-
),
424-
)
423+
franka.control_dofs_position(ctrl)
425424

426425
num_steps = 0
427426
is_recording = False
428427
time_start = time.time()
429428
while True:
430429
scene.step()
430+
if accessors:
431+
franka.get_ang()
432+
franka.get_vel()
433+
franka.get_dofs_position()
434+
franka.get_dofs_velocity()
435+
franka.get_links_pos()
436+
franka.get_links_quat()
437+
franka.get_links_vel()
438+
franka.get_contacts()
439+
franka.control_dofs_position(ctrl)
440+
431441
time_elapsed = time.time() - time_start
432442
if is_recording:
433443
num_steps += 1
@@ -444,12 +454,17 @@ def _batched_franka(solver, n_envs, gjk, is_collision_free):
444454

445455
@pytest.fixture
446456
def batched_franka(solver, n_envs, gjk):
447-
return _batched_franka(solver, n_envs, gjk, is_collision_free=False)
457+
return _batched_franka(solver, n_envs, gjk, is_collision_free=False, accessors=False)
448458

449459

450460
@pytest.fixture
451461
def batched_franka_free(solver, n_envs, gjk):
452-
return _batched_franka(solver, n_envs, gjk, is_collision_free=True)
462+
return _batched_franka(solver, n_envs, gjk, is_collision_free=True, accessors=False)
463+
464+
465+
@pytest.fixture
466+
def batched_franka_accessors(solver, n_envs, gjk):
467+
return _batched_franka(solver, n_envs, gjk, is_collision_free=True, accessors=True)
453468

454469

455470
def _duck_in_box(solver, n_envs, gjk, hard):
@@ -645,6 +660,8 @@ def box_pyramid(n_envs, n_cubes, enable_island, gjk):
645660
("anymal_c", gs.constraint_solver.Newton, None, 30000, gs.gpu),
646661
("anymal_c", None, None, 0, gs.gpu),
647662
("anymal_c", None, None, 0, gs.cpu),
663+
("batched_franka_accessors", None, None, 0, gs.cpu),
664+
("batched_franka_accessors", None, None, 30000, gs.gpu),
648665
("batched_franka_free", None, False, 30000, gs.gpu),
649666
("batched_franka_free", None, True, 30000, gs.gpu),
650667
("batched_franka", None, True, 30000, gs.gpu),

0 commit comments

Comments
 (0)