Skip to content
205 changes: 200 additions & 5 deletions genesis/engine/entities/rigid_entity/rigid_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from genesis.utils import mjcf as mju
from genesis.utils import terrain as tu
from genesis.utils import urdf as uu
from genesis.utils.misc import ALLOCATE_TENSOR_WARNING, DeprecationError, ti_to_torch
from genesis.utils.misc import ALLOCATE_TENSOR_WARNING, DeprecationError, ti_to_torch, to_gs_tensor
from genesis.engine.states.entities import RigidEntityState

from ..base_entity import Entity
from .rigid_equality import RigidEquality
Expand Down Expand Up @@ -95,6 +96,22 @@ def __init__(

self._load_model()

# Initialize target variables and checkpoint
self._tgt_keys = ["pos", "quat", "qpos", "dofs_velocity"]
self._tgt = dict()
self._tgt_buffer = list()
self._ckpt = dict()

def update_tgt(self, key, value):
# Set [self._tgt] value while keeping the insertion order between keys. When a new key is inserted or an existing
# key is updated, the new element should be inserted at the end of the dict. This is because we need to keep
# the insertion order to correctly pass the gradients in the backward pass.
self._tgt.pop(key, None)
self._tgt[key] = value

def init_ckpt(self):
pass

def _load_model(self):
self._links = gs.List()
self._joints = gs.List()
Expand Down Expand Up @@ -1447,6 +1464,7 @@ def _kernel_forward_kinematics(
entities_info,
rigid_global_info,
static_rigid_sim_config,
is_backward=False,
)

ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL))
Expand Down Expand Up @@ -1476,6 +1494,7 @@ def _kernel_forward_kinematics(
entities_info,
rigid_global_info,
static_rigid_sim_config,
is_backward=False,
)

