Skip to content
218 changes: 217 additions & 1 deletion genesis/engine/entities/rigid_entity/rigid_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from genesis.utils import mjcf as mju
from genesis.utils import terrain as tu
from genesis.utils import urdf as uu
from genesis.utils.misc import ALLOCATE_TENSOR_WARNING, DeprecationError, ti_to_torch
from genesis.utils.misc import ALLOCATE_TENSOR_WARNING, DeprecationError, ti_to_torch, to_gs_tensor
from genesis.engine.states.entities import RigidEntityState

from ..base_entity import Entity
from .rigid_equality import RigidEquality
Expand Down Expand Up @@ -95,6 +96,23 @@ def __init__(

self._load_model()

self.init_tgt_vars()
self.init_ckpt()

def init_tgt_keys(self):

self._tgt_keys = ["pos", "quat", "qpos", "dofs_velocity"]

def init_tgt_vars(self):

# temp variable to store targets for next step
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not skip line on top of undocumented methods.

self._tgt = []
self._tgt_buffer = []
self.init_tgt_keys()

def init_ckpt(self):
self._ckpt = dict()

def _load_model(self):
self._links = gs.List()
self._joints = gs.List()
Expand Down Expand Up @@ -1460,6 +1478,7 @@ def _kernel_forward_kinematics(
entities_info,
rigid_global_info,
static_rigid_sim_config,
False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use keyword argument when passing unnamed variables.

)

ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL)
Expand All @@ -1486,6 +1505,7 @@ def _kernel_forward_kinematics(
entities_info,
rigid_global_info,
static_rigid_sim_config,
False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use keyword argument when passing unnamed variables.

)

