1717from genesis .utils import mjcf as mju
1818from genesis .utils import terrain as tu
1919from genesis .utils import urdf as uu
20- from genesis .utils .misc import ALLOCATE_TENSOR_WARNING , DeprecationError , ti_to_torch
20+ from genesis .utils .misc import ALLOCATE_TENSOR_WARNING , DeprecationError , ti_to_torch , to_gs_tensor
21+ from genesis .engine .states .entities import RigidEntityState
2122
2223from ..base_entity import Entity
2324from .rigid_equality import RigidEquality
@@ -95,6 +96,23 @@ def __init__(
9596
9697 self ._load_model ()
9798
99+ self .init_tgt_vars ()
100+ self .init_ckpt ()
101+
102+ def init_tgt_keys (self ):
103+
104+ self ._tgt_keys = ["pos" , "quat" , "qpos" , "dofs_velocity" ]
105+
106+ def init_tgt_vars (self ):
107+
108+ # temp variable to store targets for next step
109+ self ._tgt = []
110+ self ._tgt_buffer = []
111+ self .init_tgt_keys ()
112+
113+ def init_ckpt (self ):
114+ self ._ckpt = dict ()
115+
98116 def _load_model (self ):
99117 self ._links = gs .List ()
100118 self ._joints = gs .List ()
@@ -1445,6 +1463,7 @@ def _kernel_forward_kinematics(
14451463 self ._solver .entities_info ,
14461464 self ._solver ._rigid_global_info ,
14471465 self ._solver ._static_rigid_sim_config ,
1466+ False ,
14481467 )
14491468
14501469 ti .loop_config (serialize = self ._solver ._para_level < gs .PARA_LEVEL .PARTIAL )
@@ -1471,6 +1490,7 @@ def _kernel_forward_kinematics(
14711490 self ._solver .entities_info ,
14721491 self ._solver ._rigid_global_info ,
14731492 self ._solver ._static_rigid_sim_config ,
1493+ False ,
14741494 )
14751495
14761496 # ------------------------------------------------------------------------------------
@@ -1605,6 +1625,128 @@ def plan_path(
16051625 # ------------------------------------------------------------------------------------
16061626 # ---------------------------------- control & io ------------------------------------
16071627 # ------------------------------------------------------------------------------------
1628+ def process_input (self , in_backward = False ):
1629+ if in_backward :
1630+ # use negative index because buffer length might not be full
1631+ index = self ._sim .cur_step_local - self ._sim ._steps_local
1632+ self ._tgt = self ._tgt_buffer [index ].copy ()
1633+ else :
1634+ self ._tgt_buffer .append (self ._tgt .copy ())
1635+
1636+ # Apply targets sequentially
1637+ _tgt = self ._tgt .copy ()
1638+ for tgt in _tgt :
1639+ k = tgt ["key" ]
1640+ assert k in self ._tgt_keys , f"Invalid target key: { k } not in { self ._tgt_keys } "
1641+
1642+ # We do not need zero velocity here because if it was true, [set_dofs_velocity] from zero_velocity would
1643+ # be in [tgt]
1644+ zero_velocity = False
1645+ if k == "pos" :
1646+ _pos = tgt ["pos" ]
1647+ _envs_idx = tgt ["envs_idx" ]
1648+ _relative = tgt ["relative" ]
1649+ _unsafe = tgt ["unsafe" ]
1650+
1651+ self .set_pos (_pos , envs_idx = _envs_idx , relative = _relative , zero_velocity = zero_velocity , unsafe = _unsafe )
1652+ elif k == "quat" :
1653+ _quat = tgt ["quat" ]
1654+ _envs_idx = tgt ["envs_idx" ]
1655+ _relative = tgt ["relative" ]
1656+ _unsafe = tgt ["unsafe" ]
1657+
1658+ self .set_quat (
1659+ _quat , envs_idx = _envs_idx , relative = _relative , zero_velocity = zero_velocity , unsafe = _unsafe
1660+ )
1661+ elif k == "qpos" :
1662+ _qpos = tgt ["qpos" ]
1663+ _qs_idx_local = tgt ["qs_idx_local" ]
1664+ _envs_idx = tgt ["envs_idx" ]
1665+ _unsafe = tgt ["unsafe" ]
1666+
1667+ self .set_qpos (
1668+ _qpos , qs_idx_local = _qs_idx_local , envs_idx = _envs_idx , zero_velocity = zero_velocity , unsafe = _unsafe
1669+ )
1670+ elif k == "dofs_velocity" :
1671+ _velocity = tgt ["velocity" ]
1672+ _dofs_idx_local = tgt ["dofs_idx_local" ]
1673+ _envs_idx = tgt ["envs_idx" ]
1674+ _unsafe = tgt ["unsafe" ]
1675+
1676+ self .set_dofs_velocity (_velocity , dofs_idx_local = _dofs_idx_local , envs_idx = _envs_idx , unsafe = _unsafe )
1677+
1678+ self ._tgt = []
1679+
1680+ def process_input_grad (self ):
1681+ index = self ._sim .cur_step_local - self ._sim ._steps_local
1682+ _tgt = self ._tgt_buffer [index ].copy ()
1683+
1684+ for tgt in reversed (_tgt ):
1685+ k = tgt ["key" ]
1686+ assert k in self ._tgt_keys , f"Invalid target key: { k } not in { self ._tgt_keys } "
1687+ if k == "pos" :
1688+ _pos = tgt ["pos" ]
1689+ _envs_idx = tgt ["envs_idx" ]
1690+ _relative = tgt ["relative" ]
1691+ _unsafe = tgt ["unsafe" ]
1692+
1693+ if _pos is not None and _pos .requires_grad :
1694+ _pos ._backward_from_ti (self .set_pos_grad , _envs_idx , _relative , _unsafe )
1695+
1696+ elif k == "quat" :
1697+ _quat = tgt ["quat" ]
1698+ _envs_idx = tgt ["envs_idx" ]
1699+ _relative = tgt ["relative" ]
1700+ _unsafe = tgt ["unsafe" ]
1701+
1702+ if _quat is not None and _quat .requires_grad :
1703+ _quat ._backward_from_ti (self .set_quat_grad , _envs_idx , _relative , _unsafe )
1704+
1705+ elif k == "qpos" :
1706+ _qpos = tgt ["qpos" ]
1707+ _qs_idx_local = tgt ["qs_idx_local" ]
1708+ _envs_idx = tgt ["envs_idx" ]
1709+ _unsafe = tgt ["unsafe" ]
1710+
1711+ if _qpos is not None and _qpos .requires_grad :
1712+ # TODO: Not implemented yet
1713+ raise NotImplementedError ("Backward pass for set_qpos_grad is not implemented yet." )
1714+
1715+ elif k == "dofs_velocity" :
1716+ _velocity = tgt ["velocity" ]
1717+ _dofs_idx_local = tgt ["dofs_idx_local" ]
1718+ _envs_idx = tgt ["envs_idx" ]
1719+ _unsafe = tgt ["unsafe" ]
1720+
1721+ if _velocity is not None and _velocity .requires_grad :
1722+ _velocity ._backward_from_ti (self .set_dofs_velocity_grad , _dofs_idx_local , _envs_idx , _unsafe )
1723+
1724+ def save_ckpt (self , ckpt_name ):
1725+ if ckpt_name not in self ._ckpt :
1726+ self ._ckpt [ckpt_name ] = {}
1727+ self ._ckpt [ckpt_name ]["_tgt_buffer" ] = self ._tgt_buffer .copy ()
1728+ self ._tgt_buffer .clear ()
1729+
1730+ def load_ckpt (self , ckpt_name ):
1731+ self ._tgt_buffer = self ._ckpt [ckpt_name ]["_tgt_buffer" ].copy ()
1732+
1733+ def reset_grad (self ):
1734+ self ._tgt_buffer .clear ()
1735+
1736+ @gs .assert_built
1737+ def get_state (self ):
1738+ state = RigidEntityState (self , self ._sim .cur_step_global )
1739+
1740+ solver_state = self ._solver .get_state ()
1741+ pos = solver_state .links_pos [:, self ._base_links_idx ].squeeze (- 2 )
1742+ quat = solver_state .links_quat [:, self ._base_links_idx ].squeeze (- 2 )
1743+
1744+ assert state ._pos .shape == pos .shape
1745+ assert state ._quat .shape == quat .shape
1746+ state ._pos = pos
1747+ state ._quat = quat
1748+
1749+ return state
16081750
16091751 def get_joint (self , name = None , uid = None ):
16101752 """
@@ -1949,6 +2091,18 @@ def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, uns
19492091 envs_idx : None | array_like, optional
19502092 The indices of the environments. If None, all environments will be considered. Defaults to None.
19512093 """
2094+ # Save in [tgt] for backward pass
2095+ self ._tgt .append (
2096+ {
2097+ "key" : "pos" ,
2098+ "pos" : pos ,
2099+ "envs_idx" : envs_idx ,
2100+ "relative" : relative ,
2101+ "zero_velocity" : zero_velocity ,
2102+ "unsafe" : unsafe ,
2103+ }
2104+ )
2105+
19522106 if not unsafe :
19532107 _pos = torch .as_tensor (pos , dtype = gs .tc_float , device = gs .device ).contiguous ()
19542108 if _pos is not pos :
@@ -1965,6 +2119,18 @@ def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, uns
19652119 if zero_velocity :
19662120 self .zero_all_dofs_velocity (envs_idx , unsafe = unsafe )
19672121
2122+ @gs .assert_built
2123+ def set_pos_grad (self , envs_idx , relative , unsafe , pos_grad ):
2124+ tmp_pos_grad = pos_grad .unsqueeze (- 2 ).clone ()
2125+ self ._solver .set_base_links_pos_grad (
2126+ self ._base_links_idx ,
2127+ envs_idx ,
2128+ relative ,
2129+ unsafe ,
2130+ tmp_pos_grad ,
2131+ )
2132+ pos_grad .data = tmp_pos_grad .squeeze (- 2 )
2133+
19682134 @gs .assert_built
19692135 def set_quat (self , quat , envs_idx = None , * , relative = False , zero_velocity = True , unsafe = False ):
19702136 """
@@ -1983,6 +2149,17 @@ def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, u
19832149 envs_idx : None | array_like, optional
19842150 The indices of the environments. If None, all environments will be considered. Defaults to None.
19852151 """
2152+ # Save in [tgt] for backward pass
2153+ self ._tgt .append (
2154+ {
2155+ "key" : "quat" ,
2156+ "quat" : quat ,
2157+ "envs_idx" : envs_idx ,
2158+ "relative" : relative ,
2159+ "zero_velocity" : zero_velocity ,
2160+ "unsafe" : unsafe ,
2161+ }
2162+ )
19862163 if not unsafe :
19872164 _quat = torch .as_tensor (quat , dtype = gs .tc_float , device = gs .device ).contiguous ()
19882165 if _quat is not quat :
@@ -1999,6 +2176,18 @@ def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, u
19992176 if zero_velocity :
20002177 self .zero_all_dofs_velocity (envs_idx , unsafe = unsafe )
20012178
2179+ @gs .assert_built
2180+ def set_quat_grad (self , envs_idx , relative , unsafe , quat_grad ):
2181+ tmp_quat_grad = quat_grad .unsqueeze (- 2 ).clone ()
2182+ self ._solver .set_base_links_quat_grad (
2183+ self ._base_links_idx ,
2184+ envs_idx ,
2185+ relative ,
2186+ unsafe ,
2187+ tmp_quat_grad ,
2188+ )
2189+ quat_grad .data = tmp_quat_grad .squeeze (- 2 )
2190+
20022191 @gs .assert_built
20032192 def get_verts (self ):
20042193 """
@@ -2091,6 +2280,18 @@ def set_qpos(self, qpos, qs_idx_local=None, envs_idx=None, *, zero_velocity=True
20912280 zero_velocity : bool, optional
20922281 Whether to zero the velocity of all the entity's dofs. Defaults to True. This is a safety measure after a sudden change in entity pose.
20932282 """
2283+ # Save in [tgt] for backward pass
2284+ self ._tgt .append (
2285+ {
2286+ "key" : "qpos" ,
2287+ "qpos" : qpos ,
2288+ "qs_idx_local" : qs_idx_local ,
2289+ "envs_idx" : envs_idx ,
2290+ "zero_velocity" : zero_velocity ,
2291+ "unsafe" : unsafe ,
2292+ }
2293+ )
2294+
20942295 qs_idx = self ._get_idx (qs_idx_local , self .n_qs , self ._q_start , unsafe = True )
20952296 self ._solver .set_qpos (qpos , qs_idx , envs_idx , unsafe = unsafe , skip_forward = zero_velocity )
20962297 if zero_velocity :
@@ -2186,9 +2387,24 @@ def set_dofs_velocity(self, velocity=None, dofs_idx_local=None, envs_idx=None, *
21862387 envs_idx : None | array_like, optional
21872388 The indices of the environments. If None, all environments will be considered. Defaults to None.
21882389 """
2390+ # Save in [tgt] for backward pass
2391+ self ._tgt .append (
2392+ {
2393+ "key" : "dofs_velocity" ,
2394+ "velocity" : velocity ,
2395+ "dofs_idx_local" : dofs_idx_local ,
2396+ "envs_idx" : envs_idx ,
2397+ "unsafe" : unsafe ,
2398+ }
2399+ )
21892400 dofs_idx = self ._get_idx (dofs_idx_local , self .n_dofs , self ._dof_start , unsafe = True )
21902401 self ._solver .set_dofs_velocity (velocity , dofs_idx , envs_idx , skip_forward = False , unsafe = unsafe )
21912402
2403+ @gs .assert_built
2404+ def set_dofs_velocity_grad (self , dofs_idx_local , envs_idx , unsafe , velocity_grad ):
2405+ dofs_idx = self ._get_idx (dofs_idx_local , self .n_dofs , self ._dof_start , unsafe = True )
2406+ self ._solver .set_dofs_velocity_grad (dofs_idx , envs_idx , unsafe , velocity_grad )
2407+
21922408 @gs .assert_built
21932409 def set_dofs_frictionloss (self , frictionloss , dofs_idx_local = None , envs_idx = None , * , unsafe = False ):
21942410 """
0 commit comments