11from copy import copy
22from itertools import chain
33from typing import TYPE_CHECKING , Literal
4+ from functools import wraps
5+ import inspect
46
57import gstaichi as ti
68import numpy as np
1921from genesis .utils import terrain as tu
2022from genesis .utils import urdf as uu
2123from genesis .utils .misc import DeprecationError , broadcast_tensor , sanitize_index , ti_to_torch
24+ from genesis .engine .states .entities import RigidEntityState
2225
2326from ..base_entity import Entity
2427from .rigid_equality import RigidEquality
3134 from genesis .engine .solvers .rigid .rigid_solver_decomp import RigidSolver
3235
3336
37+ # Wrapper to track the arguments of a function and save them in the target buffer
38+ def tracked (fun ):
39+ sig = inspect .signature (fun )
40+
41+ @wraps (fun )
42+ def wrapper (self , * args , ** kwargs ):
43+ if self ._update_tgt_while_set :
44+ bound = sig .bind (self , * args , ** kwargs )
45+ bound .apply_defaults ()
46+ args_dict = dict (tuple (bound .arguments .items ())[1 :])
47+ self ._update_tgt (fun .__name__ , args_dict )
48+ return fun (self , * args , ** kwargs )
49+
50+ return wrapper
51+
52+
3453@ti .data_oriented
3554class RigidEntity (Entity ):
3655 """
@@ -97,6 +116,23 @@ def __init__(
97116
98117 self ._load_model ()
99118
119+ # Initialize target variables and checkpoint
120+ self ._tgt_keys = ("pos" , "quat" , "qpos" , "dofs_velocity" )
121+ self ._tgt = dict ()
122+ self ._tgt_buffer = list ()
123+ self ._ckpt = dict ()
124+ self ._update_tgt_while_set = self ._solver ._requires_grad
125+
126+ def _update_tgt (self , key , value ):
127+ # Set [self._tgt] value while keeping the insertion order between keys. When a new key is inserted or an existing
128+ # key is updated, the new element should be inserted at the end of the dict. This is because we need to keep
129+ # the insertion order to correctly pass the gradients in the backward pass.
130+ self ._tgt .pop (key , None )
131+ self ._tgt [key ] = value
132+
133+ def init_ckpt (self ):
134+ pass
135+
100136 def _load_model (self ):
101137 self ._links = gs .List ()
102138 self ._joints = gs .List ()
@@ -1598,6 +1634,92 @@ def plan_path(
15981634 # ------------------------------------------------------------------------------------
15991635 # ---------------------------------- control & io ------------------------------------
16001636 # ------------------------------------------------------------------------------------
1637+ def process_input (self , in_backward = False ):
1638+ if in_backward :
1639+ # use negative index because buffer length might not be full
1640+ index = self ._sim .cur_step_local - self ._sim ._steps_local
1641+ self ._tgt = self ._tgt_buffer [index ].copy ()
1642+ else :
1643+ self ._tgt_buffer .append (self ._tgt .copy ())
1644+
1645+ update_tgt_while_set = self ._update_tgt_while_set
1646+ # Apply targets in the order of insertion
1647+ for key in self ._tgt .keys ():
1648+ data_kwargs = self ._tgt [key ]
1649+
1650+ # We do not need zero velocity here because if it was true, [set_dofs_velocity] from zero_velocity would
1651+ # be in [tgt]
1652+ if "zero_velocity" in data_kwargs :
1653+ data_kwargs ["zero_velocity" ] = False
1654+ # Do not update [tgt], as input information is finalized at this point
1655+ self ._update_tgt_while_set = False
1656+
1657+ match key :
1658+ case "set_pos" :
1659+ self .set_pos (** data_kwargs )
1660+ case "set_quat" :
1661+ self .set_quat (** data_kwargs )
1662+ case "set_dofs_velocity" :
1663+ self .set_dofs_velocity (** data_kwargs )
1664+ case _:
1665+ gs .raise_exception (f"Invalid target key: { key } not in { self ._tgt_keys } " )
1666+
1667+ self ._tgt = dict ()
1668+ self ._update_tgt_while_set = update_tgt_while_set
1669+
1670+ def process_input_grad (self ):
1671+ index = self ._sim .cur_step_local - self ._sim ._steps_local
1672+ for key in reversed (self ._tgt_buffer [index ].keys ()):
1673+ data_kwargs = self ._tgt_buffer [index ][key ]
1674+
1675+ match key :
1676+ # We need to unpack the data_kwargs because [_backward_from_ti] only supports positional arguments
1677+ case "set_pos" :
1678+ pos = data_kwargs .pop ("pos" )
1679+ if pos .requires_grad :
1680+ pos ._backward_from_ti (self .set_pos_grad , data_kwargs ["envs_idx" ], data_kwargs ["relative" ])
1681+
1682+ case "set_quat" :
1683+ quat = data_kwargs .pop ("quat" )
1684+ if quat .requires_grad :
1685+ quat ._backward_from_ti (self .set_quat_grad , data_kwargs ["envs_idx" ], data_kwargs ["relative" ])
1686+
1687+ case "set_dofs_velocity" :
1688+ velocity = data_kwargs .pop ("velocity" )
1689+ # [velocity] could be None when we want to zero the velocity (see set_dofs_velocity of RigidSolver)
1690+ if velocity is not None and velocity .requires_grad :
1691+ velocity ._backward_from_ti (
1692+ self .set_dofs_velocity_grad ,
1693+ data_kwargs ["dofs_idx_local" ],
1694+ data_kwargs ["envs_idx" ],
1695+ )
1696+ case _:
1697+ gs .raise_exception (f"Invalid target key: { key } not in { self ._tgt_keys } " )
1698+
1699+ def save_ckpt (self , ckpt_name ):
1700+ if ckpt_name not in self ._ckpt :
1701+ self ._ckpt [ckpt_name ] = {}
1702+ self ._ckpt [ckpt_name ]["_tgt_buffer" ] = self ._tgt_buffer .copy ()
1703+ self ._tgt_buffer .clear ()
1704+
1705+ def load_ckpt (self , ckpt_name ):
1706+ self ._tgt_buffer = self ._ckpt [ckpt_name ]["_tgt_buffer" ].copy ()
1707+
1708+ def reset_grad (self ):
1709+ self ._tgt_buffer .clear ()
1710+
1711+ @gs .assert_built
1712+ def get_state (self ):
1713+ state = RigidEntityState (self , self ._sim .cur_step_global )
1714+
1715+ solver_state = self ._solver .get_state ()
1716+ pos = solver_state .links_pos [:, self .base_link_idx ]
1717+ quat = solver_state .links_quat [:, self .base_link_idx ]
1718+
1719+ state ._pos = pos
1720+ state ._quat = quat
1721+
1722+ return state
16011723
16021724 def _get_global_idx (self , idx_local , idx_local_max , idx_global_start = 0 , * , unsafe = False ):
16031725 # Handling default argument and special cases
@@ -1967,6 +2089,7 @@ def get_links_invweight(self, links_idx_local=None, envs_idx=None):
19672089 return self ._solver .get_links_invweight (links_idx , envs_idx )
19682090
19692091 @gs .assert_built
2092+ @tracked
19702093 def set_pos (self , pos , envs_idx = None , * , zero_velocity = True , relative = False ):
19712094 """
19722095 Set position of the entity's base link.
@@ -1989,6 +2112,16 @@ def set_pos(self, pos, envs_idx=None, *, zero_velocity=True, relative=False):
19892112 self ._solver .set_base_links_pos (pos , self ._base_links_idx_ , envs_idx , relative = relative )
19902113
19912114 @gs .assert_built
2115+ def set_pos_grad (self , envs_idx , relative , pos_grad ):
2116+ self ._solver .set_base_links_pos_grad (
2117+ self ._base_links_idx_ ,
2118+ envs_idx ,
2119+ relative ,
2120+ pos_grad .data ,
2121+ )
2122+
2123+ @gs .assert_built
2124+ @tracked
19922125 def set_quat (self , quat , envs_idx = None , * , zero_velocity = True , relative = False ):
19932126 """
19942127 Set quaternion of the entity's base link.
@@ -2010,6 +2143,15 @@ def set_quat(self, quat, envs_idx=None, *, zero_velocity=True, relative=False):
20102143 self ._solver .set_dofs_velocity (None , self ._dofs_idx , envs_idx , skip_forward = True )
20112144 self ._solver .set_base_links_quat (quat , self ._base_links_idx_ , envs_idx , relative = relative )
20122145
2146+ @gs .assert_built
2147+ def set_quat_grad (self , envs_idx , relative , quat_grad ):
2148+ self ._solver .set_base_links_quat_grad (
2149+ self ._base_links_idx_ ,
2150+ envs_idx ,
2151+ relative ,
2152+ quat_grad .data ,
2153+ )
2154+
20132155 @gs .assert_built
20142156 def get_verts (self ):
20152157 """
@@ -2169,6 +2311,7 @@ def set_dofs_frictionloss(self, frictionloss, dofs_idx_local=None, envs_idx=None
21692311 self ._solver .set_dofs_frictionloss (frictionloss , dofs_idx , envs_idx )
21702312
21712313 @gs .assert_built
2314+ @tracked
21722315 def set_dofs_velocity (self , velocity = None , dofs_idx_local = None , envs_idx = None , * , skip_forward = False ):
21732316 """
21742317 Set the entity's dofs' velocity.
@@ -2185,6 +2328,11 @@ def set_dofs_velocity(self, velocity=None, dofs_idx_local=None, envs_idx=None, *
21852328 dofs_idx = self ._get_global_idx (dofs_idx_local , self .n_dofs , self ._dof_start , unsafe = True )
21862329 self ._solver .set_dofs_velocity (velocity , dofs_idx , envs_idx , skip_forward = skip_forward )
21872330
2331+ @gs .assert_built
2332+ def set_dofs_velocity_grad (self , dofs_idx_local , envs_idx , velocity_grad ):
2333+ dofs_idx = self ._get_idx (dofs_idx_local , self .n_dofs , self ._dof_start , unsafe = True )
2334+ self ._solver .set_dofs_velocity_grad (dofs_idx , envs_idx , velocity_grad .data )
2335+
21882336 @gs .assert_built
21892337 def set_dofs_position (self , position , dofs_idx_local = None , envs_idx = None , * , zero_velocity = True ):
21902338 """
0 commit comments