# ------------------------------------------------------------------------------------
Expand Down Expand Up @@ -1610,6 +1629,104 @@ def plan_path(
# ------------------------------------------------------------------------------------
# ---------------------------------- control & io ------------------------------------
# ------------------------------------------------------------------------------------
def process_input(self, in_backward=False):
if in_backward:
# use negative index because buffer length might not be full
index = self._sim.cur_step_local - self._sim._steps_local
self._tgt = self._tgt_buffer[index].copy()
else:
self._tgt_buffer.append(self._tgt.copy())

# Apply targets in the order of insertion
for key in self._tgt.keys():
data_kwargs = self._tgt[key]

# We do not need zero velocity here because if it was true, [set_dofs_velocity] from zero_velocity would
# be in [tgt]
if "zero_velocity" in data_kwargs:
data_kwargs["zero_velocity"] = False
# Do not update [tgt], as input information is finalized at this point
data_kwargs["update_tgt"] = False

match key:
case "pos":
self.set_pos(**data_kwargs)
case "quat":
self.set_quat(**data_kwargs)
case "qpos":
self.set_qpos(**data_kwargs)
case "dofs_velocity":
self.set_dofs_velocity(**data_kwargs)
case _:
gs.raise_exception(f"Invalid target key: {key} not in {self._tgt_keys}")

self._tgt = dict()

def process_input_grad(self):
index = self._sim.cur_step_local - self._sim._steps_local
for key in reversed(self._tgt_buffer[index].keys()):
data_kwargs = self._tgt_buffer[index][key]

match key:
# We need to unpack the data_kwargs because [_backward_from_ti] only supports positional arguments
case "pos":
pos = data_kwargs.pop("pos")
if pos.requires_grad:
pos._backward_from_ti(
self.set_pos_grad, data_kwargs["envs_idx"], data_kwargs["relative"], data_kwargs["unsafe"]
)

case "quat":
quat = data_kwargs.pop("quat")
if quat.requires_grad:
quat._backward_from_ti(
self.set_quat_grad, data_kwargs["envs_idx"], data_kwargs["relative"], data_kwargs["unsafe"]
)

case "qpos":
qpos = data_kwargs.pop("qpos")
if qpos.requires_grad:
raise NotImplementedError("Backward pass for set_qpos_grad is not implemented yet.")

case "dofs_velocity":
velocity = data_kwargs.pop("velocity")
# [velocity] could be None when we want to zero the velocity (see set_dofs_velocity of RigidSolver)
if velocity is not None and velocity.requires_grad:
velocity._backward_from_ti(
self.set_dofs_velocity_grad,
data_kwargs["dofs_idx_local"],
data_kwargs["envs_idx"],
data_kwargs["unsafe"],
)
case _:
gs.raise_exception(f"Invalid target key: {key} not in {self._tgt_keys}")

def save_ckpt(self, ckpt_name):
if ckpt_name not in self._ckpt:
self._ckpt[ckpt_name] = {}
self._ckpt[ckpt_name]["_tgt_buffer"] = self._tgt_buffer.copy()
self._tgt_buffer.clear()

def load_ckpt(self, ckpt_name):
self._tgt_buffer = self._ckpt[ckpt_name]["_tgt_buffer"].copy()

def reset_grad(self):
self._tgt_buffer.clear()

@gs.assert_built
def get_state(self):
state = RigidEntityState(self, self._sim.cur_step_global)

solver_state = self._solver.get_state()
pos = solver_state.links_pos[:, self.base_link_idx]
quat = solver_state.links_quat[:, self.base_link_idx]

assert state._pos.shape == pos.shape
assert state._quat.shape == quat.shape
state._pos = pos
state._quat = quat

return state

def get_joint(self, name=None, uid=None):
"""
Expand Down Expand Up @@ -1937,7 +2054,7 @@ def get_links_invweight(self, links_idx_local=None, envs_idx=None, *, unsafe=Fal
return self._solver.get_links_invweight(links_idx, envs_idx, unsafe=unsafe)

@gs.assert_built
def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, unsafe=False):
def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, unsafe=False, update_tgt=True):
"""
Set position of the entity's base link.

Expand All @@ -1954,6 +2071,19 @@ def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, uns
envs_idx : None | array_like, optional
The indices of the environments. If None, all environments will be considered. Defaults to None.
"""
# Save in [tgt] for backward pass
if update_tgt:
self.update_tgt(
"pos",
{
"pos": pos,
"envs_idx": envs_idx,
"relative": relative,
"zero_velocity": zero_velocity,
"unsafe": unsafe,
},
)

if not unsafe:
_pos = torch.as_tensor(pos, dtype=gs.tc_float, device=gs.device).contiguous()
if _pos is not pos:
Expand All @@ -1971,7 +2101,19 @@ def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, uns
self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe)

@gs.assert_built
def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, unsafe=False):
def set_pos_grad(self, envs_idx, relative, unsafe, pos_grad):
tmp_pos_grad = pos_grad.unsqueeze(-2)
self._solver.set_base_links_pos_grad(
self._base_links_idx,
envs_idx,
relative,
unsafe,
tmp_pos_grad,
)
pos_grad.data = tmp_pos_grad.squeeze(-2)

@gs.assert_built
def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, unsafe=False, update_tgt=True):
"""
Set quaternion of the entity's base link.

Expand All @@ -1988,6 +2130,18 @@ def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, u
envs_idx : None | array_like, optional
The indices of the environments. If None, all environments will be considered. Defaults to None.
"""
# Save in [tgt] for backward pass
if update_tgt:
self.update_tgt(
"quat",
{
"quat": quat,
"envs_idx": envs_idx,
"relative": relative,
"zero_velocity": zero_velocity,
"unsafe": unsafe,
},
)
if not unsafe:
_quat = torch.as_tensor(quat, dtype=gs.tc_float, device=gs.device).contiguous()
if _quat is not quat:
Expand All @@ -2004,6 +2158,18 @@ def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, u
if zero_velocity:
self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe)

@gs.assert_built
def set_quat_grad(self, envs_idx, relative, unsafe, quat_grad):
tmp_quat_grad = quat_grad.unsqueeze(-2)
self._solver.set_base_links_quat_grad(
self._base_links_idx,
envs_idx,
relative,
unsafe,
tmp_quat_grad,
)
quat_grad.data = tmp_quat_grad.squeeze(-2)

@gs.assert_built
def get_verts(self):
"""
Expand Down Expand Up @@ -2081,7 +2247,7 @@ def _get_idx(self, idx_local, idx_local_max, idx_global_start=0, *, unsafe=False
return idx_global

@gs.assert_built
def set_qpos(self, qpos, qs_idx_local=None, envs_idx=None, *, zero_velocity=True, unsafe=False):
def set_qpos(self, qpos, qs_idx_local=None, envs_idx=None, *, zero_velocity=True, unsafe=False, update_tgt=True):
"""
Set the entity's qpos.

Expand All @@ -2096,6 +2262,19 @@ def set_qpos(self, qpos, qs_idx_local=None, envs_idx=None, *, zero_velocity=True
zero_velocity : bool, optional
Whether to zero the velocity of all the entity's dofs. Defaults to True. This is a safety measure after a sudden change in entity pose.
"""
# Save in [tgt] for backward pass
if update_tgt:
self.update_tgt(
"qpos",
{
"qpos": qpos,
"qs_idx_local": qs_idx_local,
"envs_idx": envs_idx,
"zero_velocity": zero_velocity,
"unsafe": unsafe,
},
)

qs_idx = self._get_idx(qs_idx_local, self.n_qs, self._q_start, unsafe=True)
self._solver.set_qpos(qpos, qs_idx, envs_idx, unsafe=unsafe, skip_forward=zero_velocity)
if zero_velocity:
Expand Down Expand Up @@ -2178,7 +2357,7 @@ def set_dofs_damping(self, damping, dofs_idx_local=None, envs_idx=None, *, unsaf
self._solver.set_dofs_damping(damping, dofs_idx, envs_idx, unsafe=unsafe)

@gs.assert_built
def set_dofs_velocity(self, velocity=None, dofs_idx_local=None, envs_idx=None, *, unsafe=False):
def set_dofs_velocity(self, velocity=None, dofs_idx_local=None, envs_idx=None, *, unsafe=False, update_tgt=True):
"""
Set the entity's dofs' velocity.

Expand All @@ -2191,9 +2370,25 @@ def set_dofs_velocity(self, velocity=None, dofs_idx_local=None, envs_idx=None, *
envs_idx : None | array_like, optional
The indices of the environments. If None, all environments will be considered. Defaults to None.
"""
# Save in [tgt] for backward pass
if update_tgt:
self.update_tgt(
"dofs_velocity",
{
"velocity": velocity,
"dofs_idx_local": dofs_idx_local,
"envs_idx": envs_idx,
"unsafe": unsafe,
},
)
dofs_idx = self._get_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True)
self._solver.set_dofs_velocity(velocity, dofs_idx, envs_idx, skip_forward=False, unsafe=unsafe)

@gs.assert_built
def set_dofs_velocity_grad(self, dofs_idx_local, envs_idx, unsafe, velocity_grad):
dofs_idx = self._get_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True)
self._solver.set_dofs_velocity_grad(dofs_idx, envs_idx, unsafe, velocity_grad)

@gs.assert_built
def set_dofs_frictionloss(self, frictionloss, dofs_idx_local=None, envs_idx=None, *, unsafe=False):
"""
Expand Down
17 changes: 6 additions & 11 deletions genesis/engine/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,18 +272,13 @@ def f_global_to_s_global(self, f_global):
# ------------------------------------------------------------------------------------

def step(self, in_backward=False):
if self._rigid_only: # "Only Advance!" --Thomas Wade :P
for _ in range(self._substeps):
self.rigid_solver.substep()
self._cur_substep_global += 1
else:
self.process_input(in_backward=in_backward)
for _ in range(self._substeps):
self.substep(self.cur_substep_local)
self.process_input(in_backward=in_backward)
for _ in range(self._substeps):
self.substep(self.cur_substep_local)

self._cur_substep_global += 1
if self.cur_substep_local == 0 and not in_backward:
self.save_ckpt()
self._cur_substep_global += 1
if self.cur_substep_local == 0 and not in_backward:
self.save_ckpt()

if self.rigid_solver.is_active:
self.rigid_solver.clear_external_force()
Expand Down
6 changes: 6 additions & 0 deletions genesis/engine/solvers/rigid/constraint_noslip.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is noslip feature supported when enabling gradient computation? If not, you should raise an exception at init.

Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ def kernel_build_efc_AR_b(
rigid_solver.func_solve_mass_batched(
constraint_state.Mgrad,
constraint_state.Mgrad,
array_class.PLACEHOLDER,
i_b,
entities_info=entities_info,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
is_backward=False,
)

# AR[r, c] = J[c, :] * tmp
Expand Down Expand Up @@ -189,10 +191,12 @@ def kernel_dual_finish(
rigid_solver.func_solve_mass_batched(
vec=constraint_state.qfrc_constraint,
out=constraint_state.qacc,
out_bw=array_class.PLACEHOLDER,
i_b=i_b,
entities_info=entities_info,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
is_backward=False,
)

for i_d in range(n_dofs):
Expand Down Expand Up @@ -280,10 +284,12 @@ def compute_A_diag(
rigid_solver.func_solve_mass_batched(
constraint_state.Mgrad,
constraint_state.Mgrad,
array_class.PLACEHOLDER,
i_b,
entities_info=entities_info,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
is_backward=False,
)

# Ai = Ji * tmp
Expand Down
2 changes: 2 additions & 0 deletions genesis/engine/solvers/rigid/constraint_solver_decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1984,10 +1984,12 @@ def func_update_gradient(
rigid_solver.func_solve_mass_batched(
constraint_state.grad,
constraint_state.Mgrad,
array_class.PLACEHOLDER,
i_b,
entities_info=entities_info,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
is_backward=False,
)

elif ti.static(static_rigid_sim_config.solver_type == gs.constraint_solver.Newton):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,7 @@ def _func_update_gradient(self, island, i_b):
i_e_ = self.contact_island.island_entity[island, i_b].start + i_island_entity
i_e = self.contact_island.entity_id[i_e_, i_b]
self._solver.mass_mat_mask[i_e_, i_b] = True
self._solver._func_solve_mass_batched(self.grad, self.Mgrad, i_b)
self._solver._func_solve_mass_batched(self.grad, self.Mgrad, None, i_b)
for i_e in range(self._solver.n_entities):
self._solver.mass_mat_mask[i_e, i_b] = True
elif ti.static(self._solver_type == gs.constraint_solver.Newton):
Expand Down
Loading
Loading