Skip to content

Commit 2a8c057

Browse files
committed
finalize differentiable pass
1 parent c413fde commit 2a8c057

File tree

12 files changed

+3138
-1361
lines changed

12 files changed

+3138
-1361
lines changed

genesis/engine/entities/rigid_entity/rigid_entity.py

Lines changed: 217 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from genesis.utils import mjcf as mju
1818
from genesis.utils import terrain as tu
1919
from genesis.utils import urdf as uu
20-
from genesis.utils.misc import ALLOCATE_TENSOR_WARNING, DeprecationError, ti_to_torch
20+
from genesis.utils.misc import ALLOCATE_TENSOR_WARNING, DeprecationError, ti_to_torch, to_gs_tensor
21+
from genesis.engine.states.entities import RigidEntityState
2122

2223
from ..base_entity import Entity
2324
from .rigid_equality import RigidEquality
@@ -95,6 +96,23 @@ def __init__(
9596

9697
self._load_model()
9798

99+
self.init_tgt_vars()
100+
self.init_ckpt()
101+
102+
def init_tgt_keys(self):
103+
104+
self._tgt_keys = ["pos", "quat", "qpos", "dofs_velocity"]
105+
106+
def init_tgt_vars(self):
107+
108+
# temp variable to store targets for next step
109+
self._tgt = []
110+
self._tgt_buffer = []
111+
self.init_tgt_keys()
112+
113+
def init_ckpt(self):
114+
self._ckpt = dict()
115+
98116
def _load_model(self):
99117
self._links = gs.List()
100118
self._joints = gs.List()
@@ -1445,6 +1463,7 @@ def _kernel_forward_kinematics(
14451463
self._solver.entities_info,
14461464
self._solver._rigid_global_info,
14471465
self._solver._static_rigid_sim_config,
1466+
False,
14481467
)
14491468

14501469
ti.loop_config(serialize=self._solver._para_level < gs.PARA_LEVEL.PARTIAL)
@@ -1471,6 +1490,7 @@ def _kernel_forward_kinematics(
14711490
self._solver.entities_info,
14721491
self._solver._rigid_global_info,
14731492
self._solver._static_rigid_sim_config,
1493+
False,
14741494
)
14751495

14761496
# ------------------------------------------------------------------------------------
@@ -1605,6 +1625,128 @@ def plan_path(
16051625
# ------------------------------------------------------------------------------------
16061626
# ---------------------------------- control & io ------------------------------------
16071627
# ------------------------------------------------------------------------------------
1628+
def process_input(self, in_backward=False):
1629+
if in_backward:
1630+
# use negative index because buffer length might not be full
1631+
index = self._sim.cur_step_local - self._sim._steps_local
1632+
self._tgt = self._tgt_buffer[index].copy()
1633+
else:
1634+
self._tgt_buffer.append(self._tgt.copy())
1635+
1636+
# Apply targets sequentially
1637+
_tgt = self._tgt.copy()
1638+
for tgt in _tgt:
1639+
k = tgt["key"]
1640+
assert k in self._tgt_keys, f"Invalid target key: {k} not in {self._tgt_keys}"
1641+
1642+
# We do not need zero velocity here because if it was true, [set_dofs_velocity] from zero_velocity would
1643+
# be in [tgt]
1644+
zero_velocity = False
1645+
if k == "pos":
1646+
_pos = tgt["pos"]
1647+
_envs_idx = tgt["envs_idx"]
1648+
_relative = tgt["relative"]
1649+
_unsafe = tgt["unsafe"]
1650+
1651+
self.set_pos(_pos, envs_idx=_envs_idx, relative=_relative, zero_velocity=zero_velocity, unsafe=_unsafe)
1652+
elif k == "quat":
1653+
_quat = tgt["quat"]
1654+
_envs_idx = tgt["envs_idx"]
1655+
_relative = tgt["relative"]
1656+
_unsafe = tgt["unsafe"]
1657+
1658+
self.set_quat(
1659+
_quat, envs_idx=_envs_idx, relative=_relative, zero_velocity=zero_velocity, unsafe=_unsafe
1660+
)
1661+
elif k == "qpos":
1662+
_qpos = tgt["qpos"]
1663+
_qs_idx_local = tgt["qs_idx_local"]
1664+
_envs_idx = tgt["envs_idx"]
1665+
_unsafe = tgt["unsafe"]
1666+
1667+
self.set_qpos(
1668+
_qpos, qs_idx_local=_qs_idx_local, envs_idx=_envs_idx, zero_velocity=zero_velocity, unsafe=_unsafe
1669+
)
1670+
elif k == "dofs_velocity":
1671+
_velocity = tgt["velocity"]
1672+
_dofs_idx_local = tgt["dofs_idx_local"]
1673+
_envs_idx = tgt["envs_idx"]
1674+
_unsafe = tgt["unsafe"]
1675+
1676+
self.set_dofs_velocity(_velocity, dofs_idx_local=_dofs_idx_local, envs_idx=_envs_idx, unsafe=_unsafe)
1677+
1678+
self._tgt = []
1679+
1680+
def process_input_grad(self):
1681+
index = self._sim.cur_step_local - self._sim._steps_local
1682+
_tgt = self._tgt_buffer[index].copy()
1683+
1684+
for tgt in reversed(_tgt):
1685+
k = tgt["key"]
1686+
assert k in self._tgt_keys, f"Invalid target key: {k} not in {self._tgt_keys}"
1687+
if k == "pos":
1688+
_pos = tgt["pos"]
1689+
_envs_idx = tgt["envs_idx"]
1690+
_relative = tgt["relative"]
1691+
_unsafe = tgt["unsafe"]
1692+
1693+
if _pos is not None and _pos.requires_grad:
1694+
_pos._backward_from_ti(self.set_pos_grad, _envs_idx, _relative, _unsafe)
1695+
1696+
elif k == "quat":
1697+
_quat = tgt["quat"]
1698+
_envs_idx = tgt["envs_idx"]
1699+
_relative = tgt["relative"]
1700+
_unsafe = tgt["unsafe"]
1701+
1702+
if _quat is not None and _quat.requires_grad:
1703+
_quat._backward_from_ti(self.set_quat_grad, _envs_idx, _relative, _unsafe)
1704+
1705+
elif k == "qpos":
1706+
_qpos = tgt["qpos"]
1707+
_qs_idx_local = tgt["qs_idx_local"]
1708+
_envs_idx = tgt["envs_idx"]
1709+
_unsafe = tgt["unsafe"]
1710+
1711+
if _qpos is not None and _qpos.requires_grad:
1712+
# TODO: Not implemented yet
1713+
raise NotImplementedError("Backward pass for set_qpos_grad is not implemented yet.")
1714+
1715+
elif k == "dofs_velocity":
1716+
_velocity = tgt["velocity"]
1717+
_dofs_idx_local = tgt["dofs_idx_local"]
1718+
_envs_idx = tgt["envs_idx"]
1719+
_unsafe = tgt["unsafe"]
1720+
1721+
if _velocity is not None and _velocity.requires_grad:
1722+
_velocity._backward_from_ti(self.set_dofs_velocity_grad, _dofs_idx_local, _envs_idx, _unsafe)
1723+
1724+
def save_ckpt(self, ckpt_name):
1725+
if ckpt_name not in self._ckpt:
1726+
self._ckpt[ckpt_name] = {}
1727+
self._ckpt[ckpt_name]["_tgt_buffer"] = self._tgt_buffer.copy()
1728+
self._tgt_buffer.clear()
1729+
1730+
def load_ckpt(self, ckpt_name):
1731+
self._tgt_buffer = self._ckpt[ckpt_name]["_tgt_buffer"].copy()
1732+
1733+
def reset_grad(self):
1734+
self._tgt_buffer.clear()
1735+
1736+
@gs.assert_built
1737+
def get_state(self):
1738+
state = RigidEntityState(self, self._sim.cur_step_global)
1739+
1740+
solver_state = self._solver.get_state()
1741+
pos = solver_state.links_pos[:, self._base_links_idx].squeeze(-2)
1742+
quat = solver_state.links_quat[:, self._base_links_idx].squeeze(-2)
1743+
1744+
assert state._pos.shape == pos.shape
1745+
assert state._quat.shape == quat.shape
1746+
state._pos = pos
1747+
state._quat = quat
1748+
1749+
return state
16081750

16091751
def get_joint(self, name=None, uid=None):
16101752
"""
@@ -1949,6 +2091,18 @@ def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, uns
19492091
envs_idx : None | array_like, optional
19502092
The indices of the environments. If None, all environments will be considered. Defaults to None.
19512093
"""
2094+
# Save in [tgt] for backward pass
2095+
self._tgt.append(
2096+
{
2097+
"key": "pos",
2098+
"pos": pos,
2099+
"envs_idx": envs_idx,
2100+
"relative": relative,
2101+
"zero_velocity": zero_velocity,
2102+
"unsafe": unsafe,
2103+
}
2104+
)
2105+
19522106
if not unsafe:
19532107
_pos = torch.as_tensor(pos, dtype=gs.tc_float, device=gs.device).contiguous()
19542108
if _pos is not pos:
@@ -1965,6 +2119,18 @@ def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, uns
19652119
if zero_velocity:
19662120
self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe)
19672121

2122+
@gs.assert_built
2123+
def set_pos_grad(self, envs_idx, relative, unsafe, pos_grad):
2124+
tmp_pos_grad = pos_grad.unsqueeze(-2).clone()
2125+
self._solver.set_base_links_pos_grad(
2126+
self._base_links_idx,
2127+
envs_idx,
2128+
relative,
2129+
unsafe,
2130+
tmp_pos_grad,
2131+
)
2132+
pos_grad.data = tmp_pos_grad.squeeze(-2)
2133+
19682134
@gs.assert_built
19692135
def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, unsafe=False):
19702136
"""
@@ -1983,6 +2149,17 @@ def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, u
19832149
envs_idx : None | array_like, optional
19842150
The indices of the environments. If None, all environments will be considered. Defaults to None.
19852151
"""
2152+
# Save in [tgt] for backward pass
2153+
self._tgt.append(
2154+
{
2155+
"key": "quat",
2156+
"quat": quat,
2157+
"envs_idx": envs_idx,
2158+
"relative": relative,
2159+
"zero_velocity": zero_velocity,
2160+
"unsafe": unsafe,
2161+
}
2162+
)
19862163
if not unsafe:
19872164
_quat = torch.as_tensor(quat, dtype=gs.tc_float, device=gs.device).contiguous()
19882165
if _quat is not quat:
@@ -1999,6 +2176,18 @@ def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, u
19992176
if zero_velocity:
20002177
self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe)
20012178

