diff --git a/genesis/engine/entities/rigid_entity/rigid_entity.py b/genesis/engine/entities/rigid_entity/rigid_entity.py index 27569abd7..b95e77589 100644 --- a/genesis/engine/entities/rigid_entity/rigid_entity.py +++ b/genesis/engine/entities/rigid_entity/rigid_entity.py @@ -17,7 +17,8 @@ from genesis.utils import mjcf as mju from genesis.utils import terrain as tu from genesis.utils import urdf as uu -from genesis.utils.misc import ALLOCATE_TENSOR_WARNING, DeprecationError, ti_to_torch +from genesis.utils.misc import ALLOCATE_TENSOR_WARNING, DeprecationError, ti_to_torch, to_gs_tensor +from genesis.engine.states.entities import RigidEntityState from ..base_entity import Entity from .rigid_equality import RigidEquality @@ -95,6 +96,22 @@ def __init__( self._load_model() + # Initialize target variables and checkpoint + self._tgt_keys = ["pos", "quat", "qpos", "dofs_velocity"] + self._tgt = dict() + self._tgt_buffer = list() + self._ckpt = dict() + + def update_tgt(self, key, value): + # Set [self._tgt] value while keeping the insertion order between keys. When a new key is inserted or an existing + # key is updated, the new element should be inserted at the end of the dict. This is because we need to keep + # the insertion order to correctly pass the gradients in the backward pass. + self._tgt.pop(key, None) + self._tgt[key] = value + + def init_ckpt(self): + pass + def _load_model(self): self._links = gs.List() self._joints = gs.List() @@ -1447,6 +1464,7 @@ def _kernel_forward_kinematics( entities_info, rigid_global_info, static_rigid_sim_config, + is_backward=False, ) ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) @@ -1476,6 +1494,7 @@ def _kernel_forward_kinematics( entities_info, rigid_global_info, static_rigid_sim_config, + is_backward=False, ) # ------------------------------------------------------------------------------------ @@ -1610,6 +1629,104 @@ def plan_path( # ------------------------------------------------------------------------------------ # ---------------------------------- control & io ------------------------------------ # ------------------------------------------------------------------------------------ + def process_input(self, in_backward=False): + if in_backward: + # use negative index because buffer length might not be full + index = self._sim.cur_step_local - self._sim._steps_local + self._tgt = self._tgt_buffer[index].copy() + else: + self._tgt_buffer.append(self._tgt.copy()) + + # Apply targets in the order of insertion + for key in self._tgt.keys(): + data_kwargs = self._tgt[key] + + # We do not need zero velocity here because if it was true, [set_dofs_velocity] from zero_velocity would + # be in [tgt] + if "zero_velocity" in data_kwargs: + data_kwargs["zero_velocity"] = False + # Do not update [tgt], as input information is finalized at this point + data_kwargs["update_tgt"] = False + + match key: + case "pos": + self.set_pos(**data_kwargs) + case "quat": + self.set_quat(**data_kwargs) + case "qpos": + self.set_qpos(**data_kwargs) + case "dofs_velocity": + self.set_dofs_velocity(**data_kwargs) + case _: + gs.raise_exception(f"Invalid target key: {key} not in {self._tgt_keys}") + + self._tgt = dict() + + def process_input_grad(self): + index = self._sim.cur_step_local - self._sim._steps_local + for key in reversed(self._tgt_buffer[index].keys()): + data_kwargs = self._tgt_buffer[index][key] + + match key: + # We need to unpack the data_kwargs because [_backward_from_ti] only supports positional arguments + case "pos": + pos = data_kwargs.pop("pos") + if pos.requires_grad: + pos._backward_from_ti( + self.set_pos_grad, data_kwargs["envs_idx"], data_kwargs["relative"], data_kwargs["unsafe"] + ) + + case "quat": + quat = data_kwargs.pop("quat") + if quat.requires_grad: + quat._backward_from_ti( + self.set_quat_grad, data_kwargs["envs_idx"], data_kwargs["relative"], data_kwargs["unsafe"] + ) + + case "qpos": + qpos = data_kwargs.pop("qpos") + if qpos.requires_grad: + raise NotImplementedError("Backward pass for set_qpos_grad is not implemented yet.") + + case "dofs_velocity": + velocity = data_kwargs.pop("velocity") + # [velocity] could be None when we want to zero the velocity (see set_dofs_velocity of RigidSolver) + if velocity is not None and velocity.requires_grad: + velocity._backward_from_ti( + self.set_dofs_velocity_grad, + data_kwargs["dofs_idx_local"], + data_kwargs["envs_idx"], + data_kwargs["unsafe"], + ) + case _: + gs.raise_exception(f"Invalid target key: {key} not in {self._tgt_keys}") + + def save_ckpt(self, ckpt_name): + if ckpt_name not in self._ckpt: + self._ckpt[ckpt_name] = {} + self._ckpt[ckpt_name]["_tgt_buffer"] = self._tgt_buffer.copy() + self._tgt_buffer.clear() + + def load_ckpt(self, ckpt_name): + self._tgt_buffer = self._ckpt[ckpt_name]["_tgt_buffer"].copy() + + def reset_grad(self): + self._tgt_buffer.clear() + + @gs.assert_built + def get_state(self): + state = RigidEntityState(self, self._sim.cur_step_global) + + solver_state = self._solver.get_state() + pos = solver_state.links_pos[:, self.base_link_idx] + quat = solver_state.links_quat[:, self.base_link_idx] + + assert state._pos.shape == pos.shape + assert state._quat.shape == quat.shape + state._pos = pos + state._quat = quat + + return state def get_joint(self, name=None, uid=None): """ @@ -1937,7 +2054,7 @@ def get_links_invweight(self, links_idx_local=None, envs_idx=None, *, unsafe=Fal return self._solver.get_links_invweight(links_idx, envs_idx, unsafe=unsafe) @gs.assert_built - def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, unsafe=False): + def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, unsafe=False, update_tgt=True): """ Set position of the entity's base link. @@ -1954,6 +2071,19 @@ def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, uns envs_idx : None | array_like, optional The indices of the environments. If None, all environments will be considered. Defaults to None. """ + # Save in [tgt] for backward pass + if update_tgt: + self.update_tgt( + "pos", + { + "pos": pos, + "envs_idx": envs_idx, + "relative": relative, + "zero_velocity": zero_velocity, + "unsafe": unsafe, + }, + ) + if not unsafe: _pos = torch.as_tensor(pos, dtype=gs.tc_float, device=gs.device).contiguous() if _pos is not pos: @@ -1971,7 +2101,19 @@ def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, uns self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe) @gs.assert_built - def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, unsafe=False): + def set_pos_grad(self, envs_idx, relative, unsafe, pos_grad): + tmp_pos_grad = pos_grad.unsqueeze(-2) + self._solver.set_base_links_pos_grad( + self._base_links_idx, + envs_idx, + relative, + unsafe, + tmp_pos_grad, + ) + pos_grad.data = tmp_pos_grad.squeeze(-2) + + @gs.assert_built + def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, unsafe=False, update_tgt=True): """ Set quaternion of the entity's base link. @@ -1988,6 +2130,18 @@ def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, u envs_idx : None | array_like, optional The indices of the environments. If None, all environments will be considered. Defaults to None. """ + # Save in [tgt] for backward pass + if update_tgt: + self.update_tgt( + "quat", + { + "quat": quat, + "envs_idx": envs_idx, + "relative": relative, + "zero_velocity": zero_velocity, + "unsafe": unsafe, + }, + ) if not unsafe: _quat = torch.as_tensor(quat, dtype=gs.tc_float, device=gs.device).contiguous() if _quat is not quat: @@ -2004,6 +2158,18 @@ def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, u if zero_velocity: self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe) + @gs.assert_built + def set_quat_grad(self, envs_idx, relative, unsafe, quat_grad): + tmp_quat_grad = quat_grad.unsqueeze(-2) + self._solver.set_base_links_quat_grad( + self._base_links_idx, + envs_idx, + relative, + unsafe, + tmp_quat_grad, + ) + quat_grad.data = tmp_quat_grad.squeeze(-2) + @gs.assert_built def get_verts(self): """ @@ -2081,7 +2247,7 @@ def _get_idx(self, idx_local, idx_local_max, idx_global_start=0, *, unsafe=False return idx_global @gs.assert_built - def set_qpos(self, qpos, qs_idx_local=None, envs_idx=None, *, zero_velocity=True, unsafe=False): + def set_qpos(self, qpos, qs_idx_local=None, envs_idx=None, *, zero_velocity=True, unsafe=False, update_tgt=True): """ Set the entity's qpos. @@ -2096,6 +2262,19 @@ def set_qpos(self, qpos, qs_idx_local=None, envs_idx=None, *, zero_velocity=True zero_velocity : bool, optional Whether to zero the velocity of all the entity's dofs. Defaults to True. This is a safety measure after a sudden change in entity pose. """ + # Save in [tgt] for backward pass + if update_tgt: + self.update_tgt( + "qpos", + { + "qpos": qpos, + "qs_idx_local": qs_idx_local, + "envs_idx": envs_idx, + "zero_velocity": zero_velocity, + "unsafe": unsafe, + }, + ) + qs_idx = self._get_idx(qs_idx_local, self.n_qs, self._q_start, unsafe=True) self._solver.set_qpos(qpos, qs_idx, envs_idx, unsafe=unsafe, skip_forward=zero_velocity) if zero_velocity: @@ -2178,7 +2357,7 @@ def set_dofs_damping(self, damping, dofs_idx_local=None, envs_idx=None, *, unsaf self._solver.set_dofs_damping(damping, dofs_idx, envs_idx, unsafe=unsafe) @gs.assert_built - def set_dofs_velocity(self, velocity=None, dofs_idx_local=None, envs_idx=None, *, unsafe=False): + def set_dofs_velocity(self, velocity=None, dofs_idx_local=None, envs_idx=None, *, unsafe=False, update_tgt=True): """ Set the entity's dofs' velocity. @@ -2191,9 +2370,25 @@ def set_dofs_velocity(self, velocity=None, dofs_idx_local=None, envs_idx=None, * envs_idx : None | array_like, optional The indices of the environments. If None, all environments will be considered. Defaults to None. """ + # Save in [tgt] for backward pass + if update_tgt: + self.update_tgt( + "dofs_velocity", + { + "velocity": velocity, + "dofs_idx_local": dofs_idx_local, + "envs_idx": envs_idx, + "unsafe": unsafe, + }, + ) dofs_idx = self._get_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True) self._solver.set_dofs_velocity(velocity, dofs_idx, envs_idx, skip_forward=False, unsafe=unsafe) + @gs.assert_built + def set_dofs_velocity_grad(self, dofs_idx_local, envs_idx, unsafe, velocity_grad): + dofs_idx = self._get_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True) + self._solver.set_dofs_velocity_grad(dofs_idx, envs_idx, unsafe, velocity_grad) + @gs.assert_built def set_dofs_frictionloss(self, frictionloss, dofs_idx_local=None, envs_idx=None, *, unsafe=False): """ diff --git a/genesis/engine/simulator.py b/genesis/engine/simulator.py index bc8328db2..9b6755b54 100644 --- a/genesis/engine/simulator.py +++ b/genesis/engine/simulator.py @@ -272,18 +272,13 @@ def f_global_to_s_global(self, f_global): # ------------------------------------------------------------------------------------ def step(self, in_backward=False): - if self._rigid_only: # "Only Advance!" --Thomas Wade :P - for _ in range(self._substeps): - self.rigid_solver.substep() - self._cur_substep_global += 1 - else: - self.process_input(in_backward=in_backward) - for _ in range(self._substeps): - self.substep(self.cur_substep_local) + self.process_input(in_backward=in_backward) + for _ in range(self._substeps): + self.substep(self.cur_substep_local) - self._cur_substep_global += 1 - if self.cur_substep_local == 0 and not in_backward: - self.save_ckpt() + self._cur_substep_global += 1 + if self.cur_substep_local == 0 and not in_backward: + self.save_ckpt() if self.rigid_solver.is_active: self.rigid_solver.clear_external_force() diff --git a/genesis/engine/solvers/rigid/constraint_noslip.py b/genesis/engine/solvers/rigid/constraint_noslip.py index e9114f593..57b5f8f44 100644 --- a/genesis/engine/solvers/rigid/constraint_noslip.py +++ b/genesis/engine/solvers/rigid/constraint_noslip.py @@ -35,10 +35,12 @@ def kernel_build_efc_AR_b( rigid_solver.func_solve_mass_batched( constraint_state.Mgrad, constraint_state.Mgrad, + array_class.PLACEHOLDER, i_b, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=False, ) # AR[r, c] = J[c, :] * tmp @@ -189,10 +191,12 @@ def kernel_dual_finish( rigid_solver.func_solve_mass_batched( vec=constraint_state.qfrc_constraint, out=constraint_state.qacc, + out_bw=array_class.PLACEHOLDER, i_b=i_b, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=False, ) for i_d in range(n_dofs): @@ -280,10 +284,12 @@ def compute_A_diag( rigid_solver.func_solve_mass_batched( constraint_state.Mgrad, constraint_state.Mgrad, + array_class.PLACEHOLDER, i_b, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=False, ) # Ai = Ji * tmp diff --git a/genesis/engine/solvers/rigid/constraint_solver_decomp.py b/genesis/engine/solvers/rigid/constraint_solver_decomp.py index 5ae367bbc..71bf8ee22 100644 --- a/genesis/engine/solvers/rigid/constraint_solver_decomp.py +++ b/genesis/engine/solvers/rigid/constraint_solver_decomp.py @@ -1984,10 +1984,12 @@ def func_update_gradient( rigid_solver.func_solve_mass_batched( constraint_state.grad, constraint_state.Mgrad, + array_class.PLACEHOLDER, i_b, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=False, ) elif ti.static(static_rigid_sim_config.solver_type == gs.constraint_solver.Newton): diff --git a/genesis/engine/solvers/rigid/constraint_solver_decomp_island.py b/genesis/engine/solvers/rigid/constraint_solver_decomp_island.py index 000a1dc20..df7afe348 100644 --- a/genesis/engine/solvers/rigid/constraint_solver_decomp_island.py +++ b/genesis/engine/solvers/rigid/constraint_solver_decomp_island.py @@ -990,7 +990,7 @@ def _func_update_gradient(self, island, i_b): i_e_ = self.contact_island.island_entity[island, i_b].start + i_island_entity i_e = self.contact_island.entity_id[i_e_, i_b] self._solver.mass_mat_mask[i_e_, i_b] = True - self._solver._func_solve_mass_batched(self.grad, self.Mgrad, i_b) + self._solver._func_solve_mass_batched(self.grad, self.Mgrad, None, i_b) for i_e in range(self._solver.n_entities): self._solver.mass_mat_mask[i_e, i_b] = True elif ti.static(self._solver_type == gs.constraint_solver.Newton): diff --git a/genesis/engine/solvers/rigid/rigid_solver_decomp.py b/genesis/engine/solvers/rigid/rigid_solver_decomp.py index c650902d4..786e64861 100644 --- a/genesis/engine/solvers/rigid/rigid_solver_decomp.py +++ b/genesis/engine/solvers/rigid/rigid_solver_decomp.py @@ -11,7 +11,7 @@ import genesis.utils.geom as gu from genesis.engine.entities import AvatarEntity, DroneEntity, RigidEntity from genesis.engine.entities.base_entity import Entity -from genesis.engine.states.solvers import RigidSolverState +from genesis.engine.states import QueriedStates, RigidSolverState from genesis.options.solvers import RigidOptions from genesis.utils import linalg as lu from genesis.utils.misc import ( @@ -36,7 +36,6 @@ from genesis.engine.scene import Scene from genesis.engine.simulator import Simulator - # minimum constraint impedance IMP_MIN = 0.0001 # maximum constraint impedance @@ -135,6 +134,13 @@ def __init__(self, scene: "Scene", sim: "Simulator", options: RigidOptions) -> N self.qpos: ti.Template | ti.types.NDArray | None = None + self._queried_states = QueriedStates() + + self._ckpt = dict() + + def init_ckpt(self): + pass + def add_entity(self, idx, material, morph, surface, visualize_contact) -> Entity: if isinstance(material, gs.materials.Avatar): EntityClass = AvatarEntity @@ -202,6 +208,14 @@ def build(self): self._n_entities = self.n_entities self._n_equalities = self.n_equalities + self._max_n_links_per_entity = self.max_n_links_per_entity + self._max_n_joints_per_link = self.max_n_joints_per_link + self._max_n_dofs_per_joint = self.max_n_dofs_per_joint + self._max_n_qs_per_link = self.max_n_qs_per_link + self._max_n_dofs_per_entity = self.max_n_dofs_per_entity + self._max_n_dofs_per_link = self.max_n_dofs_per_link + self._max_n_geoms_per_entity = self.max_n_geoms_per_entity + self._geoms = self.geoms self._vgeoms = self.vgeoms self._links = self.links @@ -250,10 +264,60 @@ def build(self): enable_joint_limit=getattr(self, "_enable_joint_limit", False), box_box_detection=getattr(self, "_box_box_detection", True), sparse_solve=getattr(self._options, "sparse_solve", False), - integrator=getattr(self, "_integrator", gs.integrator.implicitfast), + integrator=getattr(self, "_integrator", gs.integrator.approximate_implicitfast), solver_type=getattr(self._options, "constraint_solver", gs.constraint_solver.CG), ) + if self._static_rigid_sim_config.requires_grad: + if self._static_rigid_sim_config.use_hibernation: + gs.raise_exception("Hibernation is not supported yet when requires_grad is True") + if self._static_rigid_sim_config.integrator != gs.integrator.approximate_implicitfast: + gs.raise_exception( + "Only approximate_implicitfast integrator is supported yet when requires_grad is True." + ) + from genesis.engine.couplers import SAPCoupler, IPCCoupler + + if isinstance(self.sim.coupler, SAPCoupler): + gs.raise_exception("SAPCoupler is not supported yet when requires_grad is True.") + + if isinstance(self.sim.coupler, IPCCoupler): + gs.raise_exception("IPCCoupler is not supported yet when requires_grad is True.") + + if getattr(self._options, "noslip_iterations", 0) > 0: + gs.raise_exception("Noslip is not supported yet when requires_grad is True.") + + # Add terms for static inner loops, use 0 if not requires_grad to avoid re-compilation + self._static_rigid_sim_config.max_n_links_per_entity = ( + getattr(self, "_max_n_links_per_entity", 0) if self._static_rigid_sim_config.requires_grad else 0 + ) + self._static_rigid_sim_config.max_n_joints_per_link = ( + getattr(self, "_max_n_joints_per_link", 0) if self._static_rigid_sim_config.requires_grad else 0 + ) + self._static_rigid_sim_config.max_n_dofs_per_joint = ( + getattr(self, "_max_n_dofs_per_joint", 0) if self._static_rigid_sim_config.requires_grad else 0 + ) + self._static_rigid_sim_config.max_n_qs_per_link = ( + getattr(self, "_max_n_qs_per_link", 0) if self._static_rigid_sim_config.requires_grad else 0 + ) + self._static_rigid_sim_config.max_n_dofs_per_entity = ( + getattr(self, "_max_n_dofs_per_entity", 0) if self._static_rigid_sim_config.requires_grad else 0 + ) + self._static_rigid_sim_config.max_n_dofs_per_link = ( + getattr(self, "_max_n_dofs_per_link", 0) if self._static_rigid_sim_config.requires_grad else 0 + ) + self._static_rigid_sim_config.max_n_geoms_per_entity = ( + getattr(self, "_max_n_geoms_per_entity", 0) if self._static_rigid_sim_config.requires_grad else 0 + ) + self._static_rigid_sim_config.max_n_awake_links = ( + getattr(self, "_n_links", 0) if self._static_rigid_sim_config.requires_grad else 0 + ) + self._static_rigid_sim_config.max_n_awake_entities = ( + getattr(self, "_n_entities", 0) if self._static_rigid_sim_config.requires_grad else 0 + ) + self._static_rigid_sim_config.max_n_awake_dofs = ( + getattr(self, "_n_dofs", 0) if self._static_rigid_sim_config.requires_grad else 0 + ) + # when the migration is finished, we will remove the about two lines self._func_vel_at_point = func_vel_at_point self._func_apply_coupling_force = func_apply_coupling_force @@ -266,6 +330,7 @@ def build(self): self._errno = self.data_manager.errno self._rigid_global_info = self.data_manager.rigid_global_info + self._rigid_adjoint_cache = self.data_manager.rigid_adjoint_cache if self._use_hibernation: self.n_awake_dofs = self._rigid_global_info.n_awake_dofs self.awake_dofs = self._rigid_global_info.awake_dofs @@ -273,6 +338,11 @@ def build(self): self.awake_links = self._rigid_global_info.awake_links self.n_awake_entities = self._rigid_global_info.n_awake_entities self.awake_entities = self._rigid_global_info.awake_entities + if self._requires_grad: + self.dofs_state_adjoint_cache = self.data_manager.dofs_state_adjoint_cache + self.links_state_adjoint_cache = self.data_manager.links_state_adjoint_cache + self.joints_state_adjoint_cache = self.data_manager.joints_state_adjoint_cache + self.geoms_state_adjoint_cache = self.data_manager.geoms_state_adjoint_cache self._init_mass_mat() self._init_dof_fields() @@ -291,7 +361,7 @@ def build(self): self._init_constraint_solver() self._init_invweight_and_meaninertia(force_update=False) - self._func_update_geoms(self._scene._envs_idx, force_update_fixed_geoms=True) + self._func_update_geoms(self._scene._envs_idx, force_update_fixed_geoms=True, is_backward=False) def _init_invweight_and_meaninertia(self, envs_idx=None, *, force_update=True, unsafe=False): # Early return if no DoFs. This is essential to avoid segfault on CUDA. @@ -323,6 +393,7 @@ def _init_invweight_and_meaninertia(self, envs_idx=None, *, force_update=True, u rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, decompose=True, + is_backward=False, ) # Define some proxies for convenience @@ -825,12 +896,20 @@ def _get_links_data( return extract_slice(tensor, *batch_shape, keepdim, unsafe=unsafe) - def substep(self): + def substep(self, f): # from genesis.utils.tools import create_timer from genesis.engine.couplers import SAPCoupler self._links_state_cache.clear() + if f == 0: + kernel_save_adjoint_cache( + f=f, + dofs_state=self.dofs_state, + rigid_global_info=self._rigid_global_info, + rigid_adjoint_cache=self._rigid_adjoint_cache, + ) + kernel_step_1( links_state=self.links_state, links_info=self.links_info, @@ -845,6 +924,7 @@ def substep(self): rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, contact_island_state=self.constraint_solver.contact_island.contact_island_state, + is_backward=False, ) if isinstance(self.sim.coupler, SAPCoupler): @@ -852,6 +932,7 @@ def substep(self): dofs_state=self.dofs_state, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=False, ) else: self._func_constraint_force() @@ -870,8 +951,38 @@ def substep(self): rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, contact_island_state=self.constraint_solver.contact_island.contact_island_state, + is_backward=False, + ) + + kernel_copy_next_to_curr( + dofs_state=self.dofs_state, + rigid_global_info=self._rigid_global_info, ) + kernel_save_adjoint_cache( + f=f + 1, + dofs_state=self.dofs_state, + rigid_global_info=self._rigid_global_info, + rigid_adjoint_cache=self._rigid_adjoint_cache, + ) + + if not self._static_rigid_sim_config.enable_mujoco_compatibility: + kernel_update_cartesian_space( + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + geoms_info=self.geoms_info, + geoms_state=self.geoms_state, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + force_update_fixed_geoms=False, + is_backward=False, + ) + def check_errno(self): match kernel_get_errno(self._errno): case 1: @@ -932,6 +1043,7 @@ def _func_forward_dynamics(self): rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, contact_island_state=self.constraint_solver.contact_island.contact_island_state, + is_backward=False, ) def _func_update_acc(self): @@ -942,6 +1054,7 @@ def _func_update_acc(self): entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=False, ) def _func_forward_kinematics_entity(self, i_e, envs_idx): @@ -957,6 +1070,7 @@ def _func_forward_kinematics_entity(self, i_e, envs_idx): entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=False, ) def _func_integrate_dq_entity(self, dq, i_e, i_b, respect_joint_limit): @@ -973,7 +1087,7 @@ def _func_integrate_dq_entity(self, dq, i_e, i_b, respect_joint_limit): static_rigid_sim_config=self._static_rigid_sim_config, ) - def _func_update_geoms(self, envs_idx, *, force_update_fixed_geoms=False): + def _func_update_geoms(self, envs_idx, *, force_update_fixed_geoms=False, is_backward=False): kernel_update_geoms( envs_idx, entities_info=self.entities_info, @@ -983,6 +1097,7 @@ def _func_update_geoms(self, envs_idx, *, force_update_fixed_geoms=False): rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, force_update_fixed_geoms=force_update_fixed_geoms, + is_backward=is_backward, ) def _process_dim(self, tensor, envs_idx=None): @@ -1121,10 +1236,155 @@ def substep_pre_coupling(self, f): return # Run Genesis rigid simulation step - self.substep() + self.substep(f) def substep_pre_coupling_grad(self, f): - pass + # Load the current state from adjoint cache + curr_qpos = self._rigid_adjoint_cache.qpos.to_numpy()[f] + curr_dofs_vel = self._rigid_adjoint_cache.dofs_vel.to_numpy()[f] + curr_dofs_acc = self._rigid_adjoint_cache.dofs_acc.to_numpy()[f] + self._rigid_global_info.qpos.from_numpy(curr_qpos) + self.dofs_state.vel.from_numpy(curr_dofs_vel) + self.dofs_state.acc.from_numpy(curr_dofs_acc) + + # =================== Forward substep ====================== + if not self._enable_mujoco_compatibility: + kernel_update_cartesian_space( + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + geoms_state=self.geoms_state, + geoms_info=self.geoms_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + force_update_fixed_geoms=False, + is_backward=False, + ) + # Save results of [update_cartesian_space] to adjoint cache + kernel_copy_cartesian_space( + src_dofs_state=self.dofs_state, + src_links_state=self.links_state, + src_joints_state=self.joints_state, + src_geoms_state=self.geoms_state, + dst_dofs_state=self.dofs_state_adjoint_cache, + dst_links_state=self.links_state_adjoint_cache, + dst_joints_state=self.joints_state_adjoint_cache, + dst_geoms_state=self.geoms_state_adjoint_cache, + ) + + self.substep(f) + + # =================== Backward substep ====================== + if not self._enable_mujoco_compatibility: + kernel_update_cartesian_space.grad( + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + geoms_state=self.geoms_state, + geoms_info=self.geoms_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + force_update_fixed_geoms=False, + is_backward=True, + ) + + qpos_grad = self._rigid_global_info.qpos.grad.to_numpy() + dofs_vel_grad = self.dofs_state.vel.grad.to_numpy() + if np.isnan(qpos_grad).sum() > 0 or np.isnan(dofs_vel_grad).sum() > 0: + gs.raise_exception(f"Nan grad in qpos or dofs_vel found at step {self._sim.cur_step_global}") + + kernel_copy_next_to_curr.grad( + dofs_state=self.dofs_state, + rigid_global_info=self._rigid_global_info, + ) + + # Load the current state from adjoint cache, as it was overwritten by [kernel_copy_next_to_curr] + self._rigid_global_info.qpos.from_numpy(curr_qpos) + self.dofs_state.vel.from_numpy(curr_dofs_vel) + + if not self._enable_mujoco_compatibility: + # Load the previous outputs of [kernel_update_cartesian_space], as it was overwritten if we disabled mujoco + # compatibility + kernel_copy_cartesian_space( + src_dofs_state=self.dofs_state_adjoint_cache, + src_links_state=self.links_state_adjoint_cache, + src_joints_state=self.joints_state_adjoint_cache, + src_geoms_state=self.geoms_state_adjoint_cache, + dst_dofs_state=self.dofs_state, + dst_links_state=self.links_state, + dst_joints_state=self.joints_state, + dst_geoms_state=self.geoms_state, + ) + + kernel_step_2.grad( + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + links_info=self.links_info, + links_state=self.links_state, + joints_info=self.joints_info, + joints_state=self.joints_state, + entities_state=self.entities_state, + entities_info=self.entities_info, + geoms_info=self.geoms_info, + geoms_state=self.geoms_state, + collider_state=self.collider._collider_state, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + contact_island_state=self.constraint_solver.contact_island.contact_island_state, + is_backward=True, + ) + + kernel_compute_qacc.grad( + dofs_state=self.dofs_state, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=True, + ) + + # Load the current dofs_acc from adjoint cache, as it was overwritten by [kernel_compute_qacc] + self.dofs_state.acc.from_numpy(curr_dofs_acc) + + kernel_forward_dynamics_without_qacc.grad( + links_state=self.links_state, + links_info=self.links_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + joints_info=self.joints_info, + entities_state=self.entities_state, + entities_info=self.entities_info, + geoms_state=self.geoms_state, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + contact_island_state=self.constraint_solver.contact_island.contact_island_state, + is_backward=True, + ) + + # If it was the very first substep, we need to backpropagate through the initial update of the cartesian space + if self._enable_mujoco_compatibility or self._sim.cur_substep_global == 0: + kernel_update_cartesian_space.grad( + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + geoms_state=self.geoms_state, + geoms_info=self.geoms_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + force_update_fixed_geoms=False, + is_backward=True, + ) def substep_post_coupling(self, f): from genesis.engine.couplers import SAPCoupler, IPCCoupler @@ -1137,6 +1397,7 @@ def substep_post_coupling(self, f): dofs_state=self.dofs_state, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=False, ) kernel_step_2( dofs_state=self.dofs_state, @@ -1153,7 +1414,34 @@ def substep_post_coupling(self, f): rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, contact_island_state=self.constraint_solver.contact_island.contact_island_state, + is_backward=False, + ) + kernel_copy_next_to_curr( + dofs_state=self.dofs_state, + rigid_global_info=self._rigid_global_info, + ) + kernel_save_adjoint_cache( + f=f + 1, + dofs_state=self.dofs_state, + rigid_global_info=self._rigid_global_info, + rigid_adjoint_cache=self._rigid_adjoint_cache, ) + if not self._static_rigid_sim_config.enable_mujoco_compatibility: + kernel_update_cartesian_space( + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + geoms_info=self.geoms_info, + geoms_state=self.geoms_state, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + force_update_fixed_geoms=False, + is_backward=False, + ) elif isinstance(self.sim.coupler, IPCCoupler): # For IPCCoupler, perform full rigid body computation in post-coupling phase # This allows IPC to handle rigid bodies during the coupling phase @@ -1161,25 +1449,59 @@ def substep_post_coupling(self, f): if self.sim.coupler.options.disable_genesis_ground_contact: original_enable_collision = self._enable_collision self._enable_collision = False - self.substep() + self.substep(f) self._enable_collision = original_enable_collision else: - self.substep() + self.substep(f) def substep_post_coupling_grad(self, f): pass def add_grad_from_state(self, state): - pass + if self.is_active: + qpos_grad = gs.zeros_like(state.qpos) + dofs_vel_grad = gs.zeros_like(state.dofs_vel) + links_pos_grad = gs.zeros_like(state.links_pos) + links_quat_grad = gs.zeros_like(state.links_quat) + + if state.qpos.grad is not None: + qpos_grad = state.qpos.grad + + if state.dofs_vel.grad is not None: + dofs_vel_grad = state.dofs_vel.grad + + if state.links_pos.grad is not None: + links_pos_grad = state.links_pos.grad + + if state.links_quat.grad is not None: + links_quat_grad = state.links_quat.grad + + kernel_get_state_grad( + qpos_grad=qpos_grad, + vel_grad=dofs_vel_grad, + links_pos_grad=links_pos_grad, + links_quat_grad=links_quat_grad, + links_state=self.links_state, + dofs_state=self.dofs_state, + geoms_state=self.geoms_state, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) def collect_output_grads(self): """ Collect gradients from downstream queried states. """ - pass + if self._sim.cur_step_global in self._queried_states: + # one step could have multiple states + assert len(self._queried_states[self._sim.cur_step_global]) == 1 + state = self._queried_states[self._sim.cur_step_global][0] + self.add_grad_from_state(state) def reset_grad(self): - pass + for entity in self._entities: + entity.reset_grad() + self._queried_states.clear() def update_geoms_render_T(self): kernel_update_geoms_render_T( @@ -1199,9 +1521,13 @@ def update_vgeoms_render_T(self): static_rigid_sim_config=self._static_rigid_sim_config, ) - def get_state(self, f): + def get_state(self, f=None): + s_global = self.sim.cur_step_global if self.is_active: - state = RigidSolverState(self._scene) + if s_global in self._queried_states: + return self._queried_states[s_global][0] + + state = RigidSolverState(self._scene, s_global) # qpos: ti.types.ndarray(), # vel: ti.types.ndarray(), @@ -1219,6 +1545,7 @@ def get_state(self, f): kernel_get_state( qpos=state.qpos, vel=state.dofs_vel, + acc=state.dofs_acc, links_pos=state.links_pos, links_quat=state.links_quat, i_pos_shift=state.i_pos_shift, @@ -1230,6 +1557,7 @@ def get_state(self, f): rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, ) + self._queried_states.append(state) else: state = None return state @@ -1240,6 +1568,7 @@ def set_state(self, f, state, envs_idx=None): kernel_set_state( qpos=state.qpos, dofs_vel=state.dofs_vel, + dofs_acc=state.dofs_acc, links_pos=state.links_pos, links_quat=state.links_quat, i_pos_shift=state.i_pos_shift, @@ -1265,6 +1594,7 @@ def set_state(self, f, state, envs_idx=None): entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=False, ) self._errno[None] = 0 @@ -1277,16 +1607,49 @@ def set_state(self, f, state, envs_idx=None): self._cur_step = -1 def process_input(self, in_backward=False): - pass + for entity in self._entities: + entity.process_input(in_backward=in_backward) def process_input_grad(self): - pass + for entity in self._entities: + entity.process_input_grad() def save_ckpt(self, ckpt_name): - pass + if ckpt_name not in self._ckpt: + self._ckpt[ckpt_name] = dict() + + self._ckpt[ckpt_name]["qpos"] = self._rigid_adjoint_cache.qpos.to_numpy() + self._ckpt[ckpt_name]["dofs_vel"] = self._rigid_adjoint_cache.dofs_vel.to_numpy() + self._ckpt[ckpt_name]["dofs_acc"] = self._rigid_adjoint_cache.dofs_acc.to_numpy() + + for entity in self._entities: + entity.save_ckpt(ckpt_name) def load_ckpt(self, ckpt_name): - pass + # Set first frame + self._rigid_global_info.qpos.from_numpy(self._ckpt[ckpt_name]["qpos"][0]) + self.dofs_state.vel.from_numpy(self._ckpt[ckpt_name]["dofs_vel"][0]) + self.dofs_state.acc.from_numpy(self._ckpt[ckpt_name]["dofs_acc"][0]) + + if not self._enable_mujoco_compatibility: + kernel_update_cartesian_space( + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + geoms_state=self.geoms_state, + geoms_info=self.geoms_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + force_update_fixed_geoms=False, + is_backward=False, + ) + + for entity in self._entities: + entity.load_ckpt(ckpt_name) @property def is_active(self): @@ -1378,6 +1741,15 @@ def _sanitize_1D_io_variables( gs.raise_exception("Expecting 1D output tensor.") return tensor, _inputs_idx, envs_idx + def _sanitize_1D_io_variables_grad( + self, + grad_after_sanitization, + grad_before_sanitization, + ): + if grad_after_sanitization.shape != grad_before_sanitization.shape: + gs.raise_exception("Shape of grad_after_sanitization and grad_before_sanitization do not match.") + return grad_after_sanitization + def _sanitize_2D_io_variables( self, tensor, @@ -1465,6 +1837,15 @@ def _sanitize_2D_io_variables( gs.raise_exception("Expecting 2D input tensor.") return tensor, _inputs_idx, envs_idx + def _sanitize_2D_io_variables_grad( + self, + grad_after_sanitization, + grad_before_sanitization, + ): + if grad_after_sanitization.shape != grad_before_sanitization.shape: + gs.raise_exception("Shape of grad_after_sanitization and grad_before_sanitization do not match.") + return grad_after_sanitization + def _get_qs_idx(self, qs_idx_local=None): return self._get_qs_idx_local(qs_idx_local) + self._q_start @@ -1518,8 +1899,40 @@ def set_base_links_pos( entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=False, ) + def set_base_links_pos_grad(self, links_idx, envs_idx, relative, unsafe, pos_grad): + if links_idx is None: + links_idx = self._base_links_idx + pos_grad_, links_idx, envs_idx = self._sanitize_2D_io_variables( + pos_grad.clone(), + links_idx, + self.n_links, + 3, + envs_idx, + idx_name="links_idx", + skip_allocation=True, + unsafe=unsafe, + ) + if self.n_envs == 0: + pos_grad_ = pos_grad_.unsqueeze(0) + if not unsafe and not torch.isin(links_idx, self._base_links_idx).all(): + gs.raise_exception("`links_idx` contains at least one link that is not a base link.") + kernel_set_links_pos_grad( + relative, + pos_grad_, + links_idx, + envs_idx, + links_info=self.links_info, + links_state=self.links_state, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + if self.n_envs == 0: + pos_grad_ = pos_grad_.squeeze(0) + pos_grad.data = self._sanitize_2D_io_variables_grad(pos_grad_, pos_grad) + def set_links_quat(self, quat, links_idx=None, envs_idx=None, *, skip_forward=False, unsafe=False): raise DeprecationError("This method has been removed. Please use 'set_base_links_quat' instead.") @@ -1569,8 +1982,41 @@ def set_base_links_quat( entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=False, ) + def set_base_links_quat_grad(self, links_idx, envs_idx, relative, unsafe, quat_grad): + if links_idx is None: + links_idx = self._base_links_idx + quat_grad_, links_idx, envs_idx = self._sanitize_2D_io_variables( + quat_grad.clone(), + links_idx, + self.n_links, + 4, + envs_idx, + idx_name="links_idx", + skip_allocation=True, + unsafe=unsafe, + ) + if self.n_envs == 0: + quat_grad_ = quat_grad_.unsqueeze(0) + if not unsafe and not torch.isin(links_idx, self._base_links_idx).all(): + gs.raise_exception("`links_idx` contains at least one link that is not a base link.") + assert relative == False, "Backward pass for relative quaternion is not supported yet." + kernel_set_links_quat_grad( + relative, + quat_grad_, + links_idx, + envs_idx, + links_info=self.links_info, + links_state=self.links_state, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + if self.n_envs == 0: + quat_grad_ = quat_grad_.squeeze(0) + quat_grad.data = self._sanitize_2D_io_variables_grad(quat_grad_, quat_grad) + def set_links_mass_shift(self, mass, links_idx=None, envs_idx=None, *, unsafe=False): mass, links_idx, envs_idx = self._sanitize_1D_io_variables( mass, links_idx, self.n_links, envs_idx, idx_name="links_idx", skip_allocation=True, unsafe=unsafe @@ -1647,6 +2093,7 @@ def set_qpos(self, qpos, qs_idx=None, envs_idx=None, *, skip_forward=False, unsa entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=False, ) def set_global_sol_params(self, sol_params, *, unsafe=False): @@ -1835,8 +2282,22 @@ def set_dofs_velocity(self, velocity, dofs_idx=None, envs_idx=None, *, skip_forw entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=False, ) + def set_dofs_velocity_grad(self, dofs_idx, envs_idx, unsafe, velocity_grad): + velocity_grad_, dofs_idx, envs_idx = self._sanitize_1D_io_variables( + velocity_grad, dofs_idx, self.n_dofs, envs_idx, skip_allocation=True, unsafe=unsafe + ) + if self.n_envs == 0: + velocity_grad_ = velocity_grad_.unsqueeze(0) + kernel_set_dofs_velocity_grad( + velocity_grad_, dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config + ) + if self.n_envs == 0: + velocity_grad_ = velocity_grad_.squeeze(0) + velocity_grad.data = self._sanitize_1D_io_variables_grad(velocity_grad_, velocity_grad) + def set_dofs_position(self, position, dofs_idx=None, envs_idx=None, *, skip_forward=False, unsafe=False): position, dofs_idx, envs_idx = self._sanitize_1D_io_variables( position, dofs_idx, self.n_dofs, envs_idx, skip_allocation=True, unsafe=unsafe @@ -1876,6 +2337,7 @@ def set_dofs_position(self, position, dofs_idx=None, envs_idx=None, *, skip_forw entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=False, ) def control_dofs_force(self, force, dofs_idx=None, envs_idx=None, *, unsafe=False): @@ -2401,18 +2863,36 @@ def n_links(self): return self._n_links return len(self.links) + @property + def max_n_links_per_entity(self): + if self.is_built: + return self._max_n_links_per_entity + return max([len(entity.links) for entity in self._entities]) if len(self._entities) > 0 else 0 + @property def n_joints(self): if self.is_built: return self._n_joints return len(self.joints) + @property + def max_n_joints_per_link(self): + if self.is_built: + return self._max_n_joints_per_link + return max([len(link.joints) for link in self.links]) if len(self.links) > 0 else 0 + @property def n_geoms(self): if self.is_built: return self._n_geoms return len(self.geoms) + @property + def max_n_geoms_per_entity(self): + if self.is_built: + return self._max_n_geoms_per_entity + return max([entity.n_geoms for entity in self._entities]) if len(self._entities) > 0 else 0 + @property def n_cells(self): if self.is_built: @@ -2473,12 +2953,36 @@ def n_qs(self): return self._n_qs return sum([entity.n_qs for entity in self._entities]) + @property + def max_n_qs_per_link(self): + if self.is_built: + return self._max_n_qs_per_link + return max([link.n_qs for link in self.links]) if len(self.links) > 0 else 0 + @property def n_dofs(self): if self.is_built: return self._n_dofs return sum(entity.n_dofs for entity in self._entities) + @property + def max_n_dofs_per_entity(self): + if self.is_built: + return self._max_n_dofs_per_entity + return max([entity.n_dofs for entity in self._entities]) if len(self._entities) > 0 else 0 + + @property + def max_n_dofs_per_link(self): + if self.is_built: + return self._max_n_dofs_per_link + return max([link.n_dofs for link in self.links]) if len(self.links) > 0 else 0 + + @property + def max_n_dofs_per_joint(self): + if self.is_built: + return self._max_n_dofs_per_joint + return max([joint.n_dofs for joint in self.joints]) if len(self.joints) > 0 else 0 + @property def init_qpos(self): if self._entities: @@ -2507,25 +3011,37 @@ def update_qacc_from_qvel_delta( dofs_state: array_class.DofsState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): n_dofs = dofs_state.ctrl_mode.shape[0] _B = dofs_state.ctrl_mode.shape[1] - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_d_ in range(rigid_global_info.n_awake_dofs[i_b]): - i_d = rigid_global_info.awake_dofs[i_d_, i_b] - dofs_state.acc[i_d, i_b] = ( - dofs_state.vel[i_d, i_b] - dofs_state.vel_prev[i_d, i_b] - ) / rigid_global_info.substep_dt[None] - dofs_state.vel[i_d, i_b] = dofs_state.vel_prev[i_d, i_b] - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_d, i_b in ti.ndrange(n_dofs, _B): - dofs_state.acc[i_d, i_b] = ( - dofs_state.vel[i_d, i_b] - dofs_state.vel_prev[i_d, i_b] - ) / rigid_global_info.substep_dt[None] - dofs_state.vel[i_d, i_b] = dofs_state.vel_prev[i_d, i_b] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ti.ndrange(1, _B) if ti.static(static_rigid_sim_config.use_hibernation) else ti.ndrange(n_dofs, _B): + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_dofs[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_dofs)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < (rigid_global_info.n_awake_dofs[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1): + i_d = ( + rigid_global_info.awake_dofs[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) + dofs_state.acc[i_d, i_b] = ( + dofs_state.vel[i_d, i_b] - dofs_state.vel_prev[i_d, i_b] + ) / rigid_global_info.substep_dt[None] + dofs_state.vel[i_d, i_b] = dofs_state.vel_prev[i_d, i_b] @ti.kernel @@ -2533,25 +3049,38 @@ def update_qvel( dofs_state: array_class.DofsState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): _B = dofs_state.vel.shape[1] n_dofs = dofs_state.vel.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_d_ in range(rigid_global_info.n_awake_dofs[i_b]): - i_d = rigid_global_info.awake_dofs[i_d_, i_b] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ti.ndrange(1, _B) if ti.static(static_rigid_sim_config.use_hibernation) else ti.ndrange(n_dofs, _B): + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_dofs[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_dofs)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < (rigid_global_info.n_awake_dofs[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1): + i_d = ( + rigid_global_info.awake_dofs[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) dofs_state.vel_prev[i_d, i_b] = dofs_state.vel[i_d, i_b] dofs_state.vel[i_d, i_b] = ( dofs_state.vel[i_d, i_b] + dofs_state.acc[i_d, i_b] * rigid_global_info.substep_dt[None] ) - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_d, i_b in ti.ndrange(n_dofs, _B): - dofs_state.vel_prev[i_d, i_b] = dofs_state.vel[i_d, i_b] - dofs_state.vel[i_d, i_b] = ( - dofs_state.vel[i_d, i_b] + dofs_state.acc[i_d, i_b] * rigid_global_info.substep_dt[None] - ) @ti.kernel(fastcache=gs.use_fastcache) @@ -2565,6 +3094,7 @@ def kernel_compute_mass_matrix( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), decompose: ti.template(), + is_backward: ti.template(), ): func_compute_mass_matrix( implicit_damping=False, @@ -2575,6 +3105,7 @@ def kernel_compute_mass_matrix( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) if decompose: func_factor_mass( @@ -2584,6 +3115,7 @@ def kernel_compute_mass_matrix( dofs_info=dofs_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) @@ -3150,6 +3682,7 @@ def kernel_forward_dynamics( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), contact_island_state: array_class.ContactIslandState, + is_backward: ti.template(), ): func_forward_dynamics( links_state=links_state, @@ -3163,6 +3696,7 @@ def kernel_forward_dynamics( rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, contact_island_state=contact_island_state, + is_backward=is_backward, ) @@ -3174,6 +3708,7 @@ def kernel_update_acc( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): func_update_acc( update_cacc=True, @@ -3183,6 +3718,7 @@ def kernel_update_acc( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) @@ -3207,170 +3743,258 @@ def func_compute_mass_matrix( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): - _B = links_state.pos.shape[1] - n_links = links_state.pos.shape[0] - n_entities = entities_info.n_links.shape[0] - n_dofs = dofs_state.f_ang.shape[0] + # crb initialize + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + ti.ndrange(1, links_state.pos.shape[1]) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.ndrange(links_state.pos.shape[0], links_state.pos.shape[1]) + ): + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_links)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < ( + rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): + i_l = ( + rigid_global_info.awake_links[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) - if ti.static(static_rigid_sim_config.use_hibernation): - # crb initialize - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] links_state.crb_inertial[i_l, i_b] = links_state.cinr_inertial[i_l, i_b] links_state.crb_pos[i_l, i_b] = links_state.cinr_pos[i_l, i_b] links_state.crb_quat[i_l, i_b] = links_state.cinr_quat[i_l, i_b] links_state.crb_mass[i_l, i_b] = links_state.cinr_mass[i_l, i_b] - # crb - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] - for i_l_ in range(entities_info.n_links[i_e]): - i_l = entities_info.link_end[i_e] - 1 - i_l_ - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - i_p = links_info.parent_idx[I_l] - - if i_p != -1: - links_state.crb_inertial[i_p, i_b] = ( - links_state.crb_inertial[i_p, i_b] + links_state.crb_inertial[i_l, i_b] - ) - links_state.crb_mass[i_p, i_b] = links_state.crb_mass[i_p, i_b] + links_state.crb_mass[i_l, i_b] + # crb + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + ti.ndrange(1, links_state.pos.shape[1]) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.ndrange(entities_info.n_links.shape[0], links_state.pos.shape[1]) + ): + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_entities)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < ( + rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): + i_e = ( + rigid_global_info.awake_entities[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) - links_state.crb_pos[i_p, i_b] = links_state.crb_pos[i_p, i_b] + links_state.crb_pos[i_l, i_b] - links_state.crb_quat[i_p, i_b] = links_state.crb_quat[i_p, i_b] + links_state.crb_quat[i_l, i_b] + for i in ( + range(entities_info.n_links[i_e]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) + ): + if i < entities_info.n_links[i_e]: + i_l = entities_info.link_end[i_e] - 1 - i + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + i_p = links_info.parent_idx[I_l] - # mass_mat - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] + if i_p != -1: + links_state.crb_inertial[i_p, i_b] += links_state.crb_inertial[i_l, i_b] + links_state.crb_mass[i_p, i_b] += links_state.crb_mass[i_l, i_b] + links_state.crb_pos[i_p, i_b] += links_state.crb_pos[i_l, i_b] + links_state.crb_quat[i_p, i_b] += links_state.crb_quat[i_l, i_b] + + # mass_mat + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + ti.ndrange(1, links_state.pos.shape[1]) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.ndrange(links_state.pos.shape[0], links_state.pos.shape[1]) + ): + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_links)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < ( + rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): + i_l = ( + rigid_global_info.awake_links[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - for i_d in range(links_info.dof_start[I_l], links_info.dof_end[I_l]): - dofs_state.f_ang[i_d, i_b], dofs_state.f_vel[i_d, i_b] = gu.inertial_mul( - links_state.crb_pos[i_l, i_b], - links_state.crb_inertial[i_l, i_b], - links_state.crb_mass[i_l, i_b], - dofs_state.cdof_vel[i_d, i_b], - dofs_state.cdof_ang[i_d, i_b], + + for i_d_ in ( + range(links_info.dof_start[I_l], links_info.dof_end[I_l]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_link)) + ): + i_d = i_d_ if ti.static(not is_backward) else links_info.dof_start[I_l] + i_d_ + + if i_d < links_info.dof_end[I_l]: + dofs_state.f_ang[i_d, i_b], dofs_state.f_vel[i_d, i_b] = gu.inertial_mul( + links_state.crb_pos[i_l, i_b], + links_state.crb_inertial[i_l, i_b], + links_state.crb_mass[i_l, i_b], + dofs_state.cdof_vel[i_d, i_b], + dofs_state.cdof_ang[i_d, i_b], + ) + + ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) + for i_0, i_b in ( + ti.ndrange(1, links_state.pos.shape[1]) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.ndrange(entities_info.n_links.shape[0], links_state.pos.shape[1]) + ): + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_entities)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < ( + rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): + i_e = ( + rigid_global_info.awake_entities[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) + + for i_d_, j_d_ in ( + ( + # Dynamic inner loop for forward pass + ti.ndrange( + (entities_info.dof_start[i_e], entities_info.dof_end[i_e]), + (entities_info.dof_start[i_e], entities_info.dof_end[i_e]), + ) ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static( + ti.ndrange( + static_rigid_sim_config.max_n_dofs_per_entity, + static_rigid_sim_config.max_n_dofs_per_entity, + ) + ) + ) + ): + i_d = i_d_ if ti.static(not is_backward) else entities_info.dof_start[i_e] + i_d_ + j_d = j_d_ if ti.static(not is_backward) else entities_info.dof_start[i_e] + j_d_ - ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) - for i_b in range(_B): - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] - for i_d in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): - for j_d in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): + if i_d < entities_info.dof_end[i_e] and j_d < entities_info.dof_end[i_e]: rigid_global_info.mass_mat[i_d, j_d, i_b] = ( dofs_state.f_ang[i_d, i_b].dot(dofs_state.cdof_ang[j_d, i_b]) + dofs_state.f_vel[i_d, i_b].dot(dofs_state.cdof_vel[j_d, i_b]) ) * rigid_global_info.mass_parent_mask[i_d, j_d] # FIXME: Updating the lower-part of the mass matrix is irrelevant - for i_d in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): - for j_d in range(i_d + 1, entities_info.dof_end[i_e]): + for i_d_, j_d_ in ( + ( + # Dynamic inner loop for forward pass + ti.ndrange( + (entities_info.dof_start[i_e], entities_info.dof_end[i_e]), + (entities_info.dof_start[i_e], entities_info.dof_end[i_e]), + ) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static( + ti.ndrange( + static_rigid_sim_config.max_n_dofs_per_entity, + static_rigid_sim_config.max_n_dofs_per_entity, + ) + ) + ) + ): + i_d = i_d_ if ti.static(not is_backward) else entities_info.dof_start[i_e] + i_d_ + j_d = j_d_ if ti.static(not is_backward) else entities_info.dof_start[i_e] + j_d_ + + if i_d < entities_info.dof_end[i_e] and j_d < entities_info.dof_end[i_e] and j_d > i_d: rigid_global_info.mass_mat[i_d, j_d, i_b] = rigid_global_info.mass_mat[j_d, i_d, i_b] + # In below blocks, we only update the diagonal terms of the mass matrix, which have not been read yet. # Take into account motor armature - for i_d in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - rigid_global_info.mass_mat[i_d, i_d, i_b] = ( - rigid_global_info.mass_mat[i_d, i_d, i_b] + dofs_info.armature[I_d] - ) + for i_d_ in ( + range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + i_d = i_d_ if ti.static(not is_backward) else entities_info.dof_start[i_e] + i_d_ + + if i_d < entities_info.dof_end[i_e]: + I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d + rigid_global_info.mass_mat[i_d, i_d, i_b] += dofs_info.armature[I_d] # Take into account first-order correction terms for implicit integration scheme right away if ti.static(implicit_damping): - for i_d in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - rigid_global_info.mass_mat[i_d, i_d, i_b] += ( - dofs_info.damping[I_d] * rigid_global_info.substep_dt[None] - ) - if ( - dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.POSITION - or dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.VELOCITY - ): - # qM += d qfrc_actuator / d qvel + for i_d_ in ( + range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + i_d = i_d_ if ti.static(not is_backward) else entities_info.dof_start[i_e] + i_d_ + + if i_d < entities_info.dof_end[i_e]: + I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d rigid_global_info.mass_mat[i_d, i_d, i_b] += ( - dofs_info.kv[I_d] * rigid_global_info.substep_dt[None] + dofs_info.damping[I_d] * rigid_global_info.substep_dt[None] ) - else: - # crb initialize - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l, i_b in ti.ndrange(n_links, _B): - links_state.crb_inertial[i_l, i_b] = links_state.cinr_inertial[i_l, i_b] - links_state.crb_pos[i_l, i_b] = links_state.cinr_pos[i_l, i_b] - links_state.crb_quat[i_l, i_b] = links_state.cinr_quat[i_l, i_b] - links_state.crb_mass[i_l, i_b] = links_state.cinr_mass[i_l, i_b] - - # crb - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_e, i_b in ti.ndrange(n_entities, _B): - for i_l_ in range(entities_info.n_links[i_e]): - i_l = entities_info.link_end[i_e] - 1 - i_l_ - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - i_p = links_info.parent_idx[I_l] - - if i_p != -1: - links_state.crb_inertial[i_p, i_b] = ( - links_state.crb_inertial[i_p, i_b] + links_state.crb_inertial[i_l, i_b] - ) - links_state.crb_mass[i_p, i_b] = links_state.crb_mass[i_p, i_b] + links_state.crb_mass[i_l, i_b] - - links_state.crb_pos[i_p, i_b] = links_state.crb_pos[i_p, i_b] + links_state.crb_pos[i_l, i_b] - links_state.crb_quat[i_p, i_b] = links_state.crb_quat[i_p, i_b] + links_state.crb_quat[i_l, i_b] - - # mass_mat - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l, i_b in ti.ndrange(n_links, _B): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - for i_d in range(links_info.dof_start[I_l], links_info.dof_end[I_l]): - dofs_state.f_ang[i_d, i_b], dofs_state.f_vel[i_d, i_b] = gu.inertial_mul( - links_state.crb_pos[i_l, i_b], - links_state.crb_inertial[i_l, i_b], - links_state.crb_mass[i_l, i_b], - dofs_state.cdof_vel[i_d, i_b], - dofs_state.cdof_ang[i_d, i_b], - ) - - ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) - for i_e, i_b in ti.ndrange(n_entities, _B): - for i_d, j_d in ti.ndrange( - (entities_info.dof_start[i_e], entities_info.dof_end[i_e]), - (entities_info.dof_start[i_e], entities_info.dof_end[i_e]), - ): - rigid_global_info.mass_mat[i_d, j_d, i_b] = ( - dofs_state.f_ang[i_d, i_b].dot(dofs_state.cdof_ang[j_d, i_b]) - + dofs_state.f_vel[i_d, i_b].dot(dofs_state.cdof_vel[j_d, i_b]) - ) * rigid_global_info.mass_parent_mask[i_d, j_d] - - # FIXME: Updating the lower-part of the mass matrix is irrelevant - for i_d in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): - for j_d in range(i_d + 1, entities_info.dof_end[i_e]): - rigid_global_info.mass_mat[i_d, j_d, i_b] = rigid_global_info.mass_mat[j_d, i_d, i_b] - - # Take into account motor armature - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_d, i_b in ti.ndrange(n_dofs, _B): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - rigid_global_info.mass_mat[i_d, i_d, i_b] = ( - rigid_global_info.mass_mat[i_d, i_d, i_b] + dofs_info.armature[I_d] - ) - - # Take into account first-order correction terms for implicit integration scheme right away - if ti.static(implicit_damping): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_d, i_b in ti.ndrange(n_dofs, _B): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - rigid_global_info.mass_mat[i_d, i_d, i_b] += dofs_info.damping[I_d] * rigid_global_info.substep_dt[None] - if ( - dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.POSITION - or dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.VELOCITY - ): - # qM += d qfrc_actuator / d qvel - rigid_global_info.mass_mat[i_d, i_d, i_b] += dofs_info.kv[I_d] * rigid_global_info.substep_dt[None] + if ( + dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.POSITION + or dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.VELOCITY + ): + # qM += d qfrc_actuator / d qvel + rigid_global_info.mass_mat[i_d, i_d, i_b] += ( + dofs_info.kv[I_d] * rigid_global_info.substep_dt[None] + ) @ti.func @@ -3381,115 +4005,281 @@ def func_factor_mass( dofs_info: array_class.DofsInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): """ Compute Cholesky decomposition (L^T @ D @ L) of mass matrix. """ - _B = dofs_state.ctrl_mode.shape[1] - n_entities = entities_info.n_links.shape[0] + if ti.static(not is_backward): + _B = dofs_state.ctrl_mode.shape[1] + n_entities = entities_info.n_links.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) - for i_b in range(_B): - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] + for i_0, i_b in ( + ti.ndrange(1, _B) if ti.static(static_rigid_sim_config.use_hibernation) else ti.ndrange(n_entities, _B) + ): + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_entities)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < ( + rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): + i_e = ( + rigid_global_info.awake_entities[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) - if rigid_global_info.mass_mat_mask[i_e, i_b]: - entity_dof_start = entities_info.dof_start[i_e] - entity_dof_end = entities_info.dof_end[i_e] - n_dofs = entities_info.n_dofs[i_e] + if rigid_global_info.mass_mat_mask[i_e, i_b]: + entity_dof_start = entities_info.dof_start[i_e] + entity_dof_end = entities_info.dof_end[i_e] + n_dofs = entities_info.n_dofs[i_e] - for i_d in range(entity_dof_start, entity_dof_end): - for j_d in range(entity_dof_start, i_d + 1): - rigid_global_info.mass_mat_L[i_d, j_d, i_b] = rigid_global_info.mass_mat[i_d, j_d, i_b] + for i_d_ in ( + range(entity_dof_start, entity_dof_end) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + i_d = i_d_ if ti.static(not is_backward) else entities_info.dof_start[i_e] + i_d_ - if ti.static(implicit_damping): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - rigid_global_info.mass_mat_L[i_d, i_d, i_b] += ( - dofs_info.damping[I_d] * rigid_global_info.substep_dt[None] - ) - if ti.static(static_rigid_sim_config.integrator == gs.integrator.implicitfast): - if ( - dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.POSITION - or dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.VELOCITY + if i_d < entity_dof_end: + for j_d_ in ( + range(entity_dof_start, i_d + 1) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) ): + j_d = j_d_ if ti.static(not is_backward) else entities_info.dof_start[i_e] + j_d_ + + if j_d < i_d + 1: + rigid_global_info.mass_mat_L[i_d, j_d, i_b] = rigid_global_info.mass_mat[ + i_d, j_d, i_b + ] + + if ti.static(implicit_damping): + I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d rigid_global_info.mass_mat_L[i_d, i_d, i_b] += ( - dofs_info.kv[I_d] * rigid_global_info.substep_dt[None] + dofs_info.damping[I_d] * rigid_global_info.substep_dt[None] ) - - for i_d_ in range(n_dofs): - i_d = entity_dof_end - i_d_ - 1 - rigid_global_info.mass_mat_D_inv[i_d, i_b] = 1.0 / rigid_global_info.mass_mat_L[i_d, i_d, i_b] - - for j_d_ in range(i_d - entity_dof_start): - j_d = i_d - j_d_ - 1 - a = rigid_global_info.mass_mat_L[i_d, j_d, i_b] * rigid_global_info.mass_mat_D_inv[i_d, i_b] - for k_d in range(entity_dof_start, j_d + 1): - rigid_global_info.mass_mat_L[j_d, k_d, i_b] -= ( - a * rigid_global_info.mass_mat_L[i_d, k_d, i_b] + if ti.static(static_rigid_sim_config.integrator == gs.integrator.implicitfast): + if (dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.POSITION) or ( + dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.VELOCITY + ): + rigid_global_info.mass_mat_L[i_d, i_d, i_b] += ( + dofs_info.kv[I_d] * rigid_global_info.substep_dt[None] + ) + + for i_d_ in ( + range(n_dofs) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + if i_d_ < n_dofs: + i_d = entity_dof_end - i_d_ - 1 + rigid_global_info.mass_mat_D_inv[i_d, i_b] = ( + 1.0 / rigid_global_info.mass_mat_L[i_d, i_d, i_b] ) - rigid_global_info.mass_mat_L[i_d, j_d, i_b] = a - # FIXME: Diagonal coeffs of L are ignored in computations, so no need to update them. - rigid_global_info.mass_mat_L[i_d, i_d, i_b] = 1.0 + for j_d_ in ( + range(i_d - entity_dof_start) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + if j_d_ < i_d - entity_dof_start: + j_d = i_d - j_d_ - 1 + a = ( + rigid_global_info.mass_mat_L[i_d, j_d, i_b] + * rigid_global_info.mass_mat_D_inv[i_d, i_b] + ) + + for k_d_ in ( + range(entity_dof_start, j_d + 1) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + k_d = ( + k_d_ + if ti.static(not is_backward) + else entities_info.dof_start[i_e] + k_d_ + ) + if k_d < j_d + 1: + rigid_global_info.mass_mat_L[j_d, k_d, i_b] -= ( + a * rigid_global_info.mass_mat_L[i_d, k_d, i_b] + ) + rigid_global_info.mass_mat_L[i_d, j_d, i_b] = a + + # FIXME: Diagonal coeffs of L are ignored in computations, so no need to update them. + rigid_global_info.mass_mat_L[i_d, i_d, i_b] = 1.0 + else: - ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) - for i_e, i_b in ti.ndrange(n_entities, _B): + # Cholesky decomposition that has safe access pattern and robust handling of divide by zero for AD. Even though + # it is logically equivalent to the above block, it shows slightly numerical difference in the result, and thus + # it fails for a unit test ("test_urdf_rope"), while passing all the others. TODO: Investigate if we can fix this + # and only use this block. + + # Assume this is the outermost loop + ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL)) + for i_e, i_b in ti.ndrange(entities_info.n_links.shape[0], dofs_state.ctrl_mode.shape[1]): if rigid_global_info.mass_mat_mask[i_e, i_b]: entity_dof_start = entities_info.dof_start[i_e] entity_dof_end = entities_info.dof_end[i_e] n_dofs = entities_info.n_dofs[i_e] - for i_d in range(entity_dof_start, entity_dof_end): - for j_d in range(entity_dof_start, i_d + 1): - rigid_global_info.mass_mat_L[i_d, j_d, i_b] = rigid_global_info.mass_mat[i_d, j_d, i_b] + for i_d0 in ( + range(n_dofs) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + if i_d0 < n_dofs: + i_d = entity_dof_start + i_d0 + i_pr = (entity_dof_start + entity_dof_end - 1) - i_d + for j_d_ in ( + range(entity_dof_start, i_d + 1) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + j_d = j_d_ if ti.static(not is_backward) else (j_d_ + entities_info.dof_start[i_e]) + j_pr = (entity_dof_start + entity_dof_end - 1) - j_d + if j_d < i_d + 1: + rigid_global_info.mass_mat_L_bw[0, i_pr, j_pr, i_b] = rigid_global_info.mass_mat[ + i_d, j_d, i_b + ] + rigid_global_info.mass_mat_L_bw[0, j_pr, i_pr, i_b] = rigid_global_info.mass_mat[ + i_d, j_d, i_b + ] - if ti.static(implicit_damping): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - rigid_global_info.mass_mat_L[i_d, i_d, i_b] += ( - dofs_info.damping[I_d] * rigid_global_info.substep_dt[None] - ) - if ti.static(static_rigid_sim_config.integrator == gs.integrator.implicitfast): - if ( - dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.POSITION - or dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.VELOCITY - ): - rigid_global_info.mass_mat_L[i_d, i_d, i_b] += ( - dofs_info.kv[I_d] * rigid_global_info.substep_dt[None] - ) + if ti.static(implicit_damping): + I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d + rigid_global_info.mass_mat_L_bw[0, i_pr, i_pr, i_b] += ( + dofs_info.damping[I_d] * rigid_global_info.substep_dt[None] + ) + if ti.static(static_rigid_sim_config.integrator == gs.integrator.implicitfast): + if ( + dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.POSITION + or dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.VELOCITY + ): + rigid_global_info.mass_mat_L_bw[0, i_pr, i_pr, i_b] += ( + dofs_info.kv[I_d] * rigid_global_info.substep_dt[None] + ) - for i_d_ in range(n_dofs): - i_d = entity_dof_end - i_d_ - 1 - rigid_global_info.mass_mat_D_inv[i_d, i_b] = 1.0 / rigid_global_info.mass_mat_L[i_d, i_d, i_b] + # Cholesky-Banachiewicz algorithm (in the perturbed indices), access pattern is safe for autodiff + # https://en.wikipedia.org/wiki/Cholesky_decomposition + for p_i0 in ( + range(n_dofs) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + for p_j0 in ( + range(p_i0 + 1) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + if p_i0 < n_dofs and p_j0 < n_dofs and p_j0 <= p_i0: + # j_pr <= i_pr + i_pr = entity_dof_start + p_i0 + j_pr = entity_dof_start + p_j0 + + sum = gs.ti_float(0.0) + for p_k0 in ( + range(p_j0) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + # k_pr < j_pr + if p_k0 < p_j0: + k_pr = entity_dof_start + p_k0 + sum += ( + rigid_global_info.mass_mat_L_bw[1, i_pr, k_pr, i_b] + * rigid_global_info.mass_mat_L_bw[1, j_pr, k_pr, i_b] + ) - for j_d_ in range(i_d - entity_dof_start): - j_d = i_d - j_d_ - 1 - a = rigid_global_info.mass_mat_L[i_d, j_d, i_b] * rigid_global_info.mass_mat_D_inv[i_d, i_b] - for k_d in range(entity_dof_start, j_d + 1): - rigid_global_info.mass_mat_L[j_d, k_d, i_b] -= ( - a * rigid_global_info.mass_mat_L[i_d, k_d, i_b] - ) - rigid_global_info.mass_mat_L[i_d, j_d, i_b] = a + a = rigid_global_info.mass_mat_L_bw[0, i_pr, j_pr, i_b] - sum + b = ti.math.clamp(rigid_global_info.mass_mat_L_bw[1, j_pr, j_pr, i_b], gs.EPS, ti.math.inf) + if p_i0 == p_j0: + rigid_global_info.mass_mat_L_bw[1, i_pr, j_pr, i_b] = ti.sqrt( + ti.math.clamp(a, gs.EPS, ti.math.inf) + ) + else: + rigid_global_info.mass_mat_L_bw[1, i_pr, j_pr, i_b] = a / b - # FIXME: Diagonal coeffs of L are ignored in computations, so no need to update them. - rigid_global_info.mass_mat_L[i_d, i_d, i_b] = 1.0 + for i_d0 in ( + range(n_dofs) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + for i_d1 in ( + range(i_d0 + 1) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + if i_d0 < n_dofs and i_d1 < n_dofs and i_d1 <= i_d0: + i_d = entity_dof_start + i_d0 + j_d = entity_dof_start + i_d1 + i_pr = (entity_dof_start + entity_dof_end - 1) - i_d + j_pr = (entity_dof_start + entity_dof_end - 1) - j_d + + a = rigid_global_info.mass_mat_L_bw[1, i_pr, i_pr, i_b] + rigid_global_info.mass_mat_L[i_d, j_d, i_b] = rigid_global_info.mass_mat_L_bw[ + 1, j_pr, i_pr, i_b + ] / ti.math.clamp(a, gs.EPS, ti.math.inf) + + if i_d == j_d: + rigid_global_info.mass_mat_D_inv[i_d, i_b] = 1.0 / ( + ti.math.clamp(a**2, gs.EPS, ti.math.inf) + ) @ti.func def func_solve_mass_batched( vec: array_class.V_ANNOTATION, out: array_class.V_ANNOTATION, + out_bw: array_class.V_ANNOTATION, i_b: ti.int32, entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): + # This loop is considered an inner loop + ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) + for i_0 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(entities_info.n_links.shape[0]) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_entities)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(entities_info.n_links.shape[0])) + ) + ): + n_entities = entities_info.n_links.shape[0] - n_entities = entities_info.n_links.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] + if i_0 < ( + rigid_global_info.n_awake_entities[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else n_entities + ): + i_e = ( + rigid_global_info.awake_entities[i_0, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) if rigid_global_info.mass_mat_mask[i_e, i_b]: entity_dof_start = entities_info.dof_start[i_e] @@ -3497,63 +4287,93 @@ def func_solve_mass_batched( n_dofs = entities_info.n_dofs[i_e] # Step 1: Solve w st. L^T @ w = y - for i_d_ in range(n_dofs): - i_d = entity_dof_end - i_d_ - 1 - out[i_d, i_b] = vec[i_d, i_b] - for j_d in range(i_d + 1, entity_dof_end): - out[i_d, i_b] -= rigid_global_info.mass_mat_L[j_d, i_d, i_b] * out[j_d, i_b] + for i_d_ in ( + range(n_dofs) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + if i_d_ < n_dofs: + i_d = entity_dof_end - i_d_ - 1 + if ti.static(is_backward): + out_bw[0, i_d, i_b] = vec[i_d, i_b] + else: + out[i_d, i_b] = vec[i_d, i_b] + + for j_d_ in ( + range(i_d + 1, entity_dof_end) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + j_d = j_d_ if ti.static(not is_backward) else (j_d_ + entities_info.dof_start[i_e]) + if j_d >= i_d + 1 and j_d < entity_dof_end: + # Since we read out[j_d, i_b], and j_d > i_d, which means that out[j_d, i_b] is already + # finalized at this point, we don't need to care about AD mutation rule. + if ti.static(is_backward): + out_bw[0, i_d, i_b] += -( + rigid_global_info.mass_mat_L[j_d, i_d, i_b] * out_bw[0, j_d, i_b] + ) + else: + out[i_d, i_b] += -(rigid_global_info.mass_mat_L[j_d, i_d, i_b] * out[j_d, i_b]) # Step 2: z = D^{-1} w - for i_d in range(entity_dof_start, entity_dof_end): - out[i_d, i_b] *= rigid_global_info.mass_mat_D_inv[i_d, i_b] + for i_d_ in ( + range(entity_dof_start, entity_dof_end) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + i_d = i_d_ if ti.static(not is_backward) else (i_d_ + entities_info.dof_start[i_e]) + if i_d < entity_dof_end: + if ti.static(is_backward): + out_bw[1, i_d, i_b] = out_bw[0, i_d, i_b] * rigid_global_info.mass_mat_D_inv[i_d, i_b] + else: + out[i_d, i_b] *= rigid_global_info.mass_mat_D_inv[i_d, i_b] # Step 3: Solve x st. L @ x = z - for i_d in range(entity_dof_start, entity_dof_end): - for j_d in range(entity_dof_start, i_d): - out[i_d, i_b] -= rigid_global_info.mass_mat_L[i_d, j_d, i_b] * out[j_d, i_b] - else: - ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) - for i_e in range(n_entities): - if rigid_global_info.mass_mat_mask[i_e, i_b]: - entity_dof_start = entities_info.dof_start[i_e] - entity_dof_end = entities_info.dof_end[i_e] - n_dofs = entities_info.n_dofs[i_e] - - # Step 1: Solve w st. L^T @ w = y - for i_d_ in range(n_dofs): - i_d = entity_dof_end - i_d_ - 1 - out[i_d, i_b] = vec[i_d, i_b] - for j_d in range(i_d + 1, entity_dof_end): - out[i_d, i_b] -= rigid_global_info.mass_mat_L[j_d, i_d, i_b] * out[j_d, i_b] - - # Step 2: z = D^{-1} w - for i_d in range(entity_dof_start, entity_dof_end): - out[i_d, i_b] *= rigid_global_info.mass_mat_D_inv[i_d, i_b] + for i_d_ in ( + range(entity_dof_start, entity_dof_end) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + i_d = i_d_ if ti.static(not is_backward) else (i_d_ + entities_info.dof_start[i_e]) + if i_d < entity_dof_end: + curr_out = out[i_d, i_b] + if ti.static(is_backward): + curr_out = out_bw[1, i_d, i_b] + + for j_d_ in ( + range(entity_dof_start, i_d) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + j_d = j_d_ if ti.static(not is_backward) else (j_d_ + entities_info.dof_start[i_e]) + if j_d < i_d: + curr_out += -(rigid_global_info.mass_mat_L[i_d, j_d, i_b] * out[j_d, i_b]) - # Step 3: Solve x st. L @ x = z - for i_d in range(entity_dof_start, entity_dof_end): - for j_d in range(entity_dof_start, i_d): - out[i_d, i_b] -= rigid_global_info.mass_mat_L[i_d, j_d, i_b] * out[j_d, i_b] + out[i_d, i_b] = curr_out @ti.func def func_solve_mass( vec: array_class.V_ANNOTATION, out: array_class.V_ANNOTATION, + out_bw: array_class.V_ANNOTATION, # Should not be None if backward entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): - _B = out.shape[1] + # This loop must be the outermost loop to be differentiable ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) - for i_b in range(_B): + for i_b in range(out.shape[1]): func_solve_mass_batched( vec, out, + out_bw, i_b, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) @@ -3629,6 +4449,7 @@ def kernel_rigid_entity_inverse_kinematics( entities_info, rigid_global_info, static_rigid_sim_config, + False, ) # compute error solved = True @@ -3757,6 +4578,7 @@ def kernel_rigid_entity_inverse_kinematics( entities_info, rigid_global_info, static_rigid_sim_config, + False, ) solved = True for i_ee in range(n_links): @@ -3866,6 +4688,7 @@ def kernel_rigid_entity_inverse_kinematics( entities_info, rigid_global_info, static_rigid_sim_config, + False, ) @@ -3884,6 +4707,7 @@ def func_forward_dynamics( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), contact_island_state: array_class.ContactIslandState, + is_backward: ti.template(), ): func_compute_mass_matrix( implicit_damping=ti.static(static_rigid_sim_config.integrator == gs.integrator.approximate_implicitfast), @@ -3894,6 +4718,7 @@ def func_forward_dynamics( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) func_factor_mass( implicit_damping=False, @@ -3902,6 +4727,7 @@ def func_forward_dynamics( dofs_info=dofs_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) func_torque_and_passive_force( entities_state=entities_state, @@ -3915,6 +4741,7 @@ def func_forward_dynamics( rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, contact_island_state=contact_island_state, + is_backward=is_backward, ) func_update_acc( update_cacc=False, @@ -3924,6 +4751,7 @@ def func_forward_dynamics( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) func_update_force( links_state=links_state, @@ -3931,6 +4759,7 @@ def func_forward_dynamics( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) # self._func_actuation() func_bias_force( @@ -3939,30 +4768,146 @@ def func_forward_dynamics( links_info=links_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) func_compute_qacc( dofs_state=dofs_state, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) -@ti.kernel(fastcache=gs.use_fastcache) -def kernel_clear_external_force( +@ti.kernel +def kernel_forward_dynamics_without_qacc( links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + joints_info: array_class.JointsInfo, + entities_state: array_class.EntitiesState, + entities_info: array_class.EntitiesInfo, + geoms_state: array_class.GeomsState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + contact_island_state: array_class.ContactIslandState, + is_backward: ti.template(), ): - func_clear_external_force( + func_compute_mass_matrix( + implicit_damping=ti.static(static_rigid_sim_config.integrator == gs.integrator.approximate_implicitfast), links_state=links_state, + links_info=links_info, + dofs_state=dofs_state, + dofs_info=dofs_info, + entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) - - -@ti.func -def func_update_cartesian_space( + func_factor_mass( + implicit_damping=False, + entities_info=entities_info, + dofs_state=dofs_state, + dofs_info=dofs_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, + ) + func_torque_and_passive_force( + entities_state=entities_state, + entities_info=entities_info, + dofs_state=dofs_state, + dofs_info=dofs_info, + links_state=links_state, + links_info=links_info, + joints_info=joints_info, + geoms_state=geoms_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + contact_island_state=contact_island_state, + is_backward=is_backward, + ) + func_update_acc( + update_cacc=False, + dofs_state=dofs_state, + links_info=links_info, + links_state=links_state, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, + ) + func_update_force( + links_state=links_state, + links_info=links_info, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, + ) + # self._func_actuation() + func_bias_force( + dofs_state=dofs_state, + links_state=links_state, + links_info=links_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, + ) + + +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_clear_external_force( + links_state: array_class.LinksState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + func_clear_external_force( + links_state=links_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + + +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_update_cartesian_space( + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_state: array_class.JointsState, + joints_info: array_class.JointsInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + geoms_info: array_class.GeomsInfo, + geoms_state: array_class.GeomsState, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), + force_update_fixed_geoms: ti.template(), + is_backward: ti.template(), +): + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_b in range(links_state.pos.shape[1]): + func_update_cartesian_space( + i_b=i_b, + links_state=links_state, + links_info=links_info, + joints_state=joints_state, + joints_info=joints_info, + dofs_state=dofs_state, + dofs_info=dofs_info, + geoms_info=geoms_info, + geoms_state=geoms_state, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + force_update_fixed_geoms=force_update_fixed_geoms, + is_backward=is_backward, + ) + + +@ti.func +def func_update_cartesian_space( i_b, links_state: array_class.LinksState, links_info: array_class.LinksInfo, @@ -3976,6 +4921,7 @@ def func_update_cartesian_space( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), force_update_fixed_geoms: ti.template(), + is_backward: ti.template(), ): func_forward_kinematics( i_b, @@ -3988,6 +4934,7 @@ def func_update_cartesian_space( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) func_COM_links( i_b, @@ -4000,6 +4947,7 @@ def func_update_cartesian_space( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) func_forward_velocity( i_b, @@ -4010,6 +4958,7 @@ def func_update_cartesian_space( dofs_state=dofs_state, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) func_update_geoms( @@ -4021,6 +4970,7 @@ def func_update_cartesian_space( rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, force_update_fixed_geoms=force_update_fixed_geoms, + is_backward=is_backward, ) @@ -4039,6 +4989,7 @@ def kernel_step_1( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), contact_island_state: array_class.ContactIslandState, + is_backward: ti.template(), ): if ti.static(static_rigid_sim_config.enable_mujoco_compatibility): _B = links_state.pos.shape[1] @@ -4058,6 +5009,7 @@ def kernel_step_1( rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, force_update_fixed_geoms=False, + is_backward=is_backward, ) func_forward_dynamics( @@ -4072,6 +5024,7 @@ def kernel_step_1( rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, contact_island_state=contact_island_state, + is_backward=is_backward, ) @@ -4082,6 +5035,7 @@ def func_implicit_damping( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): EPS = rigid_global_info.EPS[None] @@ -4101,16 +5055,22 @@ def func_implicit_damping( for i_e, i_b in ti.ndrange(n_entities, _B): entity_dof_start = entities_info.dof_start[i_e] entity_dof_end = entities_info.dof_end[i_e] - for i_d in range(entity_dof_start, entity_dof_end): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - if dofs_info.damping[I_d] > EPS: - rigid_global_info.mass_mat_mask[i_e, i_b] = True - if ti.static(static_rigid_sim_config.integrator != gs.integrator.Euler): - if ( - dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.POSITION - or dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.VELOCITY - ) and dofs_info.kv[I_d] > EPS: + for i_d_ in ( + range(entity_dof_start, entity_dof_end) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + i_d = i_d_ if ti.static(not is_backward) else entities_info.dof_start[i_e] + i_d_ + if i_d < entity_dof_end: + I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d + if dofs_info.damping[I_d] > EPS: rigid_global_info.mass_mat_mask[i_e, i_b] = True + if ti.static(static_rigid_sim_config.integrator != gs.integrator.Euler): + if ( + dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.POSITION + or dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.VELOCITY + ) and dofs_info.kv[I_d] > EPS: + rigid_global_info.mass_mat_mask[i_e, i_b] = True func_factor_mass( implicit_damping=True, @@ -4119,13 +5079,16 @@ def func_implicit_damping( dofs_info=dofs_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) func_solve_mass( vec=dofs_state.force, out=dofs_state.acc, + out_bw=dofs_state.acc_bw, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) # Disable pre-computed factorization mask right away @@ -4153,6 +5116,7 @@ def kernel_step_2( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), contact_island_state: array_class.ContactIslandState, + is_backward: ti.template(), ): # Position, Velocity and Acceleration data must be consistent when computing links acceleration, otherwise it # would not corresponds to anyting physical. There is no other way than doing this right before integration, @@ -4167,6 +5131,7 @@ def kernel_step_2( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) if ti.static(static_rigid_sim_config.integrator != gs.integrator.approximate_implicitfast): @@ -4176,6 +5141,7 @@ def kernel_step_2( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) func_integrate( @@ -4184,6 +5150,7 @@ def kernel_step_2( joints_info=joints_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) if ti.static(static_rigid_sim_config.use_hibernation): @@ -4206,26 +5173,6 @@ def kernel_step_2( static_rigid_sim_config=static_rigid_sim_config, ) - if ti.static(not static_rigid_sim_config.enable_mujoco_compatibility): - _B = links_state.pos.shape[1] - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - func_update_cartesian_space( - i_b=i_b, - links_state=links_state, - links_info=links_info, - joints_state=joints_state, - joints_info=joints_info, - dofs_state=dofs_state, - dofs_info=dofs_info, - geoms_info=geoms_info, - geoms_state=geoms_state, - entities_info=entities_info, - rigid_global_info=rigid_global_info, - static_rigid_sim_config=static_rigid_sim_config, - force_update_fixed_geoms=False, - ) - @ti.kernel(fastcache=gs.use_fastcache) def kernel_forward_kinematics_links_geoms( @@ -4241,6 +5188,7 @@ def kernel_forward_kinematics_links_geoms( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): for i_b_ in range(envs_idx.shape[0]): i_b = envs_idx[i_b_] @@ -4259,6 +5207,7 @@ def kernel_forward_kinematics_links_geoms( rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, force_update_fixed_geoms=True, + is_backward=is_backward, ) @@ -4274,26 +5223,65 @@ def func_COM_links( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): - EPS = rigid_global_info.EPS[None] - - n_links = links_info.root_idx.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l_ in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(links_info.root_idx.shape[0]) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_links)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(links_info.root_idx.shape[0])) + ) + ): + if i_l_ < ( + rigid_global_info.n_awake_links[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else links_info.root_idx.shape[0] + ): + i_l = ( + rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ + ) - links_state.root_COM[i_l, i_b].fill(0.0) + links_state.root_COM_bw[i_l, i_b].fill(0.0) links_state.mass_sum[i_l, i_b] = 0.0 - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l_ in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(links_info.root_idx.shape[0]) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_links)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(links_info.root_idx.shape[0])) + ) + ): + if i_l_ < ( + rigid_global_info.n_awake_links[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else links_info.root_idx.shape[0] + ): + i_l = ( + rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ + ) I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l mass = links_info.inertial_mass[I_l] + links_state.mass_shift[i_l, i_b] ( - links_state.i_pos[i_l, i_b], + links_state.i_pos_bw[i_l, i_b], links_state.i_quat[i_l, i_b], ) = gu.ti_transform_pos_quat_by_trans_quat( links_info.inertial_pos[I_l] + links_state.i_pos_shift[i_l, i_b], @@ -4304,32 +5292,95 @@ def func_COM_links( i_r = links_info.root_idx[I_l] links_state.mass_sum[i_r, i_b] += mass - links_state.root_COM[i_r, i_b] += mass * links_state.i_pos[i_l, i_b] + links_state.root_COM_bw[i_r, i_b] += mass * links_state.i_pos_bw[i_l, i_b] - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l_ in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(links_info.root_idx.shape[0]) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_links)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(links_info.root_idx.shape[0])) + ) + ): + if i_l_ < ( + rigid_global_info.n_awake_links[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else links_info.root_idx.shape[0] + ): + i_l = ( + rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ + ) I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l i_r = links_info.root_idx[I_l] - if i_l == i_r: - links_state.root_COM[i_l, i_b] = links_state.root_COM[i_l, i_b] / links_state.mass_sum[i_l, i_b] + if i_l == i_r and links_state.mass_sum[i_l, i_b] > 0.0: + links_state.root_COM[i_l, i_b] = links_state.root_COM_bw[i_l, i_b] / links_state.mass_sum[i_l, i_b] - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l_ in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(links_info.root_idx.shape[0]) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_links)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(links_info.root_idx.shape[0])) + ) + ): + if i_l_ < ( + rigid_global_info.n_awake_links[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else links_info.root_idx.shape[0] + ): + i_l = ( + rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ + ) I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l i_r = links_info.root_idx[I_l] links_state.root_COM[i_l, i_b] = links_state.root_COM[i_r, i_b] - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l_ in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(links_info.root_idx.shape[0]) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_links)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(links_info.root_idx.shape[0])) + ) + ): + if i_l_ < ( + rigid_global_info.n_awake_links[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else links_info.root_idx.shape[0] + ): + i_l = ( + rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ + ) I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l i_r = links_info.root_idx[I_l] - links_state.i_pos[i_l, i_b] = links_state.i_pos[i_l, i_b] - links_state.root_COM[i_l, i_b] + links_state.i_pos[i_l, i_b] = links_state.i_pos_bw[i_l, i_b] - links_state.root_COM[i_l, i_b] i_inertial = links_info.inertial_i[I_l] i_mass = links_info.inertial_mass[I_l] + links_state.mass_shift[i_l, i_b] @@ -4339,246 +5390,167 @@ def func_COM_links( links_state.cinr_quat[i_l, i_b], links_state.cinr_mass[i_l, i_b], ) = gu.ti_transform_inertia_by_trans_quat( - i_inertial, i_mass, links_state.i_pos[i_l, i_b], links_state.i_quat[i_l, i_b], EPS - ) - - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - if links_info.n_dofs[I_l] == 0: - continue - - i_p = links_info.parent_idx[I_l] - - _i_j = links_info.joint_start[I_l] - _I_j = [_i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else _i_j - joint_type = joints_info.type[_I_j] - - p_pos = ti.Vector.zero(gs.ti_float, 3) - p_quat = gu.ti_identity_quat() - if i_p != -1: - p_pos = links_state.pos[i_p, i_b] - p_quat = links_state.quat[i_p, i_b] - - if joint_type == gs.JOINT_TYPE.FREE or (links_info.is_fixed[I_l] and i_p == -1): - links_state.j_pos[i_l, i_b] = links_state.pos[i_l, i_b] - links_state.j_quat[i_l, i_b] = links_state.quat[i_l, i_b] - else: - ( - links_state.j_pos[i_l, i_b], - links_state.j_quat[i_l, i_b], - ) = gu.ti_transform_pos_quat_by_trans_quat(links_info.pos[I_l], links_info.quat[I_l], p_pos, p_quat) - - for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]): - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - - ( - links_state.j_pos[i_l, i_b], - links_state.j_quat[i_l, i_b], - ) = gu.ti_transform_pos_quat_by_trans_quat( - joints_info.pos[I_j], - gu.ti_identity_quat(), - links_state.j_pos[i_l, i_b], - links_state.j_quat[i_l, i_b], - ) - - # cdof_fn - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - if links_info.n_dofs[I_l] == 0: - continue - - i_j = links_info.joint_start[I_l] - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - joint_type = joints_info.type[I_j] - - if joint_type == gs.JOINT_TYPE.FREE: - for i_d in range(links_info.dof_start[I_l], links_info.dof_end[I_l]): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - dofs_state.cdof_vel[i_d, i_b] = dofs_info.motion_vel[I_d] - dofs_state.cdof_ang[i_d, i_b] = gu.ti_transform_by_quat( - dofs_info.motion_ang[I_d], links_state.j_quat[i_l, i_b] - ) - - offset_pos = links_state.root_COM[i_l, i_b] - links_state.j_pos[i_l, i_b] - ( - dofs_state.cdof_ang[i_d, i_b], - dofs_state.cdof_vel[i_d, i_b], - ) = gu.ti_transform_motion_by_trans_quat( - dofs_state.cdof_ang[i_d, i_b], - dofs_state.cdof_vel[i_d, i_b], - offset_pos, - gu.ti_identity_quat(), - ) - - dofs_state.cdofvel_ang[i_d, i_b] = dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] - dofs_state.cdofvel_vel[i_d, i_b] = dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] - - elif joint_type == gs.JOINT_TYPE.FIXED: - pass - else: - for i_d in range(links_info.dof_start[I_l], links_info.dof_end[I_l]): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - motion_vel = dofs_info.motion_vel[I_d] - motion_ang = dofs_info.motion_ang[I_d] - - dofs_state.cdof_ang[i_d, i_b] = gu.ti_transform_by_quat(motion_ang, links_state.j_quat[i_l, i_b]) - dofs_state.cdof_vel[i_d, i_b] = gu.ti_transform_by_quat(motion_vel, links_state.j_quat[i_l, i_b]) - - offset_pos = links_state.root_COM[i_l, i_b] - links_state.j_pos[i_l, i_b] - ( - dofs_state.cdof_ang[i_d, i_b], - dofs_state.cdof_vel[i_d, i_b], - ) = gu.ti_transform_motion_by_trans_quat( - dofs_state.cdof_ang[i_d, i_b], - dofs_state.cdof_vel[i_d, i_b], - offset_pos, - gu.ti_identity_quat(), - ) - - dofs_state.cdofvel_ang[i_d, i_b] = dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] - dofs_state.cdofvel_vel[i_d, i_b] = dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l in range(n_links): - links_state.root_COM[i_l, i_b].fill(0.0) - links_state.mass_sum[i_l, i_b] = 0.0 - - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l in range(n_links): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - - mass = links_info.inertial_mass[I_l] + links_state.mass_shift[i_l, i_b] - ( + i_inertial, + i_mass, links_state.i_pos[i_l, i_b], links_state.i_quat[i_l, i_b], - ) = gu.ti_transform_pos_quat_by_trans_quat( - links_info.inertial_pos[I_l] + links_state.i_pos_shift[i_l, i_b], - links_info.inertial_quat[I_l], - links_state.pos[i_l, i_b], - links_state.quat[i_l, i_b], + rigid_global_info.EPS[None], ) - i_r = links_info.root_idx[I_l] - links_state.mass_sum[i_r, i_b] += mass - links_state.root_COM[i_r, i_b] += mass * links_state.i_pos[i_l, i_b] - - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l in range(n_links): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - - i_r = links_info.root_idx[I_l] - if i_l == i_r: - if links_state.mass_sum[i_l, i_b] > 0.0: - links_state.root_COM[i_l, i_b] = links_state.root_COM[i_l, i_b] / links_state.mass_sum[i_l, i_b] - - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l in range(n_links): + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l_ in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(links_info.root_idx.shape[0]) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_links)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(links_info.root_idx.shape[0])) + ) + ): + if i_l_ < ( + rigid_global_info.n_awake_links[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else links_info.root_idx.shape[0] + ): + i_l = ( + rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ + ) I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - i_r = links_info.root_idx[I_l] - links_state.root_COM[i_l, i_b] = links_state.root_COM[i_r, i_b] - - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l in range(n_links): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + if links_info.n_dofs[I_l] > 0: + i_p = links_info.parent_idx[I_l] - i_r = links_info.root_idx[I_l] - links_state.i_pos[i_l, i_b] = links_state.i_pos[i_l, i_b] - links_state.root_COM[i_l, i_b] + _i_j = links_info.joint_start[I_l] + _I_j = [_i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else _i_j + joint_type = joints_info.type[_I_j] - i_inertial = links_info.inertial_i[I_l] - i_mass = links_info.inertial_mass[I_l] + links_state.mass_shift[i_l, i_b] - ( - links_state.cinr_inertial[i_l, i_b], - links_state.cinr_pos[i_l, i_b], - links_state.cinr_quat[i_l, i_b], - links_state.cinr_mass[i_l, i_b], - ) = gu.ti_transform_inertia_by_trans_quat( - i_inertial, i_mass, links_state.i_pos[i_l, i_b], links_state.i_quat[i_l, i_b], EPS - ) + p_pos = ti.Vector.zero(gs.ti_float, 3) + p_quat = gu.ti_identity_quat() + if i_p != -1: + p_pos = links_state.pos[i_p, i_b] + p_quat = links_state.quat[i_p, i_b] - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l in range(n_links): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - if links_info.n_dofs[I_l] == 0: - continue + if joint_type == gs.JOINT_TYPE.FREE or (links_info.is_fixed[I_l] and i_p == -1): + links_state.j_pos[i_l, i_b] = links_state.pos[i_l, i_b] + links_state.j_quat[i_l, i_b] = links_state.quat[i_l, i_b] + else: + ( + links_state.j_pos_bw[i_l, 0, i_b], + links_state.j_quat_bw[i_l, 0, i_b], + ) = gu.ti_transform_pos_quat_by_trans_quat(links_info.pos[I_l], links_info.quat[I_l], p_pos, p_quat) - i_p = links_info.parent_idx[I_l] + n_joints = links_info.joint_end[I_l] - links_info.joint_start[I_l] - _i_j = links_info.joint_start[I_l] - _I_j = [_i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else _i_j - joint_type = joints_info.type[_I_j] + for i_j_ in ( + range(n_joints) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_joints_per_link)) + ): + i_j = i_j_ + links_info.joint_start[I_l] - p_pos = ti.Vector.zero(gs.ti_float, 3) - p_quat = gu.ti_identity_quat() - if i_p != -1: - p_pos = links_state.pos[i_p, i_b] - p_quat = links_state.quat[i_p, i_b] + curr_i_j = 0 if ti.static(not is_backward) else i_j_ + next_i_j = 0 if ti.static(not is_backward) else i_j_ + 1 - if joint_type == gs.JOINT_TYPE.FREE or (links_info.is_fixed[I_l] and i_p == -1): - links_state.j_pos[i_l, i_b] = links_state.pos[i_l, i_b] - links_state.j_quat[i_l, i_b] = links_state.quat[i_l, i_b] - else: - ( - links_state.j_pos[i_l, i_b], - links_state.j_quat[i_l, i_b], - ) = gu.ti_transform_pos_quat_by_trans_quat(links_info.pos[I_l], links_info.quat[I_l], p_pos, p_quat) + if i_j < links_info.joint_end[I_l]: + I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]): - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + ( + links_state.j_pos_bw[i_l, next_i_j, i_b], + links_state.j_quat_bw[i_l, next_i_j, i_b], + ) = gu.ti_transform_pos_quat_by_trans_quat( + joints_info.pos[I_j], + gu.ti_identity_quat(), + links_state.j_pos_bw[i_l, curr_i_j, i_b], + links_state.j_quat_bw[i_l, curr_i_j, i_b], + ) - ( - links_state.j_pos[i_l, i_b], - links_state.j_quat[i_l, i_b], - ) = gu.ti_transform_pos_quat_by_trans_quat( - joints_info.pos[I_j], - gu.ti_identity_quat(), - links_state.j_pos[i_l, i_b], - links_state.j_quat[i_l, i_b], - ) + i_j_ = 0 if ti.static(not is_backward) else n_joints + links_state.j_pos[i_l, i_b] = links_state.j_pos_bw[i_l, i_j_, i_b] + links_state.j_quat[i_l, i_b] = links_state.j_quat_bw[i_l, i_j_, i_b] - # cdof_fn - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l in range(n_links): + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l_ in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(links_info.root_idx.shape[0]) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_links)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(links_info.root_idx.shape[0])) + ) + ): + if i_l_ < ( + rigid_global_info.n_awake_links[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else links_info.root_idx.shape[0] + ): + i_l = ( + rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ + ) I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - if links_info.n_dofs[I_l] == 0: - continue - for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]): - offset_pos = links_state.root_COM[i_l, i_b] - joints_state.xanchor[i_j, i_b] - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - joint_type = joints_info.type[I_j] - - dof_start = joints_info.dof_start[I_j] - - if joint_type == gs.JOINT_TYPE.REVOLUTE: - dofs_state.cdof_ang[dof_start, i_b] = joints_state.xaxis[i_j, i_b] - dofs_state.cdof_vel[dof_start, i_b] = joints_state.xaxis[i_j, i_b].cross(offset_pos) - elif joint_type == gs.JOINT_TYPE.PRISMATIC: - dofs_state.cdof_ang[dof_start, i_b] = ti.Vector.zero(gs.ti_float, 3) - dofs_state.cdof_vel[dof_start, i_b] = joints_state.xaxis[i_j, i_b] - elif joint_type == gs.JOINT_TYPE.SPHERICAL: - xmat_T = gu.ti_quat_to_R(links_state.quat[i_l, i_b], EPS).transpose() - for j in ti.static(range(3)): - dofs_state.cdof_ang[j + dof_start, i_b] = xmat_T[j, :] - dofs_state.cdof_vel[j + dof_start, i_b] = xmat_T[j, :].cross(offset_pos) - elif joint_type == gs.JOINT_TYPE.FREE: - for j in ti.static(range(3)): - dofs_state.cdof_ang[j + dof_start, i_b] = ti.Vector.zero(gs.ti_float, 3) - dofs_state.cdof_vel[j + dof_start, i_b] = ti.Vector.zero(gs.ti_float, 3) - dofs_state.cdof_vel[j + dof_start, i_b][j] = 1.0 - - xmat_T = gu.ti_quat_to_R(links_state.quat[i_l, i_b], EPS).transpose() - for j in ti.static(range(3)): - dofs_state.cdof_ang[j + dof_start + 3, i_b] = xmat_T[j, :] - dofs_state.cdof_vel[j + dof_start + 3, i_b] = xmat_T[j, :].cross(offset_pos) - - for i_d in range(dof_start, joints_info.dof_end[I_j]): - dofs_state.cdofvel_ang[i_d, i_b] = dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] - dofs_state.cdofvel_vel[i_d, i_b] = dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] + if links_info.n_dofs[I_l] > 0: + for i_j_ in ( + range(links_info.joint_start[I_l], links_info.joint_end[I_l]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_joints_per_link)) + ): + i_j = i_j_ if ti.static(not is_backward) else (i_j_ + links_info.joint_start[I_l]) + + if i_j < links_info.joint_end[I_l]: + offset_pos = links_state.root_COM[i_l, i_b] - joints_state.xanchor[i_j, i_b] + I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + joint_type = joints_info.type[I_j] + + dof_start = joints_info.dof_start[I_j] + + EPS = rigid_global_info.EPS[None] + if joint_type == gs.JOINT_TYPE.REVOLUTE: + dofs_state.cdof_ang[dof_start, i_b] = joints_state.xaxis[i_j, i_b] + dofs_state.cdof_vel[dof_start, i_b] = joints_state.xaxis[i_j, i_b].cross(offset_pos) + elif joint_type == gs.JOINT_TYPE.PRISMATIC: + dofs_state.cdof_ang[dof_start, i_b] = ti.Vector.zero(gs.ti_float, 3) + dofs_state.cdof_vel[dof_start, i_b] = joints_state.xaxis[i_j, i_b] + elif joint_type == gs.JOINT_TYPE.SPHERICAL: + xmat_T = gu.ti_quat_to_R(links_state.quat[i_l, i_b], EPS).transpose() + for i in ti.static(range(3)): + dofs_state.cdof_ang[i + dof_start, i_b] = xmat_T[i, :] + dofs_state.cdof_vel[i + dof_start, i_b] = xmat_T[i, :].cross(offset_pos) + elif joint_type == gs.JOINT_TYPE.FREE: + for i in ti.static(range(3)): + dofs_state.cdof_ang[i + dof_start, i_b] = ti.Vector.zero(gs.ti_float, 3) + dofs_state.cdof_vel[i + dof_start, i_b] = ti.Vector.zero(gs.ti_float, 3) + dofs_state.cdof_vel[i + dof_start, i_b][i] = 1.0 + + xmat_T = gu.ti_quat_to_R(links_state.quat[i_l, i_b], EPS).transpose() + for i in ti.static(range(3)): + dofs_state.cdof_ang[i + dof_start + 3, i_b] = xmat_T[i, :] + dofs_state.cdof_vel[i + dof_start + 3, i_b] = xmat_T[i, :].cross(offset_pos) + + for i_d_ in ( + range(dof_start, joints_info.dof_end[I_j]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_joint)) + ): + i_d = i_d_ if ti.static(not is_backward) else (i_d_ + dof_start) + if i_d < joints_info.dof_end[I_j]: + dofs_state.cdofvel_ang[i_d, i_b] = ( + dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] + ) + dofs_state.cdofvel_vel[i_d, i_b] = ( + dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] + ) @ti.func @@ -4593,28 +5565,32 @@ def func_forward_kinematics( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): - n_entities = entities_info.n_links.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] - func_forward_kinematics_entity( - i_e, - i_b, - links_state, - links_info, - joints_state, - joints_info, - dofs_state, - dofs_info, - entities_info, - rigid_global_info, - static_rigid_sim_config, + for i_e_ in ( + ( + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(entities_info.n_links.shape[0]) + ) + if ti.static(not is_backward) + else ( + ti.static(range(static_rigid_sim_config.max_n_awake_entities)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(entities_info.n_links.shape[0])) + ) + ): + if i_e_ < ( + rigid_global_info.n_awake_entities[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else entities_info.n_links.shape[0] + ): + i_e = ( + rigid_global_info.awake_entities[i_e_, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_e_ ) - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_e in range(n_entities): + func_forward_kinematics_entity( i_e, i_b, @@ -4627,6 +5603,7 @@ def func_forward_kinematics( entities_info, rigid_global_info, static_rigid_sim_config, + is_backward, ) @@ -4640,37 +5617,39 @@ def func_forward_velocity( dofs_state: array_class.DofsState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): n_entities = entities_info.n_links.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] - func_forward_velocity_entity( - i_e=i_e, - i_b=i_b, - entities_info=entities_info, - links_info=links_info, - links_state=links_state, - joints_info=joints_info, - dofs_state=dofs_state, - rigid_global_info=rigid_global_info, - static_rigid_sim_config=static_rigid_sim_config, - ) - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_e in range(n_entities): - func_forward_velocity_entity( - i_e=i_e, - i_b=i_b, - entities_info=entities_info, - links_info=links_info, - links_state=links_state, - joints_info=joints_info, - dofs_state=dofs_state, - rigid_global_info=rigid_global_info, - static_rigid_sim_config=static_rigid_sim_config, - ) + for i_e_ in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(n_entities) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_entities)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(entities_info.n_links.shape[0])) + ) + ): + i_e = ( + rigid_global_info.awake_entities[i_e_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_e_ + ) + func_forward_velocity_entity( + i_e=i_e, + i_b=i_b, + entities_info=entities_info, + links_info=links_info, + links_state=links_state, + joints_info=joints_info, + dofs_state=dofs_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, + ) @ti.kernel(fastcache=gs.use_fastcache) @@ -4686,6 +5665,7 @@ def kernel_forward_kinematics_entity( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): for i_b_ in range(envs_idx.shape[0]): i_b = envs_idx[i_b_] @@ -4702,6 +5682,7 @@ def kernel_forward_kinematics_entity( entities_info, rigid_global_info, static_rigid_sim_config, + is_backward, ) @@ -4718,106 +5699,143 @@ def func_forward_kinematics_entity( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): - EPS = rigid_global_info.EPS[None] + # Becomes static loop in backward pass, because we assume this loop is an inner loop + for i_l_ in ( + range(entities_info.link_start[i_e], entities_info.link_end[i_e]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) + ): + EPS = rigid_global_info.EPS[None] + i_l = i_l_ if ti.static(not is_backward) else (i_l_ + entities_info.link_start[i_e]) - for i_l in range(entities_info.link_start[i_e], entities_info.link_end[i_e]): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + if i_l < entities_info.link_end[i_e]: + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - pos = links_info.pos[I_l] - quat = links_info.quat[I_l] - if links_info.parent_idx[I_l] != -1: - parent_pos = links_state.pos[links_info.parent_idx[I_l], i_b] - parent_quat = links_state.quat[links_info.parent_idx[I_l], i_b] - pos = parent_pos + gu.ti_transform_by_quat(pos, parent_quat) - quat = gu.ti_transform_quat_by_quat(quat, parent_quat) + links_state.pos_bw[i_l, 0, i_b] = links_info.pos[I_l] + links_state.quat_bw[i_l, 0, i_b] = links_info.quat[I_l] + if links_info.parent_idx[I_l] != -1: + parent_pos = links_state.pos[links_info.parent_idx[I_l], i_b] + parent_quat = links_state.quat[links_info.parent_idx[I_l], i_b] + links_state.pos_bw[i_l, 0, i_b] = parent_pos + gu.ti_transform_by_quat(links_info.pos[I_l], parent_quat) + links_state.quat_bw[i_l, 0, i_b] = gu.ti_transform_quat_by_quat(links_info.quat[I_l], parent_quat) - for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]): - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - joint_type = joints_info.type[I_j] - q_start = joints_info.q_start[I_j] - dof_start = joints_info.dof_start[I_j] - I_d = [dof_start, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else dof_start + n_joints = links_info.joint_end[I_l] - links_info.joint_start[I_l] - # compute axis and anchor - if joint_type == gs.JOINT_TYPE.FREE: - joints_state.xanchor[i_j, i_b] = ti.Vector( - [ - rigid_global_info.qpos[q_start, i_b], - rigid_global_info.qpos[q_start + 1, i_b], - rigid_global_info.qpos[q_start + 2, i_b], - ] - ) - joints_state.xaxis[i_j, i_b] = ti.Vector([0.0, 0.0, 1.0]) - elif joint_type == gs.JOINT_TYPE.FIXED: - pass - else: - axis = ti.Vector([0.0, 0.0, 1.0], dt=gs.ti_float) - if joint_type == gs.JOINT_TYPE.REVOLUTE: - axis = dofs_info.motion_ang[I_d] - elif joint_type == gs.JOINT_TYPE.PRISMATIC: - axis = dofs_info.motion_vel[I_d] + for i_j_ in ( + range(n_joints) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_joints_per_link)) + ): + i_j = i_j_ + links_info.joint_start[I_l] - joints_state.xanchor[i_j, i_b] = gu.ti_transform_by_quat(joints_info.pos[I_j], quat) + pos - joints_state.xaxis[i_j, i_b] = gu.ti_transform_by_quat(axis, quat) + curr_i_j = 0 if ti.static(not is_backward) else i_j_ + next_i_j = 0 if ti.static(not is_backward) else i_j_ + 1 - if joint_type == gs.JOINT_TYPE.FREE: - pos = ti.Vector( - [ - rigid_global_info.qpos[q_start, i_b], - rigid_global_info.qpos[q_start + 1, i_b], - rigid_global_info.qpos[q_start + 2, i_b], - ], - dt=gs.ti_float, - ) - quat = ti.Vector( - [ - rigid_global_info.qpos[q_start + 3, i_b], - rigid_global_info.qpos[q_start + 4, i_b], - rigid_global_info.qpos[q_start + 5, i_b], - rigid_global_info.qpos[q_start + 6, i_b], - ], - dt=gs.ti_float, - ) - xyz = gu.ti_quat_to_xyz(quat, EPS) - for j in ti.static(range(3)): - dofs_state.pos[dof_start + j, i_b] = pos[j] - dofs_state.pos[dof_start + 3 + j, i_b] = xyz[j] - elif joint_type == gs.JOINT_TYPE.FIXED: - pass - elif joint_type == gs.JOINT_TYPE.SPHERICAL: - qloc = ti.Vector( - [ - rigid_global_info.qpos[q_start, i_b], - rigid_global_info.qpos[q_start + 1, i_b], - rigid_global_info.qpos[q_start + 2, i_b], - rigid_global_info.qpos[q_start + 3, i_b], - ], - dt=gs.ti_float, - ) - xyz = gu.ti_quat_to_xyz(qloc, EPS) - for j in ti.static(range(3)): - dofs_state.pos[dof_start + j, i_b] = xyz[j] - quat = gu.ti_transform_quat_by_quat(qloc, quat) - pos = joints_state.xanchor[i_j, i_b] - gu.ti_transform_by_quat(joints_info.pos[I_j], quat) - elif joint_type == gs.JOINT_TYPE.REVOLUTE: - axis = dofs_info.motion_ang[I_d] - dofs_state.pos[dof_start, i_b] = ( - rigid_global_info.qpos[q_start, i_b] - rigid_global_info.qpos0[q_start, i_b] - ) - qloc = gu.ti_rotvec_to_quat(axis * dofs_state.pos[dof_start, i_b], EPS) - quat = gu.ti_transform_quat_by_quat(qloc, quat) - pos = joints_state.xanchor[i_j, i_b] - gu.ti_transform_by_quat(joints_info.pos[I_j], quat) - else: # joint_type == gs.JOINT_TYPE.PRISMATIC: - dofs_state.pos[dof_start, i_b] = ( - rigid_global_info.qpos[q_start, i_b] - rigid_global_info.qpos0[q_start, i_b] - ) - pos = pos + joints_state.xaxis[i_j, i_b] * dofs_state.pos[dof_start, i_b] + if i_j < links_info.joint_end[I_l]: + I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + joint_type = joints_info.type[I_j] + q_start = joints_info.q_start[I_j] + dof_start = joints_info.dof_start[I_j] + I_d = [dof_start, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else dof_start + + # compute axis and anchor + if joint_type == gs.JOINT_TYPE.FREE: + joints_state.xanchor[i_j, i_b] = ti.Vector( + [ + rigid_global_info.qpos[q_start, i_b], + rigid_global_info.qpos[q_start + 1, i_b], + rigid_global_info.qpos[q_start + 2, i_b], + ] + ) + joints_state.xaxis[i_j, i_b] = ti.Vector([0.0, 0.0, 1.0]) + elif joint_type == gs.JOINT_TYPE.FIXED: + pass + else: + axis = ti.Vector([0.0, 0.0, 1.0], dt=gs.ti_float) + if joint_type == gs.JOINT_TYPE.REVOLUTE: + axis = dofs_info.motion_ang[I_d] + elif joint_type == gs.JOINT_TYPE.PRISMATIC: + axis = dofs_info.motion_vel[I_d] + + joints_state.xanchor[i_j, i_b] = ( + gu.ti_transform_by_quat(joints_info.pos[I_j], links_state.quat_bw[i_l, curr_i_j, i_b]) + + links_state.pos_bw[i_l, curr_i_j, i_b] + ) + joints_state.xaxis[i_j, i_b] = gu.ti_transform_by_quat( + axis, links_state.quat_bw[i_l, curr_i_j, i_b] + ) + + if joint_type == gs.JOINT_TYPE.FREE: + links_state.pos_bw[i_l, next_i_j, i_b] = ti.Vector( + [ + rigid_global_info.qpos[q_start, i_b], + rigid_global_info.qpos[q_start + 1, i_b], + rigid_global_info.qpos[q_start + 2, i_b], + ], + dt=gs.ti_float, + ) + links_state.quat_bw[i_l, next_i_j, i_b] = ti.Vector( + [ + rigid_global_info.qpos[q_start + 3, i_b], + rigid_global_info.qpos[q_start + 4, i_b], + rigid_global_info.qpos[q_start + 5, i_b], + rigid_global_info.qpos[q_start + 6, i_b], + ], + dt=gs.ti_float, + ) + xyz = gu.ti_quat_to_xyz(links_state.quat_bw[i_l, next_i_j, i_b], EPS) + for j in ti.static(range(3)): + dofs_state.pos[dof_start + j, i_b] = links_state.pos_bw[i_l, next_i_j, i_b][j] + dofs_state.pos[dof_start + 3 + j, i_b] = xyz[j] + elif joint_type == gs.JOINT_TYPE.FIXED: + pass + elif joint_type == gs.JOINT_TYPE.SPHERICAL: + qloc = ti.Vector( + [ + rigid_global_info.qpos[q_start, i_b], + rigid_global_info.qpos[q_start + 1, i_b], + rigid_global_info.qpos[q_start + 2, i_b], + rigid_global_info.qpos[q_start + 3, i_b], + ], + dt=gs.ti_float, + ) + xyz = gu.ti_quat_to_xyz(qloc, EPS) + for j in ti.static(range(3)): + dofs_state.pos[dof_start + j, i_b] = xyz[j] + links_state.quat_bw[i_l, next_i_j, i_b] = gu.ti_transform_quat_by_quat( + qloc, links_state.quat_bw[i_l, curr_i_j, i_b] + ) + links_state.pos_bw[i_l, next_i_j, i_b] = joints_state.xanchor[ + i_j, i_b + ] - gu.ti_transform_by_quat(joints_info.pos[I_j], links_state.quat_bw[i_l, next_i_j, i_b]) + elif joint_type == gs.JOINT_TYPE.REVOLUTE: + axis = dofs_info.motion_ang[I_d] + dofs_state.pos[dof_start, i_b] = ( + rigid_global_info.qpos[q_start, i_b] - rigid_global_info.qpos0[q_start, i_b] + ) + qloc = gu.ti_rotvec_to_quat(axis * dofs_state.pos[dof_start, i_b], EPS) + links_state.quat_bw[i_l, next_i_j, i_b] = gu.ti_transform_quat_by_quat( + qloc, links_state.quat_bw[i_l, curr_i_j, i_b] + ) + links_state.pos_bw[i_l, next_i_j, i_b] = joints_state.xanchor[ + i_j, i_b + ] - gu.ti_transform_by_quat(joints_info.pos[I_j], links_state.quat_bw[i_l, next_i_j, i_b]) + else: # joint_type == gs.JOINT_TYPE.PRISMATIC: + dofs_state.pos[dof_start, i_b] = ( + rigid_global_info.qpos[q_start, i_b] - rigid_global_info.qpos0[q_start, i_b] + ) + links_state.pos_bw[i_l, next_i_j, i_b] = ( + links_state.pos_bw[i_l, curr_i_j, i_b] + + joints_state.xaxis[i_j, i_b] * dofs_state.pos[dof_start, i_b] + ) - # Skip link pose update for fixed root links to let users manually overwrite them - if not (links_info.parent_idx[I_l] == -1 and links_info.is_fixed[I_l]): - links_state.pos[i_l, i_b] = pos - links_state.quat[i_l, i_b] = quat + # Skip link pose update for fixed root links to let users manually overwrite them + i_j_ = 0 if ti.static(not is_backward) else n_joints + if not (links_info.parent_idx[I_l] == -1 and links_info.is_fixed[I_l]): + links_state.pos[i_l, i_b] = links_state.pos_bw[i_l, i_j_, i_b] + links_state.quat[i_l, i_b] = links_state.quat_bw[i_l, i_j_, i_b] @ti.func @@ -4831,71 +5849,113 @@ def func_forward_velocity_entity( dofs_state: array_class.DofsState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): - for i_l in range(entities_info.link_start[i_e], entities_info.link_end[i_e]): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + for i_l_ in ( + range(entities_info.link_start[i_e], entities_info.link_end[i_e]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) + ): + i_l = i_l_ if ti.static(not is_backward) else (i_l_ + entities_info.link_start[i_e]) - cvel_vel = ti.Vector.zero(gs.ti_float, 3) - cvel_ang = ti.Vector.zero(gs.ti_float, 3) - if links_info.parent_idx[I_l] != -1: - cvel_vel = links_state.cd_vel[links_info.parent_idx[I_l], i_b] - cvel_ang = links_state.cd_ang[links_info.parent_idx[I_l], i_b] + if i_l < entities_info.link_end[i_e]: + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + n_joints = links_info.joint_end[I_l] - links_info.joint_start[I_l] - for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]): - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - joint_type = joints_info.type[I_j] - q_start = joints_info.q_start[I_j] - dof_start = joints_info.dof_start[I_j] + links_state.cd_vel_bw[i_l, 0, i_b] = ti.Vector.zero(gs.ti_float, 3) + links_state.cd_ang_bw[i_l, 0, i_b] = ti.Vector.zero(gs.ti_float, 3) - if joint_type == gs.JOINT_TYPE.FREE: - for i_3 in ti.static(range(3)): - cvel_vel = ( - cvel_vel + dofs_state.cdof_vel[dof_start + i_3, i_b] * dofs_state.vel[dof_start + i_3, i_b] - ) - cvel_ang = ( - cvel_ang + dofs_state.cdof_ang[dof_start + i_3, i_b] * dofs_state.vel[dof_start + i_3, i_b] - ) + if links_info.parent_idx[I_l] != -1: + links_state.cd_vel_bw[i_l, 0, i_b] = links_state.cd_vel[links_info.parent_idx[I_l], i_b] + links_state.cd_ang_bw[i_l, 0, i_b] = links_state.cd_ang[links_info.parent_idx[I_l], i_b] - for i_3 in ti.static(range(3)): - ( - dofs_state.cdofd_ang[dof_start + i_3, i_b], - dofs_state.cdofd_vel[dof_start + i_3, i_b], - ) = ti.Vector.zero(gs.ti_float, 3), ti.Vector.zero(gs.ti_float, 3) + for i_j_ in ( + range(n_joints) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_joints_per_link)) + ): + i_j = i_j_ + links_info.joint_start[I_l] - ( - dofs_state.cdofd_ang[dof_start + i_3 + 3, i_b], - dofs_state.cdofd_vel[dof_start + i_3 + 3, i_b], - ) = gu.motion_cross_motion( - cvel_ang, - cvel_vel, - dofs_state.cdof_ang[dof_start + i_3 + 3, i_b], - dofs_state.cdof_vel[dof_start + i_3 + 3, i_b], - ) + if i_j < links_info.joint_end[I_l]: + I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + joint_type = joints_info.type[I_j] + q_start = joints_info.q_start[I_j] + dof_start = joints_info.dof_start[I_j] - for i_3 in ti.static(range(3)): - cvel_vel = ( - cvel_vel - + dofs_state.cdof_vel[dof_start + i_3 + 3, i_b] * dofs_state.vel[dof_start + i_3 + 3, i_b] - ) - cvel_ang = ( - cvel_ang - + dofs_state.cdof_ang[dof_start + i_3 + 3, i_b] * dofs_state.vel[dof_start + i_3 + 3, i_b] - ) + curr_i_j = 0 if ti.static(not is_backward) else i_j_ + next_i_j = 0 if ti.static(not is_backward) else i_j_ + 1 - else: - for i_d in range(dof_start, joints_info.dof_end[I_j]): - dofs_state.cdofd_ang[i_d, i_b], dofs_state.cdofd_vel[i_d, i_b] = gu.motion_cross_motion( - cvel_ang, - cvel_vel, - dofs_state.cdof_ang[i_d, i_b], - dofs_state.cdof_vel[i_d, i_b], - ) - for i_d in range(dof_start, joints_info.dof_end[I_j]): - cvel_vel = cvel_vel + dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] - cvel_ang = cvel_ang + dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] + if joint_type == gs.JOINT_TYPE.FREE: + for i_3 in ti.static(range(3)): + links_state.cd_vel_bw[i_l, curr_i_j, i_b] += ( + dofs_state.cdof_vel[dof_start + i_3, i_b] * dofs_state.vel[dof_start + i_3, i_b] + ) + links_state.cd_ang_bw[i_l, curr_i_j, i_b] += ( + dofs_state.cdof_ang[dof_start + i_3, i_b] * dofs_state.vel[dof_start + i_3, i_b] + ) - links_state.cd_vel[i_l, i_b] = cvel_vel - links_state.cd_ang[i_l, i_b] = cvel_ang + for i_3 in ti.static(range(3)): + ( + dofs_state.cdofd_ang[dof_start + i_3, i_b], + dofs_state.cdofd_vel[dof_start + i_3, i_b], + ) = ti.Vector.zero(gs.ti_float, 3), ti.Vector.zero(gs.ti_float, 3) + + ( + dofs_state.cdofd_ang[dof_start + i_3 + 3, i_b], + dofs_state.cdofd_vel[dof_start + i_3 + 3, i_b], + ) = gu.motion_cross_motion( + links_state.cd_ang_bw[i_l, curr_i_j, i_b], + links_state.cd_vel_bw[i_l, curr_i_j, i_b], + dofs_state.cdof_ang[dof_start + i_3 + 3, i_b], + dofs_state.cdof_vel[dof_start + i_3 + 3, i_b], + ) + + links_state.cd_vel_bw[i_l, next_i_j, i_b] = links_state.cd_vel_bw[i_l, curr_i_j, i_b] + links_state.cd_ang_bw[i_l, next_i_j, i_b] = links_state.cd_ang_bw[i_l, curr_i_j, i_b] + + for i_3 in ti.static(range(3)): + links_state.cd_vel_bw[i_l, next_i_j, i_b] += ( + dofs_state.cdof_vel[dof_start + i_3 + 3, i_b] * dofs_state.vel[dof_start + i_3 + 3, i_b] + ) + links_state.cd_ang_bw[i_l, next_i_j, i_b] += ( + dofs_state.cdof_ang[dof_start + i_3 + 3, i_b] * dofs_state.vel[dof_start + i_3 + 3, i_b] + ) + + else: + for i_d_ in ( + range(dof_start, joints_info.dof_end[I_j]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_joint)) + ): + i_d = i_d_ if ti.static(not is_backward) else (i_d_ + dof_start) + if i_d < joints_info.dof_end[I_j]: + dofs_state.cdofd_ang[i_d, i_b], dofs_state.cdofd_vel[i_d, i_b] = gu.motion_cross_motion( + links_state.cd_ang_bw[i_l, curr_i_j, i_b], + links_state.cd_vel_bw[i_l, curr_i_j, i_b], + dofs_state.cdof_ang[i_d, i_b], + dofs_state.cdof_vel[i_d, i_b], + ) + + links_state.cd_vel_bw[i_l, next_i_j, i_b] = links_state.cd_vel_bw[i_l, curr_i_j, i_b] + links_state.cd_ang_bw[i_l, next_i_j, i_b] = links_state.cd_ang_bw[i_l, curr_i_j, i_b] + + for i_d_ in ( + range(dof_start, joints_info.dof_end[I_j]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_joint)) + ): + i_d = i_d_ if ti.static(not is_backward) else (i_d_ + dof_start) + if i_d < joints_info.dof_end[I_j]: + links_state.cd_vel_bw[i_l, next_i_j, i_b] += ( + dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] + ) + links_state.cd_ang_bw[i_l, next_i_j, i_b] += ( + dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] + ) + + i_j_ = 0 if ti.static(not is_backward) else n_joints + links_state.cd_vel[i_l, i_b] = links_state.cd_vel_bw[i_l, i_j_, i_b] + links_state.cd_ang[i_l, i_b] = links_state.cd_ang_bw[i_l, i_j_, i_b] @ti.kernel(fastcache=gs.use_fastcache) @@ -4908,6 +5968,7 @@ def kernel_update_geoms( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), force_update_fixed_geoms: ti.template(), + is_backward: ti.template(), ): for i_b_ in range(envs_idx.shape[0]): i_b = envs_idx[i_b_] @@ -4921,6 +5982,7 @@ def kernel_update_geoms( rigid_global_info, static_rigid_sim_config, force_update_fixed_geoms, + is_backward, ) @@ -4934,15 +5996,47 @@ def func_update_geoms( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), force_update_fixed_geoms: ti.template(), + is_backward: ti.template(), ): """ NOTE: this only update geom pose, not its verts and else. """ n_geoms = geoms_info.pos.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] - for i_g in range(entities_info.geom_start[i_e], entities_info.geom_end[i_e]): + for i_0 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(n_geoms) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_entities)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(geoms_info.pos.shape[0])) + ) + ): + i_e = rigid_global_info.awake_entities[i_0, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 0 + n_geoms = entities_info.geom_end[i_e] - entities_info.geom_start[i_e] + + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(n_geoms) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_geoms_per_entity)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + i_g = i_1 + entities_info.geom_start[i_e] if ti.static(static_rigid_sim_config.use_hibernation) else i_0 + if i_1 < (n_geoms if ti.static(static_rigid_sim_config.use_hibernation) else 1): if force_update_fixed_geoms or not geoms_info.is_fixed[i_g]: ( geoms_state.pos[i_g, i_b], @@ -4955,20 +6049,6 @@ def func_update_geoms( ) geoms_state.verts_updated[i_g, i_b] = False - else: - for i_g in range(n_geoms): - if force_update_fixed_geoms or not geoms_info.is_fixed[i_g]: - ( - geoms_state.pos[i_g, i_b], - geoms_state.quat[i_g, i_b], - ) = gu.ti_transform_pos_quat_by_trans_quat( - geoms_info.pos[i_g], - geoms_info.quat[i_g], - links_state.pos[geoms_info.link_idx[i_g], i_b], - links_state.quat[geoms_info.link_idx[i_g], i_b], - ) - - geoms_state.verts_updated[i_g, i_b] = False @ti.kernel(fastcache=gs.use_fastcache) @@ -5170,7 +6250,9 @@ def func_hibernate__for_all_awake_islands_either_hiberanate_or_update_aabb_sort_ ) # store entities in the hibernated islands by daisy chaining them - ci.entity_idx_to_next_entity_idx_in_hibernated_island[prev_entity_idx, i_b] = entity_idx + contact_island_state.entity_idx_to_next_entity_idx_in_hibernated_island[ + prev_entity_idx, i_b + ] = entity_idx prev_entity_idx = entity_idx @@ -5330,16 +6412,16 @@ def func_clear_external_force( _B = links_state.pos.shape[1] n_links = links_state.pos.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) - for i_b in range(_B): - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] - links_state.cfrc_applied_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) - links_state.cfrc_applied_vel[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) - else: - ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) - for i_l, i_b in ti.ndrange(n_links, _B): + ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) + for i_0, i_b in ( + ti.ndrange(1, _B) if ti.static(static_rigid_sim_config.use_hibernation) else ti.ndrange(n_links, _B) + ): + for i_1 in ( + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ): + i_l = rigid_global_info.awake_links[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_0 links_state.cfrc_applied_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) links_state.cfrc_applied_vel[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) @@ -5357,170 +6439,201 @@ def func_torque_and_passive_force( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), contact_island_state: array_class.ContactIslandState, + is_backward: ti.template(), ): - EPS = rigid_global_info.EPS[None] - - n_entities = entities_info.n_links.shape[0] - _B = dofs_state.ctrl_mode.shape[1] - n_dofs = dofs_state.ctrl_mode.shape[0] - n_links = links_info.root_idx.shape[0] - # compute force based on each dof's ctrl mode ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_e, i_b in ti.ndrange(n_entities, _B): + for i_e, i_b in ti.ndrange(entities_info.n_links.shape[0], dofs_state.ctrl_mode.shape[1]): wakeup = False - for i_l in range(entities_info.link_start[i_e], entities_info.link_end[i_e]): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - if links_info.n_dofs[I_l] == 0: - continue + EPS = rigid_global_info.EPS[None] - i_j = links_info.joint_start[I_l] - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - joint_type = joints_info.type[I_j] + for i_l_ in ( + range(entities_info.link_start[i_e], entities_info.link_end[i_e]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) + ): + i_l = i_l_ if ti.static(not is_backward) else (i_l_ + entities_info.link_start[i_e]) - for i_d in range(links_info.dof_start[I_l], links_info.dof_end[I_l]): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - force = gs.ti_float(0.0) - if dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.FORCE: - force = dofs_state.ctrl_force[i_d, i_b] - elif dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.VELOCITY: - force = dofs_info.kv[I_d] * (dofs_state.ctrl_vel[i_d, i_b] - dofs_state.vel[i_d, i_b]) - elif dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.POSITION and not ( - joint_type == gs.JOINT_TYPE.FREE and i_d >= links_info.dof_start[I_l] + 3 - ): - force = dofs_info.kp[I_d] * ( - dofs_state.ctrl_pos[i_d, i_b] - dofs_state.pos[i_d, i_b] - ) + dofs_info.kv[I_d] * (dofs_state.ctrl_vel[i_d, i_b] - dofs_state.vel[i_d, i_b]) - - dofs_state.qf_applied[i_d, i_b] = ti.math.clamp( - force, - dofs_info.force_range[I_d][0], - dofs_info.force_range[I_d][1], - ) + if i_l < entities_info.link_end[i_e]: + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + if links_info.n_dofs[I_l] > 0: + i_j = links_info.joint_start[I_l] + I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + joint_type = joints_info.type[I_j] - if ti.abs(force) > EPS: - wakeup = True + for i_d_ in ( + range(links_info.dof_start[I_l], links_info.dof_end[I_l]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_link)) + ): + i_d = i_d_ if ti.static(not is_backward) else (i_d_ + links_info.dof_start[I_l]) - dof_start = links_info.dof_start[I_l] - if joint_type == gs.JOINT_TYPE.FREE and ( - dofs_state.ctrl_mode[dof_start + 3, i_b] == gs.CTRL_MODE.POSITION - or dofs_state.ctrl_mode[dof_start + 4, i_b] == gs.CTRL_MODE.POSITION - or dofs_state.ctrl_mode[dof_start + 5, i_b] == gs.CTRL_MODE.POSITION - ): - xyz = ti.Vector( - [ - dofs_state.pos[0 + 3 + dof_start, i_b], - dofs_state.pos[1 + 3 + dof_start, i_b], - dofs_state.pos[2 + 3 + dof_start, i_b], - ], - dt=gs.ti_float, - ) + if i_d < links_info.dof_end[I_l]: + I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d + force = gs.ti_float(0.0) + if dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.FORCE: + force = dofs_state.ctrl_force[i_d, i_b] + elif dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.VELOCITY: + force = dofs_info.kv[I_d] * (dofs_state.ctrl_vel[i_d, i_b] - dofs_state.vel[i_d, i_b]) + elif dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.POSITION and not ( + joint_type == gs.JOINT_TYPE.FREE and i_d >= links_info.dof_start[I_l] + 3 + ): + force = dofs_info.kp[I_d] * ( + dofs_state.ctrl_pos[i_d, i_b] - dofs_state.pos[i_d, i_b] + ) + dofs_info.kv[I_d] * (dofs_state.ctrl_vel[i_d, i_b] - dofs_state.vel[i_d, i_b]) + + dofs_state.qf_applied[i_d, i_b] = ti.math.clamp( + force, + dofs_info.force_range[I_d][0], + dofs_info.force_range[I_d][1], + ) - ctrl_xyz = ti.Vector( - [ - dofs_state.ctrl_pos[0 + 3 + dof_start, i_b], - dofs_state.ctrl_pos[1 + 3 + dof_start, i_b], - dofs_state.ctrl_pos[2 + 3 + dof_start, i_b], - ], - dt=gs.ti_float, - ) + if ti.abs(force) > EPS: + wakeup = True - quat = gu.ti_xyz_to_quat(xyz) - ctrl_quat = gu.ti_xyz_to_quat(ctrl_xyz) + dof_start = links_info.dof_start[I_l] + if joint_type == gs.JOINT_TYPE.FREE and ( + dofs_state.ctrl_mode[dof_start + 3, i_b] == gs.CTRL_MODE.POSITION + or dofs_state.ctrl_mode[dof_start + 4, i_b] == gs.CTRL_MODE.POSITION + or dofs_state.ctrl_mode[dof_start + 5, i_b] == gs.CTRL_MODE.POSITION + ): + xyz = ti.Vector( + [ + dofs_state.pos[0 + 3 + dof_start, i_b], + dofs_state.pos[1 + 3 + dof_start, i_b], + dofs_state.pos[2 + 3 + dof_start, i_b], + ], + dt=gs.ti_float, + ) + + ctrl_xyz = ti.Vector( + [ + dofs_state.ctrl_pos[0 + 3 + dof_start, i_b], + dofs_state.ctrl_pos[1 + 3 + dof_start, i_b], + dofs_state.ctrl_pos[2 + 3 + dof_start, i_b], + ], + dt=gs.ti_float, + ) - q_diff = gu.ti_transform_quat_by_quat(ctrl_quat, gu.ti_inv_quat(quat)) - rotvec = gu.ti_quat_to_rotvec(q_diff, EPS) + quat = gu.ti_xyz_to_quat(xyz) + ctrl_quat = gu.ti_xyz_to_quat(ctrl_xyz) - for j in ti.static(range(3)): - i_d = dof_start + 3 + j - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - force = dofs_info.kp[I_d] * rotvec[j] - dofs_info.kv[I_d] * dofs_state.vel[i_d, i_b] + q_diff = gu.ti_transform_quat_by_quat(ctrl_quat, gu.ti_inv_quat(quat)) + rotvec = gu.ti_quat_to_rotvec(q_diff, EPS) - dofs_state.qf_applied[i_d, i_b] = ti.math.clamp( - force, dofs_info.force_range[I_d][0], dofs_info.force_range[I_d][1] - ) + for j in ti.static(range(3)): + i_d = dof_start + 3 + j + I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d + force = dofs_info.kp[I_d] * rotvec[j] - dofs_info.kv[I_d] * dofs_state.vel[i_d, i_b] + + dofs_state.qf_applied[i_d, i_b] = ti.math.clamp( + force, dofs_info.force_range[I_d][0], dofs_info.force_range[I_d][1] + ) - if ti.abs(force) > EPS: - wakeup = True + if ti.abs(force) > EPS: + wakeup = True - if ti.static(static_rigid_sim_config.use_hibernation) and entities_state.hibernated[i_e, i_b] and wakeup: - func_wakeup_entity_and_its_temp_island( - i_e, - i_b, - entities_state, - entities_info, - dofs_state, - links_state, - geoms_state, - rigid_global_info, - contact_island_state, + if ti.static(static_rigid_sim_config.use_hibernation): + if entities_state.hibernated[i_e, i_b] and wakeup: + # TODO: migrate this function + func_wakeup_entity_and_its_temp_island( + i_e, + i_b, + entities_state, + entities_info, + dofs_state, + links_state, + geoms_state, + rigid_global_info, + contact_island_state, + ) + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + ti.ndrange(1, dofs_state.ctrl_mode.shape[1]) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.ndrange(dofs_state.ctrl_mode.shape[0], dofs_state.ctrl_mode.shape[1]) + ): + for i_1 in ( + ( + # Dynamic inner for forward pass + range(rigid_global_info.n_awake_dofs[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_dofs)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) ) + ): + if i_1 < (rigid_global_info.n_awake_dofs[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1): + i_d = ( + rigid_global_info.awake_dofs[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_d_ in range(rigid_global_info.n_awake_dofs[i_b]): - i_d = rigid_global_info.awake_dofs[i_d_, i_b] I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - dofs_state.qf_passive[i_d, i_b] = -dofs_info.damping[I_d] * dofs_state.vel[i_d, i_b] - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + ti.ndrange(1, dofs_state.ctrl_mode.shape[1]) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.ndrange(links_info.root_idx.shape[0], dofs_state.ctrl_mode.shape[1]) + ): + for i_1 in ( + ( + # Dynamic inner for forward pass + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_links)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < ( + rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): + i_l = ( + rigid_global_info.awake_links[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - if links_info.n_dofs[I_l] == 0: - continue - - i_j = links_info.joint_start[I_l] - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - joint_type = joints_info.type[I_j] - - if joint_type != gs.JOINT_TYPE.FREE and joint_type != gs.JOINT_TYPE.FIXED: - q_start = links_info.q_start[I_l] - dof_start = links_info.dof_start[I_l] - dof_end = links_info.dof_end[I_l] - - for j_d in range(dof_end - dof_start): - I_d = ( - [dof_start + j_d, i_b] - if ti.static(static_rigid_sim_config.batch_dofs_info) - else dof_start + j_d - ) - dofs_state.qf_passive[dof_start + j_d, i_b] += ( - -rigid_global_info.qpos[q_start + j_d, i_b] * dofs_info.stiffness[I_d] - ) - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_d, i_b in ti.ndrange(n_dofs, _B): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - dofs_state.qf_passive[i_d, i_b] = -dofs_info.damping[I_d] * dofs_state.vel[i_d, i_b] - - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l, i_b in ti.ndrange(n_links, _B): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - if links_info.n_dofs[I_l] == 0: - continue - i_j = links_info.joint_start[I_l] - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - joint_type = joints_info.type[I_j] + if links_info.n_dofs[I_l] > 0: + i_j = links_info.joint_start[I_l] + I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + joint_type = joints_info.type[I_j] - if joint_type != gs.JOINT_TYPE.FREE and joint_type != gs.JOINT_TYPE.FIXED: - q_start = links_info.q_start[I_l] - dof_start = links_info.dof_start[I_l] - dof_end = links_info.dof_end[I_l] + if joint_type != gs.JOINT_TYPE.FREE and joint_type != gs.JOINT_TYPE.FIXED: + q_start = links_info.q_start[I_l] + dof_start = links_info.dof_start[I_l] + dof_end = links_info.dof_end[I_l] - for j_d in range(dof_end - dof_start): - I_d = ( - [dof_start + j_d, i_b] - if ti.static(static_rigid_sim_config.batch_dofs_info) - else dof_start + j_d - ) - dofs_state.qf_passive[dof_start + j_d, i_b] += ( - -rigid_global_info.qpos[q_start + j_d, i_b] * dofs_info.stiffness[I_d] - ) + for j_d in ( + range(dof_end - dof_start) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_link)) + ): + if j_d < dof_end: + I_d = ( + [dof_start + j_d, i_b] + if ti.static(static_rigid_sim_config.batch_dofs_info) + else dof_start + j_d + ) + dofs_state.qf_passive[dof_start + j_d, i_b] += ( + -rigid_global_info.qpos[q_start + j_d, i_b] * dofs_info.stiffness[I_d] + ) @ti.func @@ -5532,90 +6645,85 @@ def func_update_acc( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): - _B = dofs_state.ctrl_mode.shape[1] - n_links = links_info.root_idx.shape[0] - n_entities = entities_info.n_links.shape[0] + # Assume this is the outermost loop + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + ti.ndrange(1, dofs_state.ctrl_mode.shape[1]) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.ndrange(entities_info.n_links.shape[0], dofs_state.ctrl_mode.shape[1]) + ): + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_entities)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < ( + rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): + i_e = ( + rigid_global_info.awake_entities[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] - for i_l in range(entities_info.link_start[i_e], entities_info.link_end[i_e]): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - i_p = links_info.parent_idx[I_l] - - if i_p == -1: - links_state.cdd_vel[i_l, i_b] = -rigid_global_info.gravity[i_b] * ( - 1 - entities_info.gravity_compensation[i_e] - ) - links_state.cdd_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) - if ti.static(update_cacc): - links_state.cacc_lin[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) - links_state.cacc_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) - else: - links_state.cdd_vel[i_l, i_b] = links_state.cdd_vel[i_p, i_b] - links_state.cdd_ang[i_l, i_b] = links_state.cdd_ang[i_p, i_b] - if ti.static(update_cacc): - links_state.cacc_lin[i_l, i_b] = links_state.cacc_lin[i_p, i_b] - links_state.cacc_ang[i_l, i_b] = links_state.cacc_ang[i_p, i_b] - - for i_d in range(links_info.dof_start[I_l], links_info.dof_end[I_l]): - local_cdd_vel = dofs_state.cdofd_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] - local_cdd_ang = dofs_state.cdofd_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] - links_state.cdd_vel[i_l, i_b] = links_state.cdd_vel[i_l, i_b] + local_cdd_vel - links_state.cdd_ang[i_l, i_b] = links_state.cdd_ang[i_l, i_b] + local_cdd_ang - if ti.static(update_cacc): - links_state.cacc_lin[i_l, i_b] = ( - links_state.cacc_lin[i_l, i_b] - + local_cdd_vel - + dofs_state.cdof_vel[i_d, i_b] * dofs_state.acc[i_d, i_b] - ) - links_state.cacc_ang[i_l, i_b] = ( - links_state.cacc_ang[i_l, i_b] - + local_cdd_ang - + dofs_state.cdof_ang[i_d, i_b] * dofs_state.acc[i_d, i_b] - ) - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_e, i_b in ti.ndrange(n_entities, _B): - for i_l in range(entities_info.link_start[i_e], entities_info.link_end[i_e]): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - i_p = links_info.parent_idx[I_l] + for i_l_ in ( + range(entities_info.link_start[i_e], entities_info.link_end[i_e]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) + ): + i_l = i_l_ if ti.static(not is_backward) else (i_l_ + entities_info.link_start[i_e]) - if i_p == -1: - links_state.cdd_vel[i_l, i_b] = -rigid_global_info.gravity[i_b] * ( - 1 - entities_info.gravity_compensation[i_e] - ) - links_state.cdd_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) - if ti.static(update_cacc): - links_state.cacc_lin[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) - links_state.cacc_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) - else: - links_state.cdd_vel[i_l, i_b] = links_state.cdd_vel[i_p, i_b] - links_state.cdd_ang[i_l, i_b] = links_state.cdd_ang[i_p, i_b] - if ti.static(update_cacc): - links_state.cacc_lin[i_l, i_b] = links_state.cacc_lin[i_p, i_b] - links_state.cacc_ang[i_l, i_b] = links_state.cacc_ang[i_p, i_b] - - for i_d in range(links_info.dof_start[I_l], links_info.dof_end[I_l]): - # cacc = cacc_parent + cdofdot * qvel + cdof * qacc - local_cdd_vel = dofs_state.cdofd_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] - local_cdd_ang = dofs_state.cdofd_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] - links_state.cdd_vel[i_l, i_b] = links_state.cdd_vel[i_l, i_b] + local_cdd_vel - links_state.cdd_ang[i_l, i_b] = links_state.cdd_ang[i_l, i_b] + local_cdd_ang - if ti.static(update_cacc): - links_state.cacc_lin[i_l, i_b] = ( - links_state.cacc_lin[i_l, i_b] - + local_cdd_vel - + dofs_state.cdof_vel[i_d, i_b] * dofs_state.acc[i_d, i_b] - ) - links_state.cacc_ang[i_l, i_b] = ( - links_state.cacc_ang[i_l, i_b] - + local_cdd_ang - + dofs_state.cdof_ang[i_d, i_b] * dofs_state.acc[i_d, i_b] - ) + if i_l < entities_info.link_end[i_e]: + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + i_p = links_info.parent_idx[I_l] + + if i_p == -1: + links_state.cdd_vel[i_l, i_b] = -rigid_global_info.gravity[i_b] * ( + 1 - entities_info.gravity_compensation[i_e] + ) + links_state.cdd_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) + if ti.static(update_cacc): + links_state.cacc_lin[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) + links_state.cacc_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) + else: + links_state.cdd_vel[i_l, i_b] = links_state.cdd_vel[i_p, i_b] + links_state.cdd_ang[i_l, i_b] = links_state.cdd_ang[i_p, i_b] + if ti.static(update_cacc): + links_state.cacc_lin[i_l, i_b] = links_state.cacc_lin[i_p, i_b] + links_state.cacc_ang[i_l, i_b] = links_state.cacc_ang[i_p, i_b] + + for i_d_ in ( + range(links_info.dof_start[I_l], links_info.dof_end[I_l]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_link)) + ): + i_d = i_d_ if ti.static(not is_backward) else (i_d_ + links_info.dof_start[I_l]) + + if i_d < links_info.dof_end[I_l]: + # cacc = cacc_parent + cdofdot * qvel + cdof * qacc + local_cdd_vel = dofs_state.cdofd_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] + local_cdd_ang = dofs_state.cdofd_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] + links_state.cdd_vel[i_l, i_b] += local_cdd_vel + links_state.cdd_ang[i_l, i_b] += local_cdd_ang + if ti.static(update_cacc): + links_state.cacc_lin[i_l, i_b] += ( + local_cdd_vel + dofs_state.cdof_vel[i_d, i_b] * dofs_state.acc[i_d, i_b] + ) + links_state.cacc_ang[i_l, i_b] += ( + local_cdd_ang + dofs_state.cdof_ang[i_d, i_b] * dofs_state.acc[i_d, i_b] + ) @ti.func @@ -5625,16 +6733,37 @@ def func_update_force( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): - _B = links_state.pos.shape[1] - n_links = links_info.root_idx.shape[0] - n_entities = entities_info.n_links.shape[0] - - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + ti.ndrange(1, links_state.pos.shape[1]) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.ndrange(links_info.root_idx.shape[0], links_state.pos.shape[1]) + ): + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_links)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < ( + rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): + i_l = ( + rigid_global_info.awake_links[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) f1_ang, f1_vel = gu.inertial_mul( links_state.cinr_pos[i_l, i_b], @@ -5645,73 +6774,68 @@ def func_update_force( ) f2_ang, f2_vel = gu.inertial_mul( links_state.cinr_pos[i_l, i_b], - links_state.cinr_inertial[i_l, i_b], - links_state.cinr_mass[i_l, i_b], - links_state.cd_vel[i_l, i_b], - links_state.cd_ang[i_l, i_b], - ) - f2_ang, f2_vel = gu.motion_cross_force( - links_state.cd_ang[i_l, i_b], links_state.cd_vel[i_l, i_b], f2_ang, f2_vel - ) - - links_state.cfrc_vel[i_l, i_b] = ( - f1_vel + f2_vel + links_state.cfrc_applied_vel[i_l, i_b] + links_state.cfrc_coupling_vel[i_l, i_b] - ) - links_state.cfrc_ang[i_l, i_b] = ( - f1_ang + f2_ang + links_state.cfrc_applied_ang[i_l, i_b] + links_state.cfrc_coupling_ang[i_l, i_b] - ) - - for i_b in range(_B): - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] - for i_l_ in range(entities_info.n_links[i_e]): - i_l = entities_info.link_end[i_e] - 1 - i_l_ - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - i_p = links_info.parent_idx[I_l] - if i_p != -1: - links_state.cfrc_vel[i_p, i_b] = links_state.cfrc_vel[i_p, i_b] + links_state.cfrc_vel[i_l, i_b] - links_state.cfrc_ang[i_p, i_b] = links_state.cfrc_ang[i_p, i_b] + links_state.cfrc_ang[i_l, i_b] - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l, i_b in ti.ndrange(n_links, _B): - f1_ang, f1_vel = gu.inertial_mul( - links_state.cinr_pos[i_l, i_b], - links_state.cinr_inertial[i_l, i_b], - links_state.cinr_mass[i_l, i_b], - links_state.cdd_vel[i_l, i_b], - links_state.cdd_ang[i_l, i_b], - ) - f2_ang, f2_vel = gu.inertial_mul( - links_state.cinr_pos[i_l, i_b], - links_state.cinr_inertial[i_l, i_b], - links_state.cinr_mass[i_l, i_b], - links_state.cd_vel[i_l, i_b], - links_state.cd_ang[i_l, i_b], - ) - f2_ang, f2_vel = gu.motion_cross_force( - links_state.cd_ang[i_l, i_b], links_state.cd_vel[i_l, i_b], f2_ang, f2_vel - ) + links_state.cinr_inertial[i_l, i_b], + links_state.cinr_mass[i_l, i_b], + links_state.cd_vel[i_l, i_b], + links_state.cd_ang[i_l, i_b], + ) + f3_ang, f3_vel = gu.motion_cross_force( + links_state.cd_ang[i_l, i_b], links_state.cd_vel[i_l, i_b], f2_ang, f2_vel + ) + + links_state.cfrc_vel[i_l, i_b] = ( + f1_vel + f3_vel + links_state.cfrc_applied_vel[i_l, i_b] + links_state.cfrc_coupling_vel[i_l, i_b] + ) + links_state.cfrc_ang[i_l, i_b] = ( + f1_ang + f3_ang + links_state.cfrc_applied_ang[i_l, i_b] + links_state.cfrc_coupling_ang[i_l, i_b] + ) - links_state.cfrc_vel[i_l, i_b] = ( - f1_vel + f2_vel + links_state.cfrc_applied_vel[i_l, i_b] + links_state.cfrc_coupling_vel[i_l, i_b] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + ti.ndrange(1, links_state.pos.shape[1]) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.ndrange(entities_info.n_links.shape[0], links_state.pos.shape[1]) + ): + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) ) - links_state.cfrc_ang[i_l, i_b] = ( - f1_ang + f2_ang + links_state.cfrc_applied_ang[i_l, i_b] + links_state.cfrc_coupling_ang[i_l, i_b] + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_entities)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) ) + ): + if i_1 < ( + rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): + i_e = ( + rigid_global_info.awake_entities[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_e, i_b in ti.ndrange(n_entities, _B): - for i_l_ in range(entities_info.n_links[i_e]): - i_l = entities_info.link_end[i_e] - 1 - i_l_ - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - i_p = links_info.parent_idx[I_l] - if i_p != -1: - links_state.cfrc_vel[i_p, i_b] = links_state.cfrc_vel[i_p, i_b] + links_state.cfrc_vel[i_l, i_b] - links_state.cfrc_ang[i_p, i_b] = links_state.cfrc_ang[i_p, i_b] + links_state.cfrc_ang[i_l, i_b] + for i_l_ in ( + range(entities_info.n_links[i_e]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) + ): + if i_l_ < entities_info.n_links[i_e]: + i_l = entities_info.link_end[i_e] - 1 - i_l_ + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + i_p = links_info.parent_idx[I_l] + if i_p != -1: + links_state.cfrc_vel[i_p, i_b] += links_state.cfrc_vel[i_l, i_b] + links_state.cfrc_ang[i_p, i_b] += links_state.cfrc_ang[i_l, i_b] # Clear coupling forces after use ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l, i_b in ti.ndrange(n_links, _B): + for i_l, i_b in ti.ndrange(links_info.root_idx.shape[0], links_state.pos.shape[1]): links_state.cfrc_coupling_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) links_state.cfrc_coupling_vel[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) @@ -5747,49 +6871,75 @@ def func_bias_force( links_info: array_class.LinksInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): - _B = dofs_state.ctrl_mode.shape[1] - n_links = links_info.root_idx.shape[0] - - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + ti.ndrange(1, dofs_state.ctrl_mode.shape[1]) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.ndrange(links_info.root_idx.shape[0], dofs_state.ctrl_mode.shape[1]) + ): + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_links)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < ( + rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): + i_l = ( + rigid_global_info.awake_links[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - for i_d in range(links_info.dof_start[I_l], links_info.dof_end[I_l]): - dofs_state.qf_bias[i_d, i_b] = dofs_state.cdof_ang[i_d, i_b].dot( - links_state.cfrc_ang[i_l, i_b] - ) + dofs_state.cdof_vel[i_d, i_b].dot(links_state.cfrc_vel[i_l, i_b]) - - dofs_state.force[i_d, i_b] = ( - dofs_state.qf_passive[i_d, i_b] - - dofs_state.qf_bias[i_d, i_b] - + dofs_state.qf_applied[i_d, i_b] - # + self.dofs_state.qf_actuator[i_d, i_b] - ) - - dofs_state.qf_smooth[i_d, i_b] = dofs_state.force[i_d, i_b] - - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l, i_b in ti.ndrange(n_links, _B): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + for i_d_ in ( + range(links_info.dof_start[I_l], links_info.dof_end[I_l]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_link)) + ): + i_d = i_d_ if ti.static(not is_backward) else (i_d_ + links_info.dof_start[I_l]) + if i_d < links_info.dof_end[I_l]: + dofs_state.qf_bias[i_d, i_b] = dofs_state.cdof_ang[i_d, i_b].dot( + links_state.cfrc_ang[i_l, i_b] + ) + dofs_state.cdof_vel[i_d, i_b].dot(links_state.cfrc_vel[i_l, i_b]) + + dofs_state.force[i_d, i_b] = ( + dofs_state.qf_passive[i_d, i_b] + - dofs_state.qf_bias[i_d, i_b] + + dofs_state.qf_applied[i_d, i_b] + # + self.dofs_state.qf_actuator[i_d, i_b] + ) - for i_d in range(links_info.dof_start[I_l], links_info.dof_end[I_l]): - dofs_state.qf_bias[i_d, i_b] = dofs_state.cdof_ang[i_d, i_b].dot( - links_state.cfrc_ang[i_l, i_b] - ) + dofs_state.cdof_vel[i_d, i_b].dot(links_state.cfrc_vel[i_l, i_b]) + dofs_state.qf_smooth[i_d, i_b] = dofs_state.force[i_d, i_b] - dofs_state.force[i_d, i_b] = ( - dofs_state.qf_passive[i_d, i_b] - - dofs_state.qf_bias[i_d, i_b] - + dofs_state.qf_applied[i_d, i_b] - # + self.dofs_state.qf_actuator[i_d, i_b] - ) - dofs_state.qf_smooth[i_d, i_b] = dofs_state.force[i_d, i_b] +@ti.kernel +def kernel_compute_qacc( + dofs_state: array_class.DofsState, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), + is_backward: ti.template(), +): + func_compute_qacc( + dofs_state=dofs_state, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, + ) @ti.func @@ -5798,32 +6948,57 @@ def func_compute_qacc( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): - _B = dofs_state.ctrl_mode.shape[1] - n_entities = entities_info.n_links.shape[0] - func_solve_mass( vec=dofs_state.force, out=dofs_state.acc_smooth, + out_bw=dofs_state.acc_smooth_bw, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) - for i_b in range(_B): - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] - for i_d1_ in range(entities_info.n_dofs[i_e]): + # Assume this is the outermost loop + ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) + for i_0, i_b in ( + ti.ndrange(1, dofs_state.ctrl_mode.shape[1]) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.ndrange(entities_info.n_links.shape[0], dofs_state.ctrl_mode.shape[1]) + ): + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_entities)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < ( + rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): + i_e = ( + rigid_global_info.awake_entities[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) + + for i_d1_ in ( + range(entities_info.n_dofs[i_e]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): i_d1 = entities_info.dof_start[i_e] + i_d1_ - dofs_state.acc[i_d1, i_b] = dofs_state.acc_smooth[i_d1, i_b] - else: - ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) - for i_e, i_b in ti.ndrange(n_entities, _B): - for i_d1_ in range(entities_info.n_dofs[i_e]): - i_d1 = entities_info.dof_start[i_e] + i_d1_ - dofs_state.acc[i_d1, i_b] = dofs_state.acc_smooth[i_d1, i_b] + if i_d1 < entities_info.dof_end[i_e]: + dofs_state.acc[i_d1, i_b] = dofs_state.acc_smooth[i_d1, i_b] @ti.func @@ -5833,56 +7008,81 @@ def func_integrate( joints_info: array_class.JointsInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): - EPS = rigid_global_info.EPS[None] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + (ti.ndrange(1, dofs_state.ctrl_mode.shape[1])) + if ti.static(static_rigid_sim_config.use_hibernation) + else (ti.ndrange(dofs_state.ctrl_mode.shape[0], dofs_state.ctrl_mode.shape[1])) + ): + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_dofs[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_dofs)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < (rigid_global_info.n_awake_dofs[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1): + i_d = ( + rigid_global_info.awake_dofs[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) - _B = dofs_state.ctrl_mode.shape[1] - n_dofs = dofs_state.ctrl_mode.shape[0] - n_links = links_info.root_idx.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_d_ in range(rigid_global_info.n_awake_dofs[i_b]): - i_d = rigid_global_info.awake_dofs[i_d_, i_b] - dofs_state.vel[i_d, i_b] = ( + dofs_state.vel_next[i_d, i_b] = ( dofs_state.vel[i_d, i_b] + dofs_state.acc[i_d, i_b] * rigid_global_info.substep_dt[None] ) - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + (ti.ndrange(1, dofs_state.ctrl_mode.shape[1])) + if ti.static(static_rigid_sim_config.use_hibernation) + else (ti.ndrange(links_info.root_idx.shape[0], dofs_state.ctrl_mode.shape[1])) + ): + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_links)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < ( + rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): + i_l = ( + rigid_global_info.awake_links[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + if links_info.n_dofs[I_l] > 0: + EPS = rigid_global_info.EPS[None] + dof_start = links_info.dof_start[I_l] + q_start = links_info.q_start[I_l] + q_end = links_info.q_end[I_l] - for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]): + i_j = links_info.joint_start[I_l] I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - dof_start = joints_info.dof_start[I_j] - q_start = joints_info.q_start[I_j] - q_end = joints_info.q_end[I_j] - joint_type = joints_info.type[I_j] if joint_type == gs.JOINT_TYPE.FREE: - rot = ti.Vector( - [ - rigid_global_info.qpos[q_start + 3, i_b], - rigid_global_info.qpos[q_start + 4, i_b], - rigid_global_info.qpos[q_start + 5, i_b], - rigid_global_info.qpos[q_start + 6, i_b], - ] - ) - ang = ( - ti.Vector( - [ - dofs_state.vel[dof_start + 3, i_b], - dofs_state.vel[dof_start + 4, i_b], - dofs_state.vel[dof_start + 5, i_b], - ] - ) - * rigid_global_info.substep_dt[None] - ) - qrot = gu.ti_rotvec_to_quat(ang, EPS) - rot = gu.ti_transform_quat_by_quat(qrot, rot) pos = ti.Vector( [ rigid_global_info.qpos[q_start, i_b], @@ -5892,118 +7092,139 @@ def func_integrate( ) vel = ti.Vector( [ - dofs_state.vel[dof_start, i_b], - dofs_state.vel[dof_start + 1, i_b], - dofs_state.vel[dof_start + 2, i_b], + dofs_state.vel_next[dof_start, i_b], + dofs_state.vel_next[dof_start + 1, i_b], + dofs_state.vel_next[dof_start + 2, i_b], ] ) - pos = pos + vel * rigid_global_info.substep_dt[None] + pos += vel * rigid_global_info.substep_dt[None] for j in ti.static(range(3)): - rigid_global_info.qpos[q_start + j, i_b] = pos[j] - for j in ti.static(range(4)): - rigid_global_info.qpos[q_start + j + 3, i_b] = rot[j] - elif joint_type == gs.JOINT_TYPE.FIXED: - pass - elif joint_type == gs.JOINT_TYPE.SPHERICAL: - rot = ti.Vector( + rigid_global_info.qpos_next[q_start + j, i_b] = pos[j] + if joint_type == gs.JOINT_TYPE.SPHERICAL or joint_type == gs.JOINT_TYPE.FREE: + rot_offset = 3 if joint_type == gs.JOINT_TYPE.FREE else 0 + rot0 = ti.Vector( [ - rigid_global_info.qpos[q_start + 0, i_b], - rigid_global_info.qpos[q_start + 1, i_b], - rigid_global_info.qpos[q_start + 2, i_b], - rigid_global_info.qpos[q_start + 3, i_b], + rigid_global_info.qpos[q_start + rot_offset + 0, i_b], + rigid_global_info.qpos[q_start + rot_offset + 1, i_b], + rigid_global_info.qpos[q_start + rot_offset + 2, i_b], + rigid_global_info.qpos[q_start + rot_offset + 3, i_b], ] ) ang = ( ti.Vector( [ - dofs_state.vel[dof_start + 3, i_b], - dofs_state.vel[dof_start + 4, i_b], - dofs_state.vel[dof_start + 5, i_b], + dofs_state.vel_next[dof_start + rot_offset + 0, i_b], + dofs_state.vel_next[dof_start + rot_offset + 1, i_b], + dofs_state.vel_next[dof_start + rot_offset + 2, i_b], ] ) * rigid_global_info.substep_dt[None] ) qrot = gu.ti_rotvec_to_quat(ang, EPS) - rot = gu.ti_transform_quat_by_quat(qrot, rot) + rot = gu.ti_transform_quat_by_quat(qrot, rot0) for j in ti.static(range(4)): - rigid_global_info.qpos[q_start + j, i_b] = rot[j] - + rigid_global_info.qpos_next[q_start + j + rot_offset, i_b] = rot[j] else: - for j in range(q_end - q_start): - rigid_global_info.qpos[q_start + j, i_b] = ( - rigid_global_info.qpos[q_start + j, i_b] - + dofs_state.vel[dof_start + j, i_b] * rigid_global_info.substep_dt[None] - ) + for j_ in ( + (range(q_end - q_start)) + if ti.static(not is_backward) + else (ti.static(range(static_rigid_sim_config.max_n_qs_per_link))) + ): + j = q_start + j_ + if j < q_end: + rigid_global_info.qpos_next[j, i_b] = ( + rigid_global_info.qpos[j, i_b] + + dofs_state.vel_next[dof_start + j_, i_b] * rigid_global_info.substep_dt[None] + ) - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_d, i_b in ti.ndrange(n_dofs, _B): - dofs_state.vel[i_d, i_b] = ( - dofs_state.vel[i_d, i_b] + dofs_state.acc[i_d, i_b] * rigid_global_info.substep_dt[None] - ) - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l, i_b in ti.ndrange(n_links, _B): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - if links_info.n_dofs[I_l] == 0: - continue +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_copy_next_to_curr( + dofs_state: array_class.DofsState, + rigid_global_info: array_class.RigidGlobalInfo, +): + for i_d, i_b in ti.ndrange(dofs_state.vel.shape[0], dofs_state.vel.shape[1]): + dofs_state.vel[i_d, i_b] = dofs_state.vel_next[i_d, i_b] - dof_start = links_info.dof_start[I_l] - q_start = links_info.q_start[I_l] - q_end = links_info.q_end[I_l] + for i_q, i_b in ti.ndrange(rigid_global_info.qpos.shape[0], rigid_global_info.qpos.shape[1]): + rigid_global_info.qpos[i_q, i_b] = rigid_global_info.qpos_next[i_q, i_b] - i_j = links_info.joint_start[I_l] - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - joint_type = joints_info.type[I_j] - if joint_type == gs.JOINT_TYPE.FREE: - pos = ti.Vector( - [ - rigid_global_info.qpos[q_start, i_b], - rigid_global_info.qpos[q_start + 1, i_b], - rigid_global_info.qpos[q_start + 2, i_b], - ] - ) - vel = ti.Vector( - [ - dofs_state.vel[dof_start, i_b], - dofs_state.vel[dof_start + 1, i_b], - dofs_state.vel[dof_start + 2, i_b], - ] - ) - pos = pos + vel * rigid_global_info.substep_dt[None] - for j in ti.static(range(3)): - rigid_global_info.qpos[q_start + j, i_b] = pos[j] - if joint_type == gs.JOINT_TYPE.SPHERICAL or joint_type == gs.JOINT_TYPE.FREE: - rot_offset = 3 if joint_type == gs.JOINT_TYPE.FREE else 0 - rot = ti.Vector( - [ - rigid_global_info.qpos[q_start + rot_offset + 0, i_b], - rigid_global_info.qpos[q_start + rot_offset + 1, i_b], - rigid_global_info.qpos[q_start + rot_offset + 2, i_b], - rigid_global_info.qpos[q_start + rot_offset + 3, i_b], - ] - ) - ang = ( - ti.Vector( - [ - dofs_state.vel[dof_start + rot_offset + 0, i_b], - dofs_state.vel[dof_start + rot_offset + 1, i_b], - dofs_state.vel[dof_start + rot_offset + 2, i_b], - ] - ) - * rigid_global_info.substep_dt[None] - ) - qrot = gu.ti_rotvec_to_quat(ang, EPS) - rot = gu.ti_transform_quat_by_quat(qrot, rot) - for j in ti.static(range(4)): - rigid_global_info.qpos[q_start + j + rot_offset, i_b] = rot[j] - else: - for j in range(q_end - q_start): - rigid_global_info.qpos[q_start + j, i_b] = ( - rigid_global_info.qpos[q_start + j, i_b] - + dofs_state.vel[dof_start + j, i_b] * rigid_global_info.substep_dt[None] - ) +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_save_adjoint_cache( + f: ti.int32, + dofs_state: array_class.DofsState, + rigid_global_info: array_class.RigidGlobalInfo, + rigid_adjoint_cache: array_class.RigidAdjointCache, +): + n_dofs = dofs_state.vel.shape[0] + n_qs = rigid_global_info.qpos.shape[0] + _B = dofs_state.vel.shape[1] + + for i_d, i_b in ti.ndrange(n_dofs, _B): + rigid_adjoint_cache.dofs_vel[f, i_d, i_b] = dofs_state.vel[i_d, i_b] + rigid_adjoint_cache.dofs_acc[f, i_d, i_b] = dofs_state.acc[i_d, i_b] + + for i_q, i_b in ti.ndrange(n_qs, _B): + rigid_adjoint_cache.qpos[f, i_q, i_b] = rigid_global_info.qpos[i_q, i_b] + + +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_copy_cartesian_space( + src_dofs_state: array_class.DofsState, + src_links_state: array_class.LinksState, + src_joints_state: array_class.JointsState, + src_geoms_state: array_class.GeomsState, + dst_dofs_state: array_class.DofsState, + dst_links_state: array_class.LinksState, + dst_joints_state: array_class.JointsState, + dst_geoms_state: array_class.GeomsState, +): + # Copy outputs of [kernel_update_cartesian_space] among [dofs, links, joints, geoms] states. This is used to restore + # the outputs that were overwritten if we disabled mujoco compatibility for backward pass. + + # dofs state + for i_d, i_b in ti.ndrange(src_dofs_state.pos.shape[0], src_dofs_state.pos.shape[1]): + # pos, cdof_ang, cdof_vel, cdofvel_ang, cdofvel_vel, cdofd_ang, cdofd_vel + dst_dofs_state.pos[i_d, i_b] = src_dofs_state.pos[i_d, i_b] + dst_dofs_state.cdof_ang[i_d, i_b] = src_dofs_state.cdof_ang[i_d, i_b] + dst_dofs_state.cdof_vel[i_d, i_b] = src_dofs_state.cdof_vel[i_d, i_b] + dst_dofs_state.cdofvel_ang[i_d, i_b] = src_dofs_state.cdofvel_ang[i_d, i_b] + dst_dofs_state.cdofvel_vel[i_d, i_b] = src_dofs_state.cdofvel_vel[i_d, i_b] + dst_dofs_state.cdofd_ang[i_d, i_b] = src_dofs_state.cdofd_ang[i_d, i_b] + dst_dofs_state.cdofd_vel[i_d, i_b] = src_dofs_state.cdofd_vel[i_d, i_b] + + # links state + for i_l, i_b in ti.ndrange(src_links_state.pos.shape[0], src_links_state.pos.shape[1]): + # pos, quat, root_COM, mass_sum, i_pos, i_quat, cinr_inertial, cinr_pos, cinr_quat, cinr_mass, j_pos, j_quat, + # cd_vel, cd_ang + dst_links_state.pos[i_l, i_b] = src_links_state.pos[i_l, i_b] + dst_links_state.quat[i_l, i_b] = src_links_state.quat[i_l, i_b] + dst_links_state.root_COM[i_l, i_b] = src_links_state.root_COM[i_l, i_b] + dst_links_state.mass_sum[i_l, i_b] = src_links_state.mass_sum[i_l, i_b] + dst_links_state.i_pos[i_l, i_b] = src_links_state.i_pos[i_l, i_b] + dst_links_state.i_quat[i_l, i_b] = src_links_state.i_quat[i_l, i_b] + dst_links_state.cinr_inertial[i_l, i_b] = src_links_state.cinr_inertial[i_l, i_b] + dst_links_state.cinr_pos[i_l, i_b] = src_links_state.cinr_pos[i_l, i_b] + dst_links_state.cinr_quat[i_l, i_b] = src_links_state.cinr_quat[i_l, i_b] + dst_links_state.cinr_mass[i_l, i_b] = src_links_state.cinr_mass[i_l, i_b] + dst_links_state.j_pos[i_l, i_b] = src_links_state.j_pos[i_l, i_b] + dst_links_state.j_quat[i_l, i_b] = src_links_state.j_quat[i_l, i_b] + dst_links_state.cd_vel[i_l, i_b] = src_links_state.cd_vel[i_l, i_b] + dst_links_state.cd_ang[i_l, i_b] = src_links_state.cd_ang[i_l, i_b] + + # joints state + for i_j, i_b in ti.ndrange(src_joints_state.xanchor.shape[0], src_joints_state.xanchor.shape[1]): + # xanchor, xaxis + dst_joints_state.xanchor[i_j, i_b] = src_joints_state.xanchor[i_j, i_b] + dst_joints_state.xaxis[i_j, i_b] = src_joints_state.xaxis[i_j, i_b] + + # geoms state + for i_g, i_b in ti.ndrange(src_geoms_state.pos.shape[0], src_geoms_state.pos.shape[1]): + # pos, quat, verts_updated + dst_geoms_state.pos[i_g, i_b] = src_geoms_state.pos[i_g, i_b] + dst_geoms_state.quat[i_g, i_b] = src_geoms_state.quat[i_g, i_b] + dst_geoms_state.verts_updated[i_g, i_b] = src_geoms_state.verts_updated[i_g, i_b] @ti.func @@ -6139,6 +7360,7 @@ def kernel_update_vgeoms_render_T( def kernel_get_state( qpos: ti.types.ndarray(), vel: ti.types.ndarray(), + acc: ti.types.ndarray(), links_pos: ti.types.ndarray(), links_quat: ti.types.ndarray(), i_pos_shift: ti.types.ndarray(), @@ -6164,6 +7386,7 @@ def kernel_get_state( ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_d, i_b in ti.ndrange(n_dofs, _B): vel[i_b, i_d] = dofs_state.vel[i_d, i_b] + acc[i_b, i_d] = dofs_state.acc[i_d, i_b] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_l, i_b in ti.ndrange(n_links, _B): @@ -6184,6 +7407,7 @@ def kernel_get_state( def kernel_set_state( qpos: ti.types.ndarray(), dofs_vel: ti.types.ndarray(), + dofs_acc: ti.types.ndarray(), links_pos: ti.types.ndarray(), links_quat: ti.types.ndarray(), i_pos_shift: ti.types.ndarray(), @@ -6209,6 +7433,7 @@ def kernel_set_state( ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_d, i_b_ in ti.ndrange(n_dofs, envs_idx.shape[0]): dofs_state.vel[i_d, envs_idx[i_b_]] = dofs_vel[envs_idx[i_b_], i_d] + dofs_state.acc[i_d, envs_idx[i_b_]] = dofs_acc[envs_idx[i_b_], i_d] dofs_state.ctrl_force[i_d, envs_idx[i_b_]] = gs.ti_float(0.0) dofs_state.ctrl_mode[i_d, envs_idx[i_b_]] = gs.CTRL_MODE.FORCE @@ -6226,6 +7451,39 @@ def kernel_set_state( geoms_state.friction_ratio[i_l, envs_idx[i_b_]] = friction_ratio[envs_idx[i_b_], i_l] +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_get_state_grad( + qpos_grad: ti.types.ndarray(), + vel_grad: ti.types.ndarray(), + links_pos_grad: ti.types.ndarray(), + links_quat_grad: ti.types.ndarray(), + links_state: array_class.LinksState, + dofs_state: array_class.DofsState, + geoms_state: array_class.GeomsState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + n_qs = qpos_grad.shape[1] + n_dofs = vel_grad.shape[1] + n_links = links_pos_grad.shape[1] + _B = qpos_grad.shape[0] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_q, i_b in ti.ndrange(n_qs, _B): + rigid_global_info.qpos.grad[i_q, i_b] += qpos_grad[i_b, i_q] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_d, i_b in ti.ndrange(n_dofs, _B): + dofs_state.vel.grad[i_d, i_b] += vel_grad[i_b, i_d] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l, i_b in ti.ndrange(n_links, _B): + for j in ti.static(range(3)): + links_state.pos.grad[i_l, i_b][j] += links_pos_grad[i_b, i_l, j] + for j in ti.static(range(4)): + links_state.quat.grad[i_l, i_b][j] += links_quat_grad[i_b, i_l, j] + + @ti.kernel(fastcache=gs.use_fastcache) def kernel_set_links_pos( relative: ti.i32, @@ -6261,6 +7519,35 @@ def kernel_set_links_pos( ) +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_set_links_pos_grad( + relative: ti.i32, + pos_grad: ti.types.ndarray(), + links_idx: ti.types.ndarray(), + envs_idx: ti.types.ndarray(), + links_info: array_class.LinksInfo, + links_state: array_class.LinksState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l_, i_b_ in ti.ndrange(links_idx.shape[0], envs_idx.shape[0]): + i_b = envs_idx[i_b_] + i_l = links_idx[i_l_] + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + + if links_info.parent_idx[I_l] == -1 and links_info.is_fixed[I_l]: + for j in ti.static(range(3)): + pos_grad[i_b_, i_l_, j] = links_state.pos.grad[i_l, i_b][j] + links_state.pos.grad[i_l, i_b][j] = 0.0 + else: + q_start = links_info.q_start[I_l] + for j in ti.static(range(3)): + pos_grad[i_b_, i_l_, j] = rigid_global_info.qpos.grad[q_start + j, i_b] + rigid_global_info.qpos.grad[q_start + j, i_b] = 0.0 + + @ti.kernel(fastcache=gs.use_fastcache) def kernel_set_links_quat( relative: ti.i32, @@ -6315,6 +7602,35 @@ def kernel_set_links_quat( rigid_global_info.qpos[q_start + j + 3, i_b] = quat[i_b_, i_l_, j] +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_set_links_quat_grad( + relative: ti.i32, + quat_grad: ti.types.ndarray(), + links_idx: ti.types.ndarray(), + envs_idx: ti.types.ndarray(), + links_info: array_class.LinksInfo, + links_state: array_class.LinksState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l_, i_b_ in ti.ndrange(links_idx.shape[0], envs_idx.shape[0]): + i_b = envs_idx[i_b_] + i_l = links_idx[i_l_] + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + + if links_info.parent_idx[I_l] == -1 and links_info.is_fixed[I_l]: + for j in ti.static(range(4)): + quat_grad[i_b_, i_l_, j] = links_state.quat.grad[i_l, i_b][j] + links_state.quat.grad[i_l, i_b][j] = 0.0 + else: + q_start = links_info.q_start[I_l] + for j in ti.static(range(4)): + quat_grad[i_b_, i_l_, j] = rigid_global_info.qpos.grad[q_start + j + 3, i_b] + rigid_global_info.qpos.grad[q_start + j + 3, i_b] = 0.0 + + @ti.kernel(fastcache=gs.use_fastcache) def kernel_set_links_mass_shift( mass: ti.types.ndarray(), @@ -6603,6 +7919,20 @@ def kernel_set_dofs_velocity( dofs_state.vel[dofs_idx[i_d_], envs_idx[i_b_]] = velocity[i_b_, i_d_] +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_set_dofs_velocity_grad( + velocity_grad: ti.types.ndarray(), + dofs_idx: ti.types.ndarray(), + envs_idx: ti.types.ndarray(), + dofs_state: array_class.DofsState, + static_rigid_sim_config: ti.template(), +): + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) + for i_d_, i_b_ in ti.ndrange(dofs_idx.shape[0], envs_idx.shape[0]): + velocity_grad[i_b_, i_d_] = dofs_state.vel.grad[dofs_idx[i_d_], envs_idx[i_b_]] + dofs_state.vel.grad[dofs_idx[i_d_], envs_idx[i_b_]] = 0.0 + + @ti.kernel(fastcache=gs.use_fastcache) def kernel_set_dofs_zero_velocity( dofs_idx: ti.types.ndarray(), diff --git a/genesis/engine/states/__init__.py b/genesis/engine/states/__init__.py index e69de29bb..2e316a1de 100644 --- a/genesis/engine/states/__init__.py +++ b/genesis/engine/states/__init__.py @@ -0,0 +1,2 @@ +from .solvers import * +from .cache import * diff --git a/genesis/engine/states/entities.py b/genesis/engine/states/entities.py index b0135547a..b6ee1f6dd 100644 --- a/genesis/engine/states/entities.py +++ b/genesis/engine/states/entities.py @@ -188,3 +188,41 @@ def vel(self): @property def active(self): return self._active + + +class RigidEntityState(RBC): + """ + Dynamic state queried from a genesis RigidEntity. + """ + + def __init__(self, entity, s_global): + self._entity = entity + self._s_global = s_global + + num_batch = self._entity._solver._B + requires_grad = self._entity.scene.requires_grad + scene = self._entity.scene + self._pos = gs.zeros((num_batch, 3), dtype=float, requires_grad=requires_grad, scene=scene) + self._quat = gs.zeros((num_batch, 4), dtype=float, requires_grad=requires_grad, scene=scene) + + def serializable(self): + self._entity = None + + self._pos = self._pos.detach() + self._quat = self._quat.detach() + + @property + def entity(self): + return self._entity + + @property + def s_global(self): + return self._s_global + + @property + def pos(self): + return self._pos + + @property + def quat(self): + return self._quat diff --git a/genesis/engine/states/solvers.py b/genesis/engine/states/solvers.py index c1182c501..5135346e9 100644 --- a/genesis/engine/states/solvers.py +++ b/genesis/engine/states/solvers.py @@ -48,9 +48,11 @@ class RigidSolverState: Dynamic state queried from a RigidSolver. """ - def __init__(self, scene): + def __init__(self, scene, s_global): self.scene = scene + self._s_global = s_global + _B = scene.sim.rigid_solver._B args = { "dtype": gs.tc_float, @@ -59,6 +61,7 @@ def __init__(self, scene): } self.qpos = gs.zeros((_B, scene.sim.rigid_solver.n_qs), **args) self.dofs_vel = gs.zeros((_B, scene.sim.rigid_solver.n_dofs), **args) + self.dofs_acc = gs.zeros((_B, scene.sim.rigid_solver.n_dofs), **args) self.links_pos = gs.zeros((_B, scene.sim.rigid_solver.n_links, 3), **args) self.links_quat = gs.zeros((_B, scene.sim.rigid_solver.n_links, 4), **args) self.i_pos_shift = gs.zeros((_B, scene.sim.rigid_solver.n_links, 3), **args) @@ -75,6 +78,10 @@ def serializable(self): self.mass_shift = self.mass_shift.detach() self.friction_ratio = self.friction_ratio.detach() + @property + def s_global(self): + return self._s_global + class AvatarSolverState: """ diff --git a/genesis/utils/array_class.py b/genesis/utils/array_class.py index cd46f92e2..663c529e6 100644 --- a/genesis/utils/array_class.py +++ b/genesis/utils/array_class.py @@ -18,6 +18,7 @@ V_MAT = ti.Matrix.ndarray if gs.use_ndarray else ti.Matrix.field DATA_ORIENTED = partial(dataclasses.dataclass, frozen=True) if gs.use_ndarray else ti.data_oriented +PLACEHOLDER = V(dtype=gs.ti_float, shape=()) def maybe_shape(shape, is_on): @@ -73,6 +74,7 @@ def V_SCALAR_FROM(dtype, value): @DATA_ORIENTED class StructRigidGlobalInfo(metaclass=BASE_METACLASS): + # *_bw: Cache for backward pass n_awake_dofs: V_ANNOTATION awake_dofs: V_ANNOTATION n_awake_entities: V_ANNOTATION @@ -81,11 +83,13 @@ class StructRigidGlobalInfo(metaclass=BASE_METACLASS): awake_links: V_ANNOTATION qpos0: V_ANNOTATION qpos: V_ANNOTATION + qpos_next: V_ANNOTATION links_T: V_ANNOTATION envs_offset: V_ANNOTATION geoms_init_AABB: V_ANNOTATION mass_mat: V_ANNOTATION mass_mat_L: V_ANNOTATION + mass_mat_L_bw: V_ANNOTATION mass_mat_D_inv: V_ANNOTATION mass_mat_mask: V_ANNOTATION meaninertia: V_ANNOTATION @@ -108,6 +112,7 @@ class StructRigidGlobalInfo(metaclass=BASE_METACLASS): def get_rigid_global_info(solver): _B = solver._B + requires_grad = solver._requires_grad return StructRigidGlobalInfo( n_awake_dofs=V(dtype=gs.ti_int, shape=(_B,)), @@ -117,13 +122,15 @@ def get_rigid_global_info(solver): n_awake_links=V(dtype=gs.ti_int, shape=(_B,)), awake_links=V(dtype=gs.ti_int, shape=(solver.n_links_, _B)), qpos0=V(dtype=gs.ti_float, shape=(solver.n_qs_, _B)), - qpos=V(dtype=gs.ti_float, shape=(solver.n_qs_, _B)), + qpos=V(dtype=gs.ti_float, shape=(solver.n_qs_, _B), needs_grad=requires_grad), + qpos_next=V(dtype=gs.ti_float, shape=(solver.n_qs_, _B), needs_grad=requires_grad), links_T=V_MAT(n=4, m=4, dtype=gs.ti_float, shape=(solver.n_links_,)), envs_offset=V_VEC(3, dtype=gs.ti_float, shape=(_B,)), geoms_init_AABB=V_VEC(3, dtype=gs.ti_float, shape=(solver.n_geoms_, 8)), - mass_mat=V(dtype=gs.ti_float, shape=(solver.n_dofs_, solver.n_dofs_, _B)), - mass_mat_L=V(dtype=gs.ti_float, shape=(solver.n_dofs_, solver.n_dofs_, _B)), - mass_mat_D_inv=V(dtype=gs.ti_float, shape=(solver.n_dofs_, _B)), + mass_mat=V(dtype=gs.ti_float, shape=(solver.n_dofs_, solver.n_dofs_, _B), needs_grad=requires_grad), + mass_mat_L=V(dtype=gs.ti_float, shape=(solver.n_dofs_, solver.n_dofs_, _B), needs_grad=requires_grad), + mass_mat_L_bw=V(dtype=gs.ti_float, shape=(2, solver.n_dofs_, solver.n_dofs_, _B), needs_grad=requires_grad), + mass_mat_D_inv=V(dtype=gs.ti_float, shape=(solver.n_dofs_, _B), needs_grad=requires_grad), mass_mat_mask=V(dtype=gs.ti_bool, shape=(solver.n_entities_, _B)), meaninertia=V(dtype=gs.ti_float, shape=(_B,)), mass_parent_mask=V(dtype=gs.ti_float, shape=(solver.n_dofs_, solver.n_dofs_)), @@ -1120,6 +1127,7 @@ def get_dofs_info(solver): @DATA_ORIENTED class StructDofsState(metaclass=BASE_METACLASS): + # *_bw: Cache to avoid overwriting for backward pass force: V_ANNOTATION qf_bias: V_ANNOTATION qf_passive: V_ANNOTATION @@ -1129,8 +1137,11 @@ class StructDofsState(metaclass=BASE_METACLASS): pos: V_ANNOTATION vel: V_ANNOTATION vel_prev: V_ANNOTATION + vel_next: V_ANNOTATION acc: V_ANNOTATION + acc_bw: V_ANNOTATION acc_smooth: V_ANNOTATION + acc_smooth_bw: V_ANNOTATION qf_smooth: V_ANNOTATION qf_constraint: V_ANNOTATION cdof_ang: V_ANNOTATION @@ -1150,32 +1161,36 @@ class StructDofsState(metaclass=BASE_METACLASS): def get_dofs_state(solver): shape = (solver.n_dofs_, solver._B) + requires_grad = solver._requires_grad return StructDofsState( - force=V(dtype=gs.ti_float, shape=shape), - qf_bias=V(dtype=gs.ti_float, shape=shape), - qf_passive=V(dtype=gs.ti_float, shape=shape), - qf_actuator=V(dtype=gs.ti_float, shape=shape), - qf_applied=V(dtype=gs.ti_float, shape=shape), - act_length=V(dtype=gs.ti_float, shape=shape), - pos=V(dtype=gs.ti_float, shape=shape), - vel=V(dtype=gs.ti_float, shape=shape), - vel_prev=V(dtype=gs.ti_float, shape=shape), - acc=V(dtype=gs.ti_float, shape=shape), - acc_smooth=V(dtype=gs.ti_float, shape=shape), - qf_smooth=V(dtype=gs.ti_float, shape=shape), - qf_constraint=V(dtype=gs.ti_float, shape=shape), - cdof_ang=V(dtype=gs.ti_vec3, shape=shape), - cdof_vel=V(dtype=gs.ti_vec3, shape=shape), - cdofvel_ang=V(dtype=gs.ti_vec3, shape=shape), - cdofvel_vel=V(dtype=gs.ti_vec3, shape=shape), - cdofd_ang=V(dtype=gs.ti_vec3, shape=shape), - cdofd_vel=V(dtype=gs.ti_vec3, shape=shape), - f_vel=V(dtype=gs.ti_vec3, shape=shape), - f_ang=V(dtype=gs.ti_vec3, shape=shape), - ctrl_force=V(dtype=gs.ti_float, shape=shape), - ctrl_pos=V(dtype=gs.ti_float, shape=shape), - ctrl_vel=V(dtype=gs.ti_float, shape=shape), + force=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + qf_bias=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + qf_passive=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + qf_actuator=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + qf_applied=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + act_length=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + pos=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + vel=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + vel_prev=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + vel_next=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + acc=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + acc_bw=V(dtype=gs.ti_float, shape=(2, solver.n_dofs_, solver._B), needs_grad=requires_grad), + acc_smooth=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + acc_smooth_bw=V(dtype=gs.ti_float, shape=(2, solver.n_dofs_, solver._B), needs_grad=requires_grad), + qf_smooth=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + qf_constraint=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + cdof_ang=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cdof_vel=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cdofvel_ang=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cdofvel_vel=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cdofd_ang=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cdofd_vel=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + f_vel=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + f_ang=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + ctrl_force=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + ctrl_pos=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + ctrl_vel=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), ctrl_mode=V(dtype=gs.ti_int, shape=shape), hibernated=V(dtype=gs.ti_int, shape=shape), ) @@ -1186,6 +1201,7 @@ def get_dofs_state(solver): @DATA_ORIENTED class StructLinksState(metaclass=BASE_METACLASS): + # *_bw: Cache to avoid overwriting for backward pass cinr_inertial: V_ANNOTATION cinr_pos: V_ANNOTATION cinr_quat: V_ANNOTATION @@ -1198,16 +1214,24 @@ class StructLinksState(metaclass=BASE_METACLASS): cdd_ang: V_ANNOTATION pos: V_ANNOTATION quat: V_ANNOTATION + pos_bw: V_ANNOTATION + quat_bw: V_ANNOTATION i_pos: V_ANNOTATION + i_pos_bw: V_ANNOTATION i_quat: V_ANNOTATION j_pos: V_ANNOTATION j_quat: V_ANNOTATION + j_pos_bw: V_ANNOTATION + j_quat_bw: V_ANNOTATION j_vel: V_ANNOTATION j_ang: V_ANNOTATION cd_ang: V_ANNOTATION cd_vel: V_ANNOTATION + cd_ang_bw: V_ANNOTATION + cd_vel_bw: V_ANNOTATION mass_sum: V_ANNOTATION root_COM: V_ANNOTATION # COM of the kinematic tree + root_COM_bw: V_ANNOTATION mass_shift: V_ANNOTATION i_pos_shift: V_ANNOTATION cacc_ang: V_ANNOTATION @@ -1223,42 +1247,54 @@ class StructLinksState(metaclass=BASE_METACLASS): def get_links_state(solver): + max_n_joints_per_link = solver._static_rigid_sim_config.max_n_joints_per_link shape = (solver.n_links_, solver._B) + shape_bw = (solver.n_links_, max_n_joints_per_link + 1, solver._B) + + requires_grad = solver._requires_grad return StructLinksState( - cinr_inertial=V(dtype=gs.ti_mat3, shape=shape), - cinr_pos=V(dtype=gs.ti_vec3, shape=shape), - cinr_quat=V(dtype=gs.ti_vec4, shape=shape), - cinr_mass=V(dtype=gs.ti_float, shape=shape), - crb_inertial=V(dtype=gs.ti_mat3, shape=shape), - crb_pos=V(dtype=gs.ti_vec3, shape=shape), - crb_quat=V(dtype=gs.ti_vec4, shape=shape), - crb_mass=V(dtype=gs.ti_float, shape=shape), - cdd_vel=V(dtype=gs.ti_vec3, shape=shape), - cdd_ang=V(dtype=gs.ti_vec3, shape=shape), - pos=V(dtype=gs.ti_vec3, shape=shape), - quat=V(dtype=gs.ti_vec4, shape=shape), - i_pos=V(dtype=gs.ti_vec3, shape=shape), - i_quat=V(dtype=gs.ti_vec4, shape=shape), - j_pos=V(dtype=gs.ti_vec3, shape=shape), - j_quat=V(dtype=gs.ti_vec4, shape=shape), - j_vel=V(dtype=gs.ti_vec3, shape=shape), - j_ang=V(dtype=gs.ti_vec3, shape=shape), - cd_ang=V(dtype=gs.ti_vec3, shape=shape), - cd_vel=V(dtype=gs.ti_vec3, shape=shape), - mass_sum=V(dtype=gs.ti_float, shape=shape), - root_COM=V(dtype=gs.ti_vec3, shape=shape), - mass_shift=V(dtype=gs.ti_float, shape=shape), - i_pos_shift=V(dtype=gs.ti_vec3, shape=shape), - cacc_ang=V(dtype=gs.ti_vec3, shape=shape), - cacc_lin=V(dtype=gs.ti_vec3, shape=shape), - cfrc_ang=V(dtype=gs.ti_vec3, shape=shape), - cfrc_vel=V(dtype=gs.ti_vec3, shape=shape), - cfrc_applied_ang=V(dtype=gs.ti_vec3, shape=shape), - cfrc_applied_vel=V(dtype=gs.ti_vec3, shape=shape), - cfrc_coupling_ang=V(dtype=gs.ti_vec3, shape=shape), - cfrc_coupling_vel=V(dtype=gs.ti_vec3, shape=shape), - contact_force=V(dtype=gs.ti_vec3, shape=shape), + cinr_inertial=V(dtype=gs.ti_mat3, shape=shape, needs_grad=requires_grad), + cinr_pos=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cinr_quat=V(dtype=gs.ti_vec4, shape=shape, needs_grad=requires_grad), + cinr_mass=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + crb_inertial=V(dtype=gs.ti_mat3, shape=shape, needs_grad=requires_grad), + crb_pos=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + crb_quat=V(dtype=gs.ti_vec4, shape=shape, needs_grad=requires_grad), + crb_mass=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + cdd_vel=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cdd_ang=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + pos=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + quat=V(dtype=gs.ti_vec4, shape=shape, needs_grad=requires_grad), + pos_bw=V(dtype=gs.ti_vec3, shape=shape_bw, needs_grad=requires_grad), + quat_bw=V(dtype=gs.ti_vec4, shape=shape_bw, needs_grad=requires_grad), + i_pos=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + i_pos_bw=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + i_quat=V(dtype=gs.ti_vec4, shape=shape, needs_grad=requires_grad), + j_pos=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + j_quat=V(dtype=gs.ti_vec4, shape=shape, needs_grad=requires_grad), + j_pos_bw=V(dtype=gs.ti_vec3, shape=shape_bw, needs_grad=requires_grad), + j_quat_bw=V(dtype=gs.ti_vec4, shape=shape_bw, needs_grad=requires_grad), + j_vel=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + j_ang=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cd_ang=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cd_vel=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cd_ang_bw=V(dtype=gs.ti_vec3, shape=shape_bw, needs_grad=requires_grad), + cd_vel_bw=V(dtype=gs.ti_vec3, shape=shape_bw, needs_grad=requires_grad), + mass_sum=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + root_COM=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + root_COM_bw=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + mass_shift=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + i_pos_shift=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cacc_ang=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cacc_lin=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cfrc_ang=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cfrc_vel=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cfrc_applied_ang=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cfrc_applied_vel=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cfrc_coupling_ang=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cfrc_coupling_vel=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + contact_force=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), hibernated=V(dtype=gs.ti_int, shape=shape), ) @@ -1348,10 +1384,11 @@ class StructJointsState(metaclass=BASE_METACLASS): def get_joints_state(solver): shape = (solver.n_joints_, solver._B) + requires_grad = solver._requires_grad return StructJointsState( - xanchor=V(dtype=gs.ti_vec3, shape=shape), - xaxis=V(dtype=gs.ti_vec3, shape=shape), + xanchor=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + xaxis=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), ) @@ -1703,6 +1740,30 @@ def get_entities_state(solver): ) +# =========================================== RigidAdjointCache =========================================== +@DATA_ORIENTED +class StructRigidAdjointCache(metaclass=BASE_METACLASS): + # This cache stores intermediate values during rigid body simulation to use Taichi's AD. Taichi's AD requires + # us not to overwrite the values that have been read during the forward pass, so we need to store the intemediate + # values in this cache to avoid overwriting them. Specifically, after we compute next frame's qpos, dofs_vel, and + # dofs_acc, we need to store them in this cache because we overwrite the values in the next frame. See how + # [kernel_save_adjoint_cache] is used in [rigid_solver_decomp.py] to store the values in this cache. + qpos: V_ANNOTATION + dofs_vel: V_ANNOTATION + dofs_acc: V_ANNOTATION + + +def get_rigid_adjoint_cache(solver): + substeps_local = solver._sim.substeps_local + requires_grad = solver._requires_grad + + return StructRigidAdjointCache( + qpos=V(dtype=gs.ti_float, shape=(substeps_local + 1, solver.n_qs_, solver._B), needs_grad=requires_grad), + dofs_vel=V(dtype=gs.ti_float, shape=(substeps_local + 1, solver.n_dofs_, solver._B), needs_grad=requires_grad), + dofs_acc=V(dtype=gs.ti_float, shape=(substeps_local + 1, solver.n_dofs_, solver._B), needs_grad=requires_grad), + ) + + # =================================== StructRigidSimStaticConfig =================================== @@ -1759,6 +1820,14 @@ def __init__(self, solver): self.entities_info = get_entities_info(solver) self.entities_state = get_entities_state(solver) + if solver._static_rigid_sim_config.requires_grad: + # Data structures required for backward pass + self.dofs_state_adjoint_cache = get_dofs_state(solver) + self.links_state_adjoint_cache = get_links_state(solver) + self.joints_state_adjoint_cache = get_joints_state(solver) + self.geoms_state_adjoint_cache = get_geoms_state(solver) + + self.rigid_adjoint_cache = get_rigid_adjoint_cache(solver) self.errno = V_SCALAR_FROM(dtype=gs.ti_int, value=0) @@ -1794,3 +1863,4 @@ def __init__(self, solver): SDFInfo = StructSDFInfo if gs.use_ndarray else ti.template() ContactIslandState = StructContactIslandState if gs.use_ndarray else ti.template() DiffContactInput = StructDiffContactInput if gs.use_ndarray else ti.template() +RigidAdjointCache = StructRigidAdjointCache if gs.use_ndarray else ti.template() diff --git a/genesis/utils/geom.py b/genesis/utils/geom.py index 16a1a1d6a..f387fae55 100644 --- a/genesis/utils/geom.py +++ b/genesis/utils/geom.py @@ -83,8 +83,12 @@ def ti_rotvec_to_R(rotvec, eps): def ti_rotvec_to_quat(rotvec, eps): quat = ti.Vector.zero(gs.ti_float, 4) - theta = rotvec.norm() - if theta > eps: + # We need to use [norm_sqr] instead of [norm] to avoid nan gradients in the backward pass. Even when theta = 0, + # the gradient of [norm] operation is computed and used (note that the gradient becomes NaN when theta = 0). This + # is seemd to be a bug in Taichi autodiff @TODO: change back after the bug is fixed. + thetasq = rotvec.norm_sqr() + if thetasq > (eps**2): + theta = ti.sqrt(thetasq) theta_half = 0.5 * theta c, s = ti.cos(theta_half), ti.sin(theta_half) diff --git a/genesis/utils/path_planning.py b/genesis/utils/path_planning.py index 615195b3f..fe3cc1df6 100644 --- a/genesis/utils/path_planning.py +++ b/genesis/utils/path_planning.py @@ -426,6 +426,7 @@ def _kernel_rrt_step1( entities_info, rigid_global_info, self._solver._static_rigid_sim_config, + is_backward=False, ) gs.engine.solvers.rigid.rigid_solver_decomp.func_update_geoms( i_b, @@ -436,6 +437,7 @@ def _kernel_rrt_step1( rigid_global_info, self._solver._static_rigid_sim_config, force_update_fixed_geoms=False, + is_backward=False, ) @ti.kernel @@ -791,6 +793,7 @@ def _kernel_rrt_connect_step1( entities_info, rigid_global_info, self._solver._static_rigid_sim_config, + is_backward=False, ) gs.engine.solvers.rigid.rigid_solver_decomp.func_update_geoms( i_b, @@ -801,6 +804,7 @@ def _kernel_rrt_connect_step1( rigid_global_info, self._solver._static_rigid_sim_config, force_update_fixed_geoms=False, + is_backward=False, ) @ti.kernel diff --git a/tests/test_grad.py b/tests/test_grad.py index 05d504f28..179bb0c24 100644 --- a/tests/test_grad.py +++ b/tests/test_grad.py @@ -207,6 +207,7 @@ def constraint_solver_resolve(): rigid_global_info=rigid_solver._rigid_global_info, static_rigid_sim_config=rigid_solver._static_rigid_sim_config, contact_island_state=constraint_solver.contact_island.contact_island_state, + is_backward=False, ) constraint_solver.add_equality_constraints() rigid_solver.collider.detection() @@ -265,7 +266,7 @@ def compute_loss(input_mass, input_jac, input_aref, input_efc_D, input_force): ### Compute directional derivatives along random directions FD_EPS = 1e-3 - TRIALS = 100 + TRIALS = 200 for dL_dx, x_type in ( (dL_dforce, "force"), @@ -409,3 +410,91 @@ def test_differentiable_push(precision, show_viewer): for v_i in v_list[:-1]: assert (v_i.grad.abs() > gs.EPS).any() assert (v_list[-1].grad.abs() < gs.EPS).all() + + +@pytest.mark.required +@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) +def test_differentiable_rigid(show_viewer): + dt = 1e-2 + horizon = 100 + substeps = 1 + goal_pos = gs.tensor([0.7, 1.0, 0.05]) + goal_quat = gs.tensor([0.3, 0.2, 0.1, 0.9]) + goal_quat = goal_quat / torch.norm(goal_quat, dim=-1, keepdim=True) + + scene = gs.Scene( + sim_options=gs.options.SimOptions(dt=dt, substeps=substeps, requires_grad=True, gravity=(0, 0, -1)), + rigid_options=gs.options.RigidOptions( + enable_collision=False, + enable_self_collision=False, + enable_joint_limit=False, + disable_constraint=True, + use_contact_island=False, + use_hibernation=False, + ), + viewer_options=gs.options.ViewerOptions( + camera_pos=(2.5, -0.15, 2.42), + camera_lookat=(0.5, 0.5, 0.1), + ), + show_viewer=show_viewer, + ) + + box = scene.add_entity( + gs.morphs.Box( + pos=(0, 0, 0), + size=(0.1, 0.1, 0.2), + ), + surface=gs.surfaces.Default( + color=(0.9, 0.0, 0.0, 1.0), + ), + ) + if show_viewer: + target = scene.add_entity( + gs.morphs.Box( + pos=goal_pos, + quat=goal_quat, + size=(0.1, 0.1, 0.2), + ), + surface=gs.surfaces.Default( + color=(0.0, 0.9, 0.0, 0.5), + ), + ) + + scene.build() + + num_iter = 200 + lr = 1e-2 + + init_pos = gs.tensor([0.3, 0.1, 0.28], requires_grad=True) + init_quat = gs.tensor([1.0, 0.0, 0.0, 0.0], requires_grad=True) + optimizer = torch.optim.Adam([init_pos, init_quat], lr=lr) + + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_iter, eta_min=1e-3) + + for iter in range(num_iter): + scene.reset() + + box.set_pos(init_pos) + box.set_quat(init_quat) + + loss = 0 + for i in range(horizon): + scene.step() + if show_viewer: + target.set_pos(goal_pos) + target.set_quat(goal_quat) + + box_state = box.get_state() + box_pos = box_state.pos + box_quat = box_state.quat + loss = torch.abs(box_pos - goal_pos).sum() + torch.abs(box_quat - goal_quat).sum() + + optimizer.zero_grad() + loss.backward() # this lets gradient flow all the way back to tensor input + optimizer.step() + scheduler.step() + + with torch.no_grad(): + init_quat.data = init_quat / torch.norm(init_quat, dim=-1, keepdim=True) + + assert_allclose(loss, 0.0, atol=1e-2)