# ------------------------------------------------------------------------------------
Expand Down Expand Up @@ -1620,6 +1640,128 @@ def plan_path(
# ------------------------------------------------------------------------------------
# ---------------------------------- control & io ------------------------------------
# ------------------------------------------------------------------------------------
def process_input(self, in_backward=False):
if in_backward:
# use negative index because buffer length might not be full
index = self._sim.cur_step_local - self._sim._steps_local
self._tgt = self._tgt_buffer[index].copy()
else:
self._tgt_buffer.append(self._tgt.copy())

# Apply targets sequentially
_tgt = self._tgt.copy()
for tgt in _tgt:
k = tgt["key"]
assert k in self._tgt_keys, f"Invalid target key: {k} not in {self._tgt_keys}"

# We do not need zero velocity here because if it was true, [set_dofs_velocity] from zero_velocity would
# be in [tgt]
zero_velocity = False
if k == "pos":
_pos = tgt["pos"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern is weird. It looks like self._tgt should just be a dictionary of dictionary, using tgt["key"] as top-level key.

Note that built-in python dict is preserving insertion order as per official python standard, just like sequence containers list and tuple.

_envs_idx = tgt["envs_idx"]
_relative = tgt["relative"]
_unsafe = tgt["unsafe"]

self.set_pos(_pos, envs_idx=_envs_idx, relative=_relative, zero_velocity=zero_velocity, unsafe=_unsafe)
elif k == "quat":
_quat = tgt["quat"]
_envs_idx = tgt["envs_idx"]
_relative = tgt["relative"]
_unsafe = tgt["unsafe"]

self.set_quat(
_quat, envs_idx=_envs_idx, relative=_relative, zero_velocity=zero_velocity, unsafe=_unsafe
)
elif k == "qpos":
_qpos = tgt["qpos"]
_qs_idx_local = tgt["qs_idx_local"]
_envs_idx = tgt["envs_idx"]
_unsafe = tgt["unsafe"]

self.set_qpos(
_qpos, qs_idx_local=_qs_idx_local, envs_idx=_envs_idx, zero_velocity=zero_velocity, unsafe=_unsafe
)
elif k == "dofs_velocity":
_velocity = tgt["velocity"]
_dofs_idx_local = tgt["dofs_idx_local"]
_envs_idx = tgt["envs_idx"]
_unsafe = tgt["unsafe"]

self.set_dofs_velocity(_velocity, dofs_idx_local=_dofs_idx_local, envs_idx=_envs_idx, unsafe=_unsafe)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be a perfect use case for match-case!

for key, tgt in self._tgt.items():
    match key:
        case "pos":
            [...]
        case "quat":
            [...]
        case _:
            gs.raise_exception(f"Invalid target key: {k} not in {self._tgt_keys}")


self._tgt = []

def process_input_grad(self):
index = self._sim.cur_step_local - self._sim._steps_local
_tgt = self._tgt_buffer[index].copy()

for tgt in reversed(_tgt):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using _tgt vs tgt to distinguish top-level container from element is far from ideal. What about something like tgt vs data_kwargs or something like this?

k = tgt["key"]
assert k in self._tgt_keys, f"Invalid target key: {k} not in {self._tgt_keys}"
if k == "pos":
_pos = tgt["pos"]
_envs_idx = tgt["envs_idx"]
_relative = tgt["relative"]
_unsafe = tgt["unsafe"]

if _pos is not None and _pos.requires_grad:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you guarding for _pos is not None? It is you that added this None value in tgt in the first place? If so, it would make more sense to simple skip adding keys that are irrelevant.

_pos._backward_from_ti(self.set_pos_grad, _envs_idx, _relative, _unsafe)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you unpacking all these arguments to pass them back? It would be more clear and future proof to just do:

_pos._backward_from_ti(self.set_pos_grad, **tgt)


elif k == "quat":
_quat = tgt["quat"]
_envs_idx = tgt["envs_idx"]
_relative = tgt["relative"]
_unsafe = tgt["unsafe"]

if _quat is not None and _quat.requires_grad:
_quat._backward_from_ti(self.set_quat_grad, _envs_idx, _relative, _unsafe)

elif k == "qpos":
_qpos = tgt["qpos"]
_qs_idx_local = tgt["qs_idx_local"]
_envs_idx = tgt["envs_idx"]
_unsafe = tgt["unsafe"]

if _qpos is not None and _qpos.requires_grad:
# TODO: Not implemented yet
raise NotImplementedError("Backward pass for set_qpos_grad is not implemented yet.")

elif k == "dofs_velocity":
_velocity = tgt["velocity"]
_dofs_idx_local = tgt["dofs_idx_local"]
_envs_idx = tgt["envs_idx"]
_unsafe = tgt["unsafe"]

if _velocity is not None and _velocity.requires_grad:
_velocity._backward_from_ti(self.set_dofs_velocity_grad, _dofs_idx_local, _envs_idx, _unsafe)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you prefixing with _ ? What is the meaning of this? This usually means private variable, which obviously is not applicable here.


def save_ckpt(self, ckpt_name):
if ckpt_name not in self._ckpt:
self._ckpt[ckpt_name] = {}
self._ckpt[ckpt_name]["_tgt_buffer"] = self._tgt_buffer.copy()
self._tgt_buffer.clear()

def load_ckpt(self, ckpt_name):
self._tgt_buffer = self._ckpt[ckpt_name]["_tgt_buffer"].copy()

def reset_grad(self):
self._tgt_buffer.clear()

@gs.assert_built
def get_state(self):
state = RigidEntityState(self, self._sim.cur_step_global)

solver_state = self._solver.get_state()
pos = solver_state.links_pos[:, self._base_links_idx].squeeze(-2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only use negative squeeze index if it is necessary, i.e. to support optional starting dimensional. In this case, 1 ?

Beyond that, not that you could extract unsqueeze pos directly if you want:

pos = solver_state.links_pos[:, self._base_links_idx : self._base_links_idx + 1]

Note that pos = solver_state.links_pos[:, [self._base_links_idx]] would also work but it is MUCH slower because, first, the list of indices must be converted to a torch tensor on device, then it will rely on fancy indexing and return a copy instead of a view.

quat = solver_state.links_quat[:, self._base_links_idx].squeeze(-2)

assert state._pos.shape == pos.shape
assert state._quat.shape == quat.shape
state._pos = pos
state._quat = quat

return state

def get_joint(self, name=None, uid=None):
"""
Expand Down Expand Up @@ -1964,6 +2106,18 @@ def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, uns
envs_idx : None | array_like, optional
The indices of the environments. If None, all environments will be considered. Defaults to None.
"""
# Save in [tgt] for backward pass
self._tgt.append(
{
"key": "pos",
"pos": pos,
"envs_idx": envs_idx,
"relative": relative,
"zero_velocity": zero_velocity,
"unsafe": unsafe,
}
)

if not unsafe:
_pos = torch.as_tensor(pos, dtype=gs.tc_float, device=gs.device).contiguous()
if _pos is not pos:
Expand All @@ -1980,6 +2134,18 @@ def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, uns
if zero_velocity:
self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe)

@gs.assert_built
def set_pos_grad(self, envs_idx, relative, unsafe, pos_grad):
tmp_pos_grad = pos_grad.unsqueeze(-2).clone()
self._solver.set_base_links_pos_grad(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why cloning is necessary in this context? If it is not, just remove it. But you need to know whatever it is guaranteed to be continuous or discontinuous by design.

self._base_links_idx,
envs_idx,
relative,
unsafe,
tmp_pos_grad,
)
pos_grad.data = tmp_pos_grad.squeeze(-2)

@gs.assert_built
def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, unsafe=False):
"""
Expand All @@ -1998,6 +2164,17 @@ def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, u
envs_idx : None | array_like, optional
The indices of the environments. If None, all environments will be considered. Defaults to None.
"""
# Save in [tgt] for backward pass
self._tgt.append(
{
"key": "quat",
"quat": quat,
"envs_idx": envs_idx,
"relative": relative,
"zero_velocity": zero_velocity,
"unsafe": unsafe,
}
)
if not unsafe:
_quat = torch.as_tensor(quat, dtype=gs.tc_float, device=gs.device).contiguous()
if _quat is not quat:
Expand All @@ -2014,6 +2191,18 @@ def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, u
if zero_velocity:
self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe)

