1414from genesis .engine .states .solvers import RigidSolverState
1515from genesis .options .solvers import RigidOptions
1616from genesis .utils import linalg as lu
17- from genesis .utils .misc import (
18- ALLOCATE_TENSOR_WARNING ,
19- DeprecationError ,
20- ti_to_torch ,
21- ti_to_numpy ,
22- ti_to_python ,
23- indices_to_mask ,
24- _get_ti_metadata ,
25- )
17+ from genesis .utils .misc import ALLOCATE_TENSOR_WARNING , DeprecationError , ti_to_torch , ti_to_numpy , indices_to_mask
2618from genesis .utils .sdf_decomp import SDF
2719
2820from ..base_solver import Solver
@@ -131,7 +123,6 @@ def __init__(self, scene: "Scene", sim: "Simulator", options: RigidOptions) -> N
131123 self ._options = options
132124
133125 self ._cur_step = - 1
134- self ._links_state_cache = {}
135126
136127 self .qpos : ti .Template | ti .types .NDArray | None = None
137128
@@ -819,38 +810,10 @@ def _init_constraint_solver(self):
819810 else :
820811 self .constraint_solver = ConstraintSolver (self )
821812
822- def _get_links_data (
823- self ,
824- field_name : str ,
825- row_mask : slice | int | range | list | torch .Tensor | np .ndarray | None = None ,
826- col_mask : slice | int | range | list | torch .Tensor | np .ndarray | None = None ,
827- keepdim = True ,
828- * ,
829- to_torch = True ,
830- ):
831- links_state_py = self ._links_state_cache .setdefault ((to_torch ,), {})
832-
833- field = getattr (self .links_state , field_name )
834- tensor = links_state_py .get (field_name )
835- if tensor is None :
836- tensor = links_state_py [field_name ] = ti_to_python (field , transpose = True , to_torch = to_torch )
837-
838- ti_data_meta = _get_ti_metadata (field )
839- if len (ti_data_meta .shape ) < 2 :
840- if row_mask is not None and col_mask is not None :
841- gs .raise_exception ("Cannot specify both row and colum masks for tensor with 1D batch." )
842- mask = indices_to_mask (row_mask if col_mask is None else col_mask , keepdim = keepdim , to_torch = to_torch )
843- else :
844- mask = indices_to_mask (row_mask , col_mask , keepdim = keepdim , to_torch = to_torch )
845-
846- return tensor [mask ]
847-
848813 def substep (self ):
849814 # from genesis.utils.tools import create_timer
850815 from genesis .engine .couplers import SAPCoupler
851816
852- self ._links_state_cache .clear ()
853-
854817 kernel_step_1 (
855818 links_state = self .links_state ,
856819 links_info = self .links_info ,
@@ -1301,7 +1264,6 @@ def set_state(self, f, state, envs_idx=None):
13011264 self .collider .clear (envs_idx )
13021265 if self .constraint_solver is not None :
13031266 self .constraint_solver .reset (envs_idx )
1304- self ._links_state_cache .clear ()
13051267 self ._cur_step = - 1
13061268
13071269 def process_input (self , in_backward = False ):
@@ -1541,7 +1503,6 @@ def set_base_links_pos(self, pos, links_idx=None, envs_idx=None, *, relative=Fal
15411503 static_rigid_sim_config = self ._static_rigid_sim_config ,
15421504 )
15431505
1544- self ._links_state_cache .clear ()
15451506 kernel_forward_kinematics_links_geoms (
15461507 envs_idx ,
15471508 links_state = self .links_state ,
@@ -1590,7 +1551,6 @@ def set_base_links_quat(self, quat, links_idx=None, envs_idx=None, *, relative=F
15901551 static_rigid_sim_config = self ._static_rigid_sim_config ,
15911552 )
15921553
1593- self ._links_state_cache .clear ()
15941554 kernel_forward_kinematics_links_geoms (
15951555 envs_idx ,
15961556 links_state = self .links_state ,
@@ -1668,7 +1628,6 @@ def set_qpos(self, qpos, qs_idx=None, envs_idx=None, *, skip_forward=False, unsa
16681628 qpos = qpos .unsqueeze (0 )
16691629 kernel_set_qpos (qpos , qs_idx , envs_idx , self ._rigid_global_info , self ._static_rigid_sim_config )
16701630
1671- self ._links_state_cache .clear ()
16721631 self .collider .reset (envs_idx , cache_only = True )
16731632 if not isinstance (envs_idx , torch .Tensor ):
16741633 envs_idx = self ._scene ._sanitize_envs_idx (envs_idx , unsafe = unsafe )
@@ -1877,7 +1836,6 @@ def set_dofs_velocity(self, velocity, dofs_idx=None, envs_idx=None, *, skip_forw
18771836 velocity = velocity .unsqueeze (0 )
18781837 kernel_set_dofs_velocity (velocity , dofs_idx , envs_idx , self .dofs_state , self ._static_rigid_sim_config )
18791838
1880- self ._links_state_cache .clear ()
18811839 if not skip_forward :
18821840 kernel_forward_velocity (
18831841 envs_idx ,
@@ -1908,7 +1866,6 @@ def set_dofs_position(self, position, dofs_idx=None, envs_idx=None, *, unsafe=Fa
19081866 self ._static_rigid_sim_config ,
19091867 )
19101868
1911- self ._links_state_cache .clear ()
19121869 self .collider .reset (envs_idx , cache_only = True )
19131870 self .collider .clear (envs_idx )
19141871 if self .constraint_solver is not None :
@@ -2055,20 +2012,20 @@ def get_links_pos(
20552012 ):
20562013 ref = self ._convert_ref_to_idx (ref )
20572014 if ref == 0 :
2058- tensor = self ._get_links_data ( " root_COM" , envs_idx , links_idx , to_torch = to_torch )
2015+ tensor = ti_to_torch ( self .links_state . root_COM , envs_idx , links_idx , transpose = True )
20592016 elif ref == 1 :
2060- i_pos = self ._get_links_data ( " i_pos" , envs_idx , links_idx , to_torch = to_torch )
2061- root_COM = self ._get_links_data ( " root_COM" , envs_idx , links_idx , to_torch = to_torch )
2017+ i_pos = ti_to_torch ( self .links_state . i_pos , envs_idx , links_idx , transpose = True )
2018+ root_COM = ti_to_torch ( self .links_state . root_COM , envs_idx , links_idx , transpose = True )
20622019 tensor = i_pos + root_COM
20632020 elif ref == 2 :
2064- tensor = self ._get_links_data ( " pos" , envs_idx , links_idx , to_torch = to_torch )
2021+ tensor = ti_to_torch ( self .links_state . pos , envs_idx , links_idx , transpose = True )
20652022 else :
20662023 gs .raise_exception ("'ref' must be either 'link_origin', 'link_com', or 'root_com'." )
20672024
20682025 return tensor [0 ] if self .n_envs == 0 else tensor
20692026
20702027 def get_links_quat (self , links_idx = None , envs_idx = None , * , to_torch = True , unsafe = False ):
2071- tensor = self ._get_links_data ( " quat" , envs_idx , links_idx , to_torch = to_torch )
2028+ tensor = ti_to_torch ( self .links_state . quat , envs_idx , links_idx , transpose = True )
20722029 return tensor [0 ] if self .n_envs == 0 else tensor
20732030
20742031 def get_links_vel (
@@ -2103,7 +2060,7 @@ def get_links_vel(
21032060 return _tensor
21042061
21052062 def get_links_ang (self , links_idx = None , envs_idx = None , * , to_torch = True , unsafe = False ):
2106- tensor = self ._get_links_data ( " cd_ang" , envs_idx , links_idx , to_torch = to_torch )
2063+ tensor = ti_to_torch ( self .links_state . cd_ang , envs_idx , links_idx , transpose = True )
21072064 return tensor [0 ] if self .n_envs == 0 else tensor
21082065
21092066 def get_links_acc (self , links_idx = None , envs_idx = None , * , unsafe = False ):
@@ -2121,7 +2078,7 @@ def get_links_acc(self, links_idx=None, envs_idx=None, *, unsafe=False):
21212078 return _tensor
21222079
21232080 def get_links_acc_ang (self , links_idx = None , envs_idx = None , * , to_torch = True , unsafe = False ):
2124- tensor = self ._get_links_data ( " cacc_ang" , envs_idx , links_idx , to_torch = to_torch )
2081+ tensor = ti_to_torch ( self .links_state . cacc_ang , envs_idx , links_idx , transpose = True )
21252082 return tensor [0 ] if self .n_envs == 0 else tensor
21262083
21272084 def get_links_root_COM (self , links_idx = None , envs_idx = None , * , to_torch = True , unsafe = False ):
@@ -2131,15 +2088,15 @@ def get_links_root_COM(self, links_idx=None, envs_idx=None, *, to_torch=True, un
21312088 This corresponds to the global COM of each entity, assuming a single-rooted structure - that is, as long as no
21322089 two successive links are connected by a free-floating joint (ie a joint that allows all 6 degrees of freedom).
21332090 """
2134- tensor = self ._get_links_data ( " root_COM" , envs_idx , links_idx , to_torch = to_torch )
2091+ tensor = ti_to_torch ( self .links_state . root_COM , envs_idx , links_idx , transpose = True )
21352092 return tensor [0 ] if self .n_envs == 0 else tensor
21362093
21372094 def get_links_mass_shift (self , links_idx = None , envs_idx = None , * , to_torch = True , unsafe = False ):
2138- tensor = self ._get_links_data ( " mass_shift" , envs_idx , links_idx , to_torch = to_torch )
2095+ tensor = ti_to_torch ( self .links_state . mass_shift , envs_idx , links_idx , transpose = True )
21392096 return tensor [0 ] if self .n_envs == 0 else tensor
21402097
21412098 def get_links_COM_shift (self , links_idx = None , envs_idx = None , * , to_torch = True , unsafe = False ):
2142- tensor = self ._get_links_data ( " i_pos_shift" , envs_idx , links_idx , to_torch = to_torch )
2099+ tensor = ti_to_torch ( self .links_state . i_pos_shift , envs_idx , links_idx , transpose = True )
21432100 return tensor [0 ] if self .n_envs == 0 else tensor
21442101
21452102 def get_links_inertial_mass (self , links_idx = None , envs_idx = None , * , unsafe = False ):
0 commit comments