-
Notifications
You must be signed in to change notification settings - Fork 2.5k
[FEATURE] Differentiable forward dynamics for rigid body sim. #1808
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
2a8c057
48b00ce
6e9f65f
dadece9
66d71f6
89fa76c
bc38956
bdac4fe
61a9902
82d2251
ca2487f
4eca23a
436e986
8b6a03b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,8 @@ | |
| from genesis.utils import mjcf as mju | ||
| from genesis.utils import terrain as tu | ||
| from genesis.utils import urdf as uu | ||
| from genesis.utils.misc import ALLOCATE_TENSOR_WARNING, DeprecationError, ti_to_torch | ||
| from genesis.utils.misc import ALLOCATE_TENSOR_WARNING, DeprecationError, ti_to_torch, to_gs_tensor | ||
| from genesis.engine.states.entities import RigidEntityState | ||
|
|
||
| from ..base_entity import Entity | ||
| from .rigid_equality import RigidEquality | ||
|
|
@@ -95,6 +96,23 @@ def __init__( | |
|
|
||
| self._load_model() | ||
|
|
||
| self.init_tgt_vars() | ||
| self.init_ckpt() | ||
|
|
||
| def init_tgt_keys(self): | ||
|
|
||
| self._tgt_keys = ["pos", "quat", "qpos", "dofs_velocity"] | ||
|
|
||
| def init_tgt_vars(self): | ||
|
|
||
| # temp variable to store targets for next step | ||
| self._tgt = [] | ||
| self._tgt_buffer = [] | ||
| self.init_tgt_keys() | ||
|
|
||
| def init_ckpt(self): | ||
| self._ckpt = dict() | ||
|
|
||
| def _load_model(self): | ||
| self._links = gs.List() | ||
| self._joints = gs.List() | ||
|
|
@@ -1460,6 +1478,7 @@ def _kernel_forward_kinematics( | |
| entities_info, | ||
| rigid_global_info, | ||
| static_rigid_sim_config, | ||
| False, | ||
|
||
| ) | ||
|
|
||
| ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) | ||
|
|
@@ -1486,6 +1505,7 @@ def _kernel_forward_kinematics( | |
| entities_info, | ||
| rigid_global_info, | ||
| static_rigid_sim_config, | ||
| False, | ||
|
||
| ) | ||
|
|
||
| # ------------------------------------------------------------------------------------ | ||
|
|
@@ -1620,6 +1640,128 @@ def plan_path( | |
| # ------------------------------------------------------------------------------------ | ||
| # ---------------------------------- control & io ------------------------------------ | ||
| # ------------------------------------------------------------------------------------ | ||
| def process_input(self, in_backward=False): | ||
| if in_backward: | ||
| # use negative index because buffer length might not be full | ||
| index = self._sim.cur_step_local - self._sim._steps_local | ||
| self._tgt = self._tgt_buffer[index].copy() | ||
| else: | ||
| self._tgt_buffer.append(self._tgt.copy()) | ||
|
|
||
| # Apply targets sequentially | ||
| _tgt = self._tgt.copy() | ||
| for tgt in _tgt: | ||
| k = tgt["key"] | ||
| assert k in self._tgt_keys, f"Invalid target key: {k} not in {self._tgt_keys}" | ||
|
|
||
| # We do not need zero velocity here because if it was true, [set_dofs_velocity] from zero_velocity would | ||
| # be in [tgt] | ||
| zero_velocity = False | ||
| if k == "pos": | ||
| _pos = tgt["pos"] | ||
|
||
| _envs_idx = tgt["envs_idx"] | ||
| _relative = tgt["relative"] | ||
| _unsafe = tgt["unsafe"] | ||
|
|
||
| self.set_pos(_pos, envs_idx=_envs_idx, relative=_relative, zero_velocity=zero_velocity, unsafe=_unsafe) | ||
| elif k == "quat": | ||
| _quat = tgt["quat"] | ||
| _envs_idx = tgt["envs_idx"] | ||
| _relative = tgt["relative"] | ||
| _unsafe = tgt["unsafe"] | ||
|
|
||
| self.set_quat( | ||
| _quat, envs_idx=_envs_idx, relative=_relative, zero_velocity=zero_velocity, unsafe=_unsafe | ||
| ) | ||
| elif k == "qpos": | ||
| _qpos = tgt["qpos"] | ||
| _qs_idx_local = tgt["qs_idx_local"] | ||
| _envs_idx = tgt["envs_idx"] | ||
| _unsafe = tgt["unsafe"] | ||
|
|
||
| self.set_qpos( | ||
| _qpos, qs_idx_local=_qs_idx_local, envs_idx=_envs_idx, zero_velocity=zero_velocity, unsafe=_unsafe | ||
| ) | ||
| elif k == "dofs_velocity": | ||
| _velocity = tgt["velocity"] | ||
| _dofs_idx_local = tgt["dofs_idx_local"] | ||
| _envs_idx = tgt["envs_idx"] | ||
| _unsafe = tgt["unsafe"] | ||
|
|
||
| self.set_dofs_velocity(_velocity, dofs_idx_local=_dofs_idx_local, envs_idx=_envs_idx, unsafe=_unsafe) | ||
|
||
|
|
||
| self._tgt = [] | ||
|
|
||
| def process_input_grad(self): | ||
| index = self._sim.cur_step_local - self._sim._steps_local | ||
| _tgt = self._tgt_buffer[index].copy() | ||
|
|
||
| for tgt in reversed(_tgt): | ||
|
||
| k = tgt["key"] | ||
| assert k in self._tgt_keys, f"Invalid target key: {k} not in {self._tgt_keys}" | ||
| if k == "pos": | ||
| _pos = tgt["pos"] | ||
| _envs_idx = tgt["envs_idx"] | ||
| _relative = tgt["relative"] | ||
| _unsafe = tgt["unsafe"] | ||
|
|
||
| if _pos is not None and _pos.requires_grad: | ||
|
||
| _pos._backward_from_ti(self.set_pos_grad, _envs_idx, _relative, _unsafe) | ||
|
||
|
|
||
| elif k == "quat": | ||
| _quat = tgt["quat"] | ||
| _envs_idx = tgt["envs_idx"] | ||
| _relative = tgt["relative"] | ||
| _unsafe = tgt["unsafe"] | ||
|
|
||
| if _quat is not None and _quat.requires_grad: | ||
| _quat._backward_from_ti(self.set_quat_grad, _envs_idx, _relative, _unsafe) | ||
|
|
||
| elif k == "qpos": | ||
| _qpos = tgt["qpos"] | ||
| _qs_idx_local = tgt["qs_idx_local"] | ||
| _envs_idx = tgt["envs_idx"] | ||
| _unsafe = tgt["unsafe"] | ||
|
|
||
| if _qpos is not None and _qpos.requires_grad: | ||
| # TODO: Not implemented yet | ||
| raise NotImplementedError("Backward pass for set_qpos_grad is not implemented yet.") | ||
|
|
||
| elif k == "dofs_velocity": | ||
| _velocity = tgt["velocity"] | ||
| _dofs_idx_local = tgt["dofs_idx_local"] | ||
| _envs_idx = tgt["envs_idx"] | ||
| _unsafe = tgt["unsafe"] | ||
|
|
||
| if _velocity is not None and _velocity.requires_grad: | ||
| _velocity._backward_from_ti(self.set_dofs_velocity_grad, _dofs_idx_local, _envs_idx, _unsafe) | ||
|
||
|
|
||
| def save_ckpt(self, ckpt_name): | ||
| if ckpt_name not in self._ckpt: | ||
| self._ckpt[ckpt_name] = {} | ||
| self._ckpt[ckpt_name]["_tgt_buffer"] = self._tgt_buffer.copy() | ||
| self._tgt_buffer.clear() | ||
|
|
||
| def load_ckpt(self, ckpt_name): | ||
| self._tgt_buffer = self._ckpt[ckpt_name]["_tgt_buffer"].copy() | ||
|
|
||
| def reset_grad(self): | ||
| self._tgt_buffer.clear() | ||
|
|
||
| @gs.assert_built | ||
| def get_state(self): | ||
| state = RigidEntityState(self, self._sim.cur_step_global) | ||
|
|
||
| solver_state = self._solver.get_state() | ||
| pos = solver_state.links_pos[:, self._base_links_idx].squeeze(-2) | ||
|
||
| quat = solver_state.links_quat[:, self._base_links_idx].squeeze(-2) | ||
|
|
||
| assert state._pos.shape == pos.shape | ||
| assert state._quat.shape == quat.shape | ||
| state._pos = pos | ||
| state._quat = quat | ||
|
|
||
| return state | ||
|
|
||
| def get_joint(self, name=None, uid=None): | ||
| """ | ||
|
|
@@ -1964,6 +2106,18 @@ def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, uns | |
| envs_idx : None | array_like, optional | ||
| The indices of the environments. If None, all environments will be considered. Defaults to None. | ||
| """ | ||
| # Save in [tgt] for backward pass | ||
| self._tgt.append( | ||
| { | ||
| "key": "pos", | ||
| "pos": pos, | ||
| "envs_idx": envs_idx, | ||
| "relative": relative, | ||
| "zero_velocity": zero_velocity, | ||
| "unsafe": unsafe, | ||
| } | ||
| ) | ||
|
|
||
| if not unsafe: | ||
| _pos = torch.as_tensor(pos, dtype=gs.tc_float, device=gs.device).contiguous() | ||
| if _pos is not pos: | ||
|
|
@@ -1980,6 +2134,18 @@ def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, uns | |
| if zero_velocity: | ||
| self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe) | ||
|
|
||
| @gs.assert_built | ||
| def set_pos_grad(self, envs_idx, relative, unsafe, pos_grad): | ||
| tmp_pos_grad = pos_grad.unsqueeze(-2).clone() | ||
| self._solver.set_base_links_pos_grad( | ||
|
||
| self._base_links_idx, | ||
| envs_idx, | ||
| relative, | ||
| unsafe, | ||
| tmp_pos_grad, | ||
| ) | ||
| pos_grad.data = tmp_pos_grad.squeeze(-2) | ||
|
|
||
| @gs.assert_built | ||
| def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, unsafe=False): | ||
| """ | ||
|
|
@@ -1998,6 +2164,17 @@ def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, u | |
| envs_idx : None | array_like, optional | ||
| The indices of the environments. If None, all environments will be considered. Defaults to None. | ||
| """ | ||
| # Save in [tgt] for backward pass | ||
| self._tgt.append( | ||
| { | ||
| "key": "quat", | ||
| "quat": quat, | ||
| "envs_idx": envs_idx, | ||
| "relative": relative, | ||
| "zero_velocity": zero_velocity, | ||
| "unsafe": unsafe, | ||
| } | ||
| ) | ||
| if not unsafe: | ||
| _quat = torch.as_tensor(quat, dtype=gs.tc_float, device=gs.device).contiguous() | ||
| if _quat is not quat: | ||
|
|
@@ -2014,6 +2191,18 @@ def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, u | |
| if zero_velocity: | ||
| self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe) | ||
|
|
||
| @gs.assert_built | ||
| def set_quat_grad(self, envs_idx, relative, unsafe, quat_grad): | ||
| tmp_quat_grad = quat_grad.unsqueeze(-2).clone() | ||
| self._solver.set_base_links_quat_grad( | ||
| self._base_links_idx, | ||
| envs_idx, | ||
| relative, | ||
| unsafe, | ||
| tmp_quat_grad, | ||
| ) | ||
| quat_grad.data = tmp_quat_grad.squeeze(-2) | ||
|
|
||
| @gs.assert_built | ||
| def get_verts(self): | ||
| """ | ||
|
|
@@ -2106,6 +2295,18 @@ def set_qpos(self, qpos, qs_idx_local=None, envs_idx=None, *, zero_velocity=True | |
| zero_velocity : bool, optional | ||
| Whether to zero the velocity of all the entity's dofs. Defaults to True. This is a safety measure after a sudden change in entity pose. | ||
| """ | ||
| # Save in [tgt] for backward pass | ||
| self._tgt.append( | ||
| { | ||
| "key": "qpos", | ||
| "qpos": qpos, | ||
| "qs_idx_local": qs_idx_local, | ||
| "envs_idx": envs_idx, | ||
| "zero_velocity": zero_velocity, | ||
| "unsafe": unsafe, | ||
| } | ||
| ) | ||
|
|
||
| qs_idx = self._get_idx(qs_idx_local, self.n_qs, self._q_start, unsafe=True) | ||
| self._solver.set_qpos(qpos, qs_idx, envs_idx, unsafe=unsafe, skip_forward=zero_velocity) | ||
| if zero_velocity: | ||
|
|
@@ -2201,9 +2402,24 @@ def set_dofs_velocity(self, velocity=None, dofs_idx_local=None, envs_idx=None, * | |
| envs_idx : None | array_like, optional | ||
| The indices of the environments. If None, all environments will be considered. Defaults to None. | ||
| """ | ||
| # Save in [tgt] for backward pass | ||
| self._tgt.append( | ||
| { | ||
| "key": "dofs_velocity", | ||
| "velocity": velocity, | ||
| "dofs_idx_local": dofs_idx_local, | ||
| "envs_idx": envs_idx, | ||
| "unsafe": unsafe, | ||
| } | ||
| ) | ||
| dofs_idx = self._get_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True) | ||
| self._solver.set_dofs_velocity(velocity, dofs_idx, envs_idx, skip_forward=False, unsafe=unsafe) | ||
|
|
||
| @gs.assert_built | ||
| def set_dofs_velocity_grad(self, dofs_idx_local, envs_idx, unsafe, velocity_grad): | ||
| dofs_idx = self._get_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True) | ||
| self._solver.set_dofs_velocity_grad(dofs_idx, envs_idx, unsafe, velocity_grad) | ||
|
|
||
| @gs.assert_built | ||
| def set_dofs_frictionloss(self, frictionloss, dofs_idx_local=None, envs_idx=None, *, unsafe=False): | ||
| """ | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is noslip feature supported when enabling gradient computation? If not, you should raise an exception at init. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,10 +36,12 @@ def kernel_build_efc_AR_b( | |
| rigid_solver.func_solve_mass_batched( | ||
| constraint_state.Mgrad, | ||
| constraint_state.Mgrad, | ||
| constraint_state.Mgrad, # this will not be used anyway because is_backward is False | ||
|
||
| i_b, | ||
| entities_info=entities_info, | ||
| rigid_global_info=rigid_global_info, | ||
| static_rigid_sim_config=static_rigid_sim_config, | ||
| is_backward=False, | ||
| ) | ||
|
|
||
| # AR[r, c] = J[c, :] * tmp | ||
|
|
@@ -192,10 +194,12 @@ def kernel_dual_finish( | |
| rigid_solver.func_solve_mass_batched( | ||
| vec=constraint_state.qfrc_constraint, | ||
| out=constraint_state.qacc, | ||
| out_bw=constraint_state.qacc, # this will not be used anyway because is_backward is False | ||
| i_b=i_b, | ||
| entities_info=entities_info, | ||
| rigid_global_info=rigid_global_info, | ||
| static_rigid_sim_config=static_rigid_sim_config, | ||
| is_backward=False, | ||
| ) | ||
|
|
||
| for i_d in range(n_dofs): | ||
|
|
@@ -284,10 +288,12 @@ def compute_A_diag( | |
| rigid_solver.func_solve_mass_batched( | ||
| constraint_state.Mgrad, | ||
| constraint_state.Mgrad, | ||
| constraint_state.Mgrad, # this will not be used anyway because is_backward is False | ||
| i_b, | ||
| entities_info=entities_info, | ||
| rigid_global_info=rigid_global_info, | ||
| static_rigid_sim_config=static_rigid_sim_config, | ||
| is_backward=False, | ||
| ) | ||
|
|
||
| # Ai = Ji * tmp | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not skip line on top of undocumented methods.