@gs.assert_built
def set_quat_grad(self, envs_idx, relative, unsafe, quat_grad):
tmp_quat_grad = quat_grad.unsqueeze(-2).clone()
self._solver.set_base_links_quat_grad(
self._base_links_idx,
envs_idx,
relative,
unsafe,
tmp_quat_grad,
)
quat_grad.data = tmp_quat_grad.squeeze(-2)

@gs.assert_built
def get_verts(self):
"""
Expand Down Expand Up @@ -2106,6 +2295,18 @@ def set_qpos(self, qpos, qs_idx_local=None, envs_idx=None, *, zero_velocity=True
zero_velocity : bool, optional
Whether to zero the velocity of all the entity's dofs. Defaults to True. This is a safety measure after a sudden change in entity pose.
"""
# Save in [tgt] for backward pass
self._tgt.append(
{
"key": "qpos",
"qpos": qpos,
"qs_idx_local": qs_idx_local,
"envs_idx": envs_idx,
"zero_velocity": zero_velocity,
"unsafe": unsafe,
}
)

qs_idx = self._get_idx(qs_idx_local, self.n_qs, self._q_start, unsafe=True)
self._solver.set_qpos(qpos, qs_idx, envs_idx, unsafe=unsafe, skip_forward=zero_velocity)
if zero_velocity:
Expand Down Expand Up @@ -2201,9 +2402,24 @@ def set_dofs_velocity(self, velocity=None, dofs_idx_local=None, envs_idx=None, *
envs_idx : None | array_like, optional
The indices of the environments. If None, all environments will be considered. Defaults to None.
"""
# Save in [tgt] for backward pass
self._tgt.append(
{
"key": "dofs_velocity",
"velocity": velocity,
"dofs_idx_local": dofs_idx_local,
"envs_idx": envs_idx,
"unsafe": unsafe,
}
)
dofs_idx = self._get_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True)
self._solver.set_dofs_velocity(velocity, dofs_idx, envs_idx, skip_forward=False, unsafe=unsafe)

@gs.assert_built
def set_dofs_velocity_grad(self, dofs_idx_local, envs_idx, unsafe, velocity_grad):
dofs_idx = self._get_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True)
self._solver.set_dofs_velocity_grad(dofs_idx, envs_idx, unsafe, velocity_grad)

@gs.assert_built
def set_dofs_frictionloss(self, frictionloss, dofs_idx_local=None, envs_idx=None, *, unsafe=False):
"""
Expand Down
17 changes: 6 additions & 11 deletions genesis/engine/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,18 +268,13 @@ def f_global_to_s_global(self, f_global):
# ------------------------------------------------------------------------------------

def step(self, in_backward=False):
if self._rigid_only: # "Only Advance!" --Thomas Wade :P
for _ in range(self._substeps):
self.rigid_solver.substep()
self._cur_substep_global += 1
else:
self.process_input(in_backward=in_backward)
for _ in range(self._substeps):
self.substep(self.cur_substep_local)
self.process_input(in_backward=in_backward)
for _ in range(self._substeps):
self.substep(self.cur_substep_local)

self._cur_substep_global += 1
if self.cur_substep_local == 0 and not in_backward:
self.save_ckpt()
self._cur_substep_global += 1
if self.cur_substep_local == 0 and not in_backward:
self.save_ckpt()

if self.rigid_solver.is_active:
self.rigid_solver.clear_external_force()
Expand Down
6 changes: 6 additions & 0 deletions genesis/engine/solvers/rigid/constraint_noslip.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is noslip feature supported when enabling gradient computation? If not, you should raise an exception at init.

Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@ def kernel_build_efc_AR_b(
rigid_solver.func_solve_mass_batched(
constraint_state.Mgrad,
constraint_state.Mgrad,
constraint_state.Mgrad, # this will not be used anyway because is_backward is False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get why specifying a different out variable for the value and its gradient is an option in the first place. It seems that it is never the case in your PR. If so, I would suggest removing this option that is adding complexity (and therefore increasing maintenance burden because in theory it should be unit-tested) without any benefit.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After checking, it seems that another variable is specified when is_backward=True, so here if I understand correctly you are setting it to anything because it will not be used in practice. This does work in practice but I don't like it much... What about defining some extra free variable PLACEHOLDER in array_class (0D taichi tensor of type array_class.V) that we could use everywhere an argument is not used? This would clarify the intend and avoid any mistake because you cannot do much with such tensor.

i_b,
entities_info=entities_info,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
is_backward=False,
)

# AR[r, c] = J[c, :] * tmp
Expand Down Expand Up @@ -192,10 +194,12 @@ def kernel_dual_finish(
rigid_solver.func_solve_mass_batched(
vec=constraint_state.qfrc_constraint,
out=constraint_state.qacc,
out_bw=constraint_state.qacc, # this will not be used anyway because is_backward is False
i_b=i_b,
entities_info=entities_info,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
is_backward=False,
)

for i_d in range(n_dofs):
Expand Down Expand Up @@ -284,10 +288,12 @@ def compute_A_diag(
rigid_solver.func_solve_mass_batched(
constraint_state.Mgrad,
constraint_state.Mgrad,
constraint_state.Mgrad, # this will not be used anyway because is_backward is False
i_b,
entities_info=entities_info,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
is_backward=False,
)

# Ai = Ji * tmp
Expand Down
2 changes: 2 additions & 0 deletions genesis/engine/solvers/rigid/constraint_solver_decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1987,10 +1987,12 @@ def func_update_gradient(
rigid_solver.func_solve_mass_batched(
constraint_state.grad,
constraint_state.Mgrad,
constraint_state.Mgrad, # this will not be used anyway because is_backward is False
i_b,
entities_info=entities_info,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
is_backward=False,
)

elif ti.static(static_rigid_sim_config.solver_type == gs.constraint_solver.Newton):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,7 @@ def _func_update_gradient(self, island, i_b):
i_e_ = self.contact_island.island_entity[island, i_b].start + i_island_entity
i_e = self.contact_island.entity_id[i_e_, i_b]
self._solver._mass_mat_mask[i_e_, i_b] = 1
self._solver._func_solve_mass_batched(self.grad, self.Mgrad, i_b)
self._solver._func_solve_mass_batched(self.grad, self.Mgrad, None, i_b)
for i_e in range(self._solver.n_entities):
self._solver._mass_mat_mask[i_e, i_b] = 1
elif ti.static(self._solver_type == gs.constraint_solver.Newton):
Expand Down
Loading
Loading