Skip to content

Commit bdad406

Browse files
authored
[MISC] Use zero-copy for all 'set_dof_*' except armature. (#2037)
* Keep benchmark PR comment in one line. * Add dofs acceleration in Rigid State. * Use zero-copy for all 'set_dof_*' except armature.
1 parent 83f4fed commit bdad406

File tree

5 files changed

+28
-23
lines changed

5 files changed

+28
-23
lines changed

.github/workflows/alarm.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -474,9 +474,8 @@ jobs:
474474
}
475475
476476
const title = (process.env.HAS_REGRESSIONS || '0') === '1'
477-
? 'Benchmark Regression Detected' : 'Abnormal Benchmark Result Detected';
478-
const comment = `:warning: **${title}**
479-
➡️ **[Report](${process.env.REPORT_URL})**`;
477+
? '🔴 Benchmark Regression Detected' : '⚠️ Abnormal Benchmark Result Detected';
478+
const comment = `**${title} ➡️ [Report](${process.env.REPORT_URL})**`;
480479
481480
await github.rest.issues.createComment({
482481
owner: context.repo.owner,

genesis/engine/solvers/rigid/rigid_solver_decomp.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,22 +1197,10 @@ def get_state(self, f):
11971197
if self.is_active:
11981198
state = RigidSolverState(self._scene)
11991199

1200-
# qpos: ti.types.ndarray(),
1201-
# vel: ti.types.ndarray(),
1202-
# links_pos: ti.types.ndarray(),
1203-
# links_quat: ti.types.ndarray(),
1204-
# i_pos_shift: ti.types.ndarray(),
1205-
# mass_shift: ti.types.ndarray(),
1206-
# friction_ratio: ti.types.ndarray(),
1207-
# links_state: array_class.LinksState,
1208-
# dofs_state: array_class.DofsState,
1209-
# geoms_state: array_class.GeomsState,
1210-
# rigid_global_info: array_class.RigidGlobalInfo,
1211-
# static_rigid_sim_config: ti.template(),
1212-
12131200
kernel_get_state(
12141201
qpos=state.qpos,
12151202
vel=state.dofs_vel,
1203+
acc=state.dofs_acc,
12161204
links_pos=state.links_pos,
12171205
links_quat=state.links_quat,
12181206
i_pos_shift=state.i_pos_shift,
@@ -1234,6 +1222,7 @@ def set_state(self, f, state, envs_idx=None):
12341222
kernel_set_state(
12351223
qpos=state.qpos,
12361224
dofs_vel=state.dofs_vel,
1225+
dofs_acc=state.dofs_acc,
12371226
links_pos=state.links_pos,
12381227
links_quat=state.links_quat,
12391228
i_pos_shift=state.i_pos_shift,
@@ -1749,6 +1738,14 @@ def set_sol_params(self, sol_params, geoms_idx=None, envs_idx=None, *, joints_id
17491738
)
17501739

17511740
def _set_dofs_info(self, tensor_list, dofs_idx, name, envs_idx=None, *, unsafe=False):
1741+
if gs.use_zerocopy and name in {"kp", "kv", "force_range", "stiffness", "damping", "frictionloss", "limit"}:
1742+
mask = indices_to_mask(*((envs_idx, dofs_idx) if self._options.batch_dofs_info else (dofs_idx,)))
1743+
data = ti_to_torch(getattr(self.dofs_info, name), transpose=True, copy=False)
1744+
num_values = len(tensor_list)
1745+
for j, mask_j in enumerate(((*mask, ..., j) for j in range(num_values)) if num_values > 1 else (mask,)):
1746+
data[mask_j] = torch.as_tensor(tensor_list[j], dtype=gs.tc_float, device=gs.device)
1747+
return
1748+
17521749
tensor_list = list(tensor_list)
17531750
for j, tensor in enumerate(tensor_list):
17541751
tensor_list[j], dofs_idx, envs_idx_ = self._sanitize_1D_io_variables(
@@ -1775,7 +1772,7 @@ def _set_dofs_info(self, tensor_list, dofs_idx, name, envs_idx=None, *, unsafe=F
17751772
elif name == "armature":
17761773
kernel_set_dofs_armature(tensor_list[0], dofs_idx, envs_idx_, self.dofs_info, self._static_rigid_sim_config)
17771774
qs_idx = torch.arange(self.n_qs, dtype=gs.tc_int, device=gs.device)
1778-
qpos_cur = self.get_qpos(envs_idx=envs_idx, qs_idx=qs_idx, unsafe=unsafe)
1775+
qpos_cur = self.get_qpos(qs_idx=qs_idx, envs_idx=envs_idx, unsafe=unsafe)
17791776
self._init_invweight_and_meaninertia(envs_idx=envs_idx, force_update=True, unsafe=unsafe)
17801777
self.set_qpos(qpos_cur, qs_idx=qs_idx, envs_idx=envs_idx, unsafe=unsafe)
17811778
elif name == "damping":
@@ -5842,11 +5839,11 @@ def kernel_update_vgeoms_render_T(
58425839
vgeoms_render_T[(i_g, i_b, *J)] = ti.cast(geom_T[J], ti.float32)
58435840

58445841

5845-
# FIXME: This kernel cannot use 'pure' because 'gs.Tensor' is currently not support by GsTaichi
5846-
@ti.kernel(fastcache=False)
5842+
@ti.kernel(fastcache=gs.use_fastcache)
58475843
def kernel_get_state(
58485844
qpos: ti.types.ndarray(),
58495845
vel: ti.types.ndarray(),
5846+
acc: ti.types.ndarray(),
58505847
links_pos: ti.types.ndarray(),
58515848
links_quat: ti.types.ndarray(),
58525849
i_pos_shift: ti.types.ndarray(),
@@ -5872,6 +5869,7 @@ def kernel_get_state(
58725869
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
58735870
for i_d, i_b in ti.ndrange(n_dofs, _B):
58745871
vel[i_b, i_d] = dofs_state.vel[i_d, i_b]
5872+
acc[i_b, i_d] = dofs_state.acc[i_d, i_b]
58755873

58765874
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
58775875
for i_l, i_b in ti.ndrange(n_links, _B):
@@ -5887,11 +5885,11 @@ def kernel_get_state(
58875885
friction_ratio[i_b, i_l] = geoms_state.friction_ratio[i_l, i_b]
58885886

58895887

5890-
# FIXME: This kernel cannot use 'pure' because 'gs.Tensor' is currently not support by GsTaichi
5891-
@ti.kernel(fastcache=False)
5888+
@ti.kernel(fastcache=gs.use_fastcache)
58925889
def kernel_set_state(
58935890
qpos: ti.types.ndarray(),
58945891
dofs_vel: ti.types.ndarray(),
5892+
dofs_acc: ti.types.ndarray(),
58955893
links_pos: ti.types.ndarray(),
58965894
links_quat: ti.types.ndarray(),
58975895
i_pos_shift: ti.types.ndarray(),
@@ -5917,6 +5915,7 @@ def kernel_set_state(
59175915
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
59185916
for i_d, i_b_ in ti.ndrange(n_dofs, envs_idx.shape[0]):
59195917
dofs_state.vel[i_d, envs_idx[i_b_]] = dofs_vel[envs_idx[i_b_], i_d]
5918+
dofs_state.acc[i_d, envs_idx[i_b_]] = dofs_acc[envs_idx[i_b_], i_d]
59205919
dofs_state.ctrl_force[i_d, envs_idx[i_b_]] = gs.ti_float(0.0)
59215920
dofs_state.ctrl_mode[i_d, envs_idx[i_b_]] = gs.CTRL_MODE.FORCE
59225921

genesis/engine/states/solvers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(self, scene):
5959
}
6060
self.qpos = gs.zeros((_B, scene.sim.rigid_solver.n_qs), **args)
6161
self.dofs_vel = gs.zeros((_B, scene.sim.rigid_solver.n_dofs), **args)
62+
self.dofs_acc = gs.zeros((_B, scene.sim.rigid_solver.n_dofs), **args)
6263
self.links_pos = gs.zeros((_B, scene.sim.rigid_solver.n_links, 3), **args)
6364
self.links_quat = gs.zeros((_B, scene.sim.rigid_solver.n_links, 4), **args)
6465
self.i_pos_shift = gs.zeros((_B, scene.sim.rigid_solver.n_links, 3), **args)
@@ -69,6 +70,7 @@ def serializable(self):
6970
self.scene = None
7071
self.qpos = self.qpos.detach()
7172
self.dofs_vel = self.dofs_vel.detach()
73+
self.dofs_acc = self.dofs_acc.detach()
7274
self.links_pos = self.links_pos.detach()
7375
self.links_quat = self.links_quat.detach()
7476
self.i_pos_shift = self.i_pos_shift.detach()

tests/test_rigid_benchmarks.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,9 @@ def _batched_franka(solver, n_envs, gjk, is_collision_free, accessors):
433433
reset_envs_idx = None
434434
qpos0 = ctrl
435435

436+
dofs_stiffness = franka.get_dofs_stiffness()
437+
dofs_damping = franka.get_dofs_damping()
438+
436439
num_steps = 0
437440
is_recording = False
438441
time_start = time.time()
@@ -441,6 +444,8 @@ def _batched_franka(solver, n_envs, gjk, is_collision_free, accessors):
441444
if accessors:
442445
franka.set_qpos(qpos0, envs_idx=reset_envs_idx, zero_velocity=False, skip_forward=True)
443446
franka.set_dofs_velocity(vel0, envs_idx=reset_envs_idx, skip_forward=True)
447+
franka.set_dofs_stiffness(dofs_stiffness)
448+
franka.set_dofs_damping(dofs_damping)
444449
franka.get_ang()
445450
franka.get_vel()
446451
franka.get_dofs_position()

tests/test_rigid_physics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2855,8 +2855,8 @@ def must_cast(value):
28552855
(gs_s.n_dofs, n_envs, gs_s.get_dofs_velocity, gs_s.set_dofs_velocity, gs_s.dofs_state.vel),
28562856
(gs_s.n_dofs, n_envs, gs_s.get_dofs_position, gs_s.set_dofs_position, gs_s.dofs_state.pos),
28572857
(gs_s.n_dofs, -1, gs_s.get_dofs_force_range, gs_s.set_dofs_force_range, gs_s.dofs_info.force_range),
2858-
(gs_s.n_dofs, -1, gs_s.get_dofs_limit, None, gs_s.dofs_info.limit),
2859-
(gs_s.n_dofs, -1, gs_s.get_dofs_stiffness, None, gs_s.dofs_info.stiffness),
2858+
(gs_s.n_dofs, -1, gs_s.get_dofs_limit, gs_s.set_dofs_limit, gs_s.dofs_info.limit),
2859+
(gs_s.n_dofs, -1, gs_s.get_dofs_stiffness, gs_s.set_dofs_stiffness, gs_s.dofs_info.stiffness),
28602860
(gs_s.n_dofs, -1, gs_s.get_dofs_invweight, None, gs_s.dofs_info.invweight),
28612861
(gs_s.n_dofs, -1, gs_s.get_dofs_armature, gs_s.set_dofs_armature, gs_s.dofs_info.armature),
28622862
(gs_s.n_dofs, -1, gs_s.get_dofs_damping, gs_s.set_dofs_damping, gs_s.dofs_info.damping),

0 commit comments

Comments
 (0)