2179+
@gs.assert_built
2180+
def set_quat_grad(self, envs_idx, relative, unsafe, quat_grad):
2181+
tmp_quat_grad = quat_grad.unsqueeze(-2).clone()
2182+
self._solver.set_base_links_quat_grad(
2183+
self._base_links_idx,
2184+
envs_idx,
2185+
relative,
2186+
unsafe,
2187+
tmp_quat_grad,
2188+
)
2189+
quat_grad.data = tmp_quat_grad.squeeze(-2)
2190+
20022191
@gs.assert_built
20032192
def get_verts(self):
20042193
"""
@@ -2091,6 +2280,18 @@ def set_qpos(self, qpos, qs_idx_local=None, envs_idx=None, *, zero_velocity=True
20912280
zero_velocity : bool, optional
20922281
Whether to zero the velocity of all the entity's dofs. Defaults to True. This is a safety measure after a sudden change in entity pose.
20932282
"""
2283+
# Save in [tgt] for backward pass
2284+
self._tgt.append(
2285+
{
2286+
"key": "qpos",
2287+
"qpos": qpos,
2288+
"qs_idx_local": qs_idx_local,
2289+
"envs_idx": envs_idx,
2290+
"zero_velocity": zero_velocity,
2291+
"unsafe": unsafe,
2292+
}
2293+
)
2294+
20942295
qs_idx = self._get_idx(qs_idx_local, self.n_qs, self._q_start, unsafe=True)
20952296
self._solver.set_qpos(qpos, qs_idx, envs_idx, unsafe=unsafe, skip_forward=zero_velocity)
20962297
if zero_velocity:
@@ -2186,9 +2387,24 @@ def set_dofs_velocity(self, velocity=None, dofs_idx_local=None, envs_idx=None, *
21862387
envs_idx : None | array_like, optional
21872388
The indices of the environments. If None, all environments will be considered. Defaults to None.
21882389
"""
2390+
# Save in [tgt] for backward pass
2391+
self._tgt.append(
2392+
{
2393+
"key": "dofs_velocity",
2394+
"velocity": velocity,
2395+
"dofs_idx_local": dofs_idx_local,
2396+
"envs_idx": envs_idx,
2397+
"unsafe": unsafe,
2398+
}
2399+
)
21892400
dofs_idx = self._get_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True)
21902401
self._solver.set_dofs_velocity(velocity, dofs_idx, envs_idx, skip_forward=False, unsafe=unsafe)
21912402

2403+
@gs.assert_built
2404+
def set_dofs_velocity_grad(self, dofs_idx_local, envs_idx, unsafe, velocity_grad):
2405+
dofs_idx = self._get_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True)
2406+
self._solver.set_dofs_velocity_grad(dofs_idx, envs_idx, unsafe, velocity_grad)
2407+
21922408
@gs.assert_built
21932409
def set_dofs_frictionloss(self, frictionloss, dofs_idx_local=None, envs_idx=None, *, unsafe=False):
21942410
"""

genesis/engine/simulator.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -268,18 +268,13 @@ def f_global_to_s_global(self, f_global):
268268
# ------------------------------------------------------------------------------------
269269

270270
def step(self, in_backward=False):
271-
if self._rigid_only: # "Only Advance!" --Thomas Wade :P
272-
for _ in range(self._substeps):
273-
self.rigid_solver.substep()
274-
self._cur_substep_global += 1
275-
else:
276-
self.process_input(in_backward=in_backward)
277-
for _ in range(self._substeps):
278-
self.substep(self.cur_substep_local)
271+
self.process_input(in_backward=in_backward)
272+
for _ in range(self._substeps):
273+
self.substep(self.cur_substep_local)
279274

280-
self._cur_substep_global += 1
281-
if self.cur_substep_local == 0 and not in_backward:
282-
self.save_ckpt()
275+
self._cur_substep_global += 1
276+
if self.cur_substep_local == 0 and not in_backward:
277+
self.save_ckpt()
283278

284279
if self.rigid_solver.is_active:
285280
self.rigid_solver.clear_external_force()

genesis/engine/solvers/rigid/constraint_noslip.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,12 @@ def kernel_build_efc_AR_b(
3636
rigid_solver.func_solve_mass_batched(
3737
constraint_state.Mgrad,
3838
constraint_state.Mgrad,
39+
None,
3940
i_b,
4041
entities_info=entities_info,
4142
rigid_global_info=rigid_global_info,
4243
static_rigid_sim_config=static_rigid_sim_config,
44+
is_backward=False,
4345
)
4446

4547
# AR[r, c] = J[c, :] * tmp
@@ -192,10 +194,12 @@ def kernel_dual_finish(
192194
rigid_solver.func_solve_mass_batched(
193195
vec=constraint_state.qfrc_constraint,
194196
out=constraint_state.qacc,
197+
out_bw=None,
195198
i_b=i_b,
196199
entities_info=entities_info,
197200
rigid_global_info=rigid_global_info,
198201
static_rigid_sim_config=static_rigid_sim_config,
202+
is_backward=False,
199203
)
200204

201205
for i_d in range(n_dofs):
@@ -284,10 +288,12 @@ def compute_A_diag(
284288
rigid_solver.func_solve_mass_batched(
285289
constraint_state.Mgrad,
286290
constraint_state.Mgrad,
291+
None,
287292
i_b,
288293
entities_info=entities_info,
289294
rigid_global_info=rigid_global_info,
290295
static_rigid_sim_config=static_rigid_sim_config,
296+
is_backward=False,
291297
)
292298

293299
# Ai = Ji * tmp

genesis/engine/solvers/rigid/constraint_solver_decomp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1987,10 +1987,12 @@ def func_update_gradient(
19871987
rigid_solver.func_solve_mass_batched(
19881988
constraint_state.grad,
19891989
constraint_state.Mgrad,
1990+
None,
19901991
i_b,
19911992
entities_info=entities_info,
19921993
rigid_global_info=rigid_global_info,
19931994
static_rigid_sim_config=static_rigid_sim_config,
1995+
is_backward=False,
19941996
)
19951997

19961998
elif ti.static(static_rigid_sim_config.solver_type == gs.constraint_solver.Newton):

genesis/engine/solvers/rigid/constraint_solver_decomp_island.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1008,7 +1008,7 @@ def _func_update_gradient(self, island, i_b):
10081008
i_e_ = self.contact_island.island_entity[island, i_b].start + i_island_entity
10091009
i_e = self.contact_island.entity_id[i_e_, i_b]
10101010
self._solver._mass_mat_mask[i_e_, i_b] = 1
1011-
self._solver._func_solve_mass_batched(self.grad, self.Mgrad, i_b)
1011+
self._solver._func_solve_mass_batched(self.grad, self.Mgrad, None, i_b)
10121012
for i_e in range(self._solver.n_entities):
10131013
self._solver._mass_mat_mask[i_e, i_b] = 1
10141014
elif ti.static(self._solver_type == gs.constraint_solver.Newton):

0 commit comments

Comments
 (0)