@@ -1197,22 +1197,10 @@ def get_state(self, f):
11971197 if self .is_active :
11981198 state = RigidSolverState (self ._scene )
11991199
1200- # qpos: ti.types.ndarray(),
1201- # vel: ti.types.ndarray(),
1202- # links_pos: ti.types.ndarray(),
1203- # links_quat: ti.types.ndarray(),
1204- # i_pos_shift: ti.types.ndarray(),
1205- # mass_shift: ti.types.ndarray(),
1206- # friction_ratio: ti.types.ndarray(),
1207- # links_state: array_class.LinksState,
1208- # dofs_state: array_class.DofsState,
1209- # geoms_state: array_class.GeomsState,
1210- # rigid_global_info: array_class.RigidGlobalInfo,
1211- # static_rigid_sim_config: ti.template(),
1212-
12131200 kernel_get_state (
12141201 qpos = state .qpos ,
12151202 vel = state .dofs_vel ,
1203+ acc = state .dofs_acc ,
12161204 links_pos = state .links_pos ,
12171205 links_quat = state .links_quat ,
12181206 i_pos_shift = state .i_pos_shift ,
@@ -1234,6 +1222,7 @@ def set_state(self, f, state, envs_idx=None):
12341222 kernel_set_state (
12351223 qpos = state .qpos ,
12361224 dofs_vel = state .dofs_vel ,
1225+ dofs_acc = state .dofs_acc ,
12371226 links_pos = state .links_pos ,
12381227 links_quat = state .links_quat ,
12391228 i_pos_shift = state .i_pos_shift ,
@@ -1749,6 +1738,14 @@ def set_sol_params(self, sol_params, geoms_idx=None, envs_idx=None, *, joints_id
17491738 )
17501739
17511740 def _set_dofs_info (self , tensor_list , dofs_idx , name , envs_idx = None , * , unsafe = False ):
1741+ if gs .use_zerocopy and name in {"kp" , "kv" , "force_range" , "stiffness" , "damping" , "frictionloss" , "limit" }:
1742+ mask = indices_to_mask (* ((envs_idx , dofs_idx ) if self ._options .batch_dofs_info else (dofs_idx ,)))
1743+ data = ti_to_torch (getattr (self .dofs_info , name ), transpose = True , copy = False )
1744+ num_values = len (tensor_list )
1745+ for j , mask_j in enumerate (((* mask , ..., j ) for j in range (num_values )) if num_values > 1 else (mask ,)):
1746+ data [mask_j ] = torch .as_tensor (tensor_list [j ], dtype = gs .tc_float , device = gs .device )
1747+ return
1748+
17521749 tensor_list = list (tensor_list )
17531750 for j , tensor in enumerate (tensor_list ):
17541751 tensor_list [j ], dofs_idx , envs_idx_ = self ._sanitize_1D_io_variables (
@@ -1775,7 +1772,7 @@ def _set_dofs_info(self, tensor_list, dofs_idx, name, envs_idx=None, *, unsafe=F
17751772 elif name == "armature" :
17761773 kernel_set_dofs_armature (tensor_list [0 ], dofs_idx , envs_idx_ , self .dofs_info , self ._static_rigid_sim_config )
17771774 qs_idx = torch .arange (self .n_qs , dtype = gs .tc_int , device = gs .device )
1778- qpos_cur = self .get_qpos (envs_idx = envs_idx , qs_idx = qs_idx , unsafe = unsafe )
1775+ qpos_cur = self .get_qpos (qs_idx = qs_idx , envs_idx = envs_idx , unsafe = unsafe )
17791776 self ._init_invweight_and_meaninertia (envs_idx = envs_idx , force_update = True , unsafe = unsafe )
17801777 self .set_qpos (qpos_cur , qs_idx = qs_idx , envs_idx = envs_idx , unsafe = unsafe )
17811778 elif name == "damping" :
@@ -5842,11 +5839,11 @@ def kernel_update_vgeoms_render_T(
58425839 vgeoms_render_T [(i_g , i_b , * J )] = ti .cast (geom_T [J ], ti .float32 )
58435840
58445841
5845- # FIXME: This kernel cannot use 'pure' because 'gs.Tensor' is currently not support by GsTaichi
5846- @ti .kernel (fastcache = False )
5842+ @ti .kernel (fastcache = gs .use_fastcache )
58475843def kernel_get_state (
58485844 qpos : ti .types .ndarray (),
58495845 vel : ti .types .ndarray (),
5846+ acc : ti .types .ndarray (),
58505847 links_pos : ti .types .ndarray (),
58515848 links_quat : ti .types .ndarray (),
58525849 i_pos_shift : ti .types .ndarray (),
@@ -5872,6 +5869,7 @@ def kernel_get_state(
58725869 ti .loop_config (serialize = static_rigid_sim_config .para_level < gs .PARA_LEVEL .ALL )
58735870 for i_d , i_b in ti .ndrange (n_dofs , _B ):
58745871 vel [i_b , i_d ] = dofs_state .vel [i_d , i_b ]
5872+ acc [i_b , i_d ] = dofs_state .acc [i_d , i_b ]
58755873
58765874 ti .loop_config (serialize = static_rigid_sim_config .para_level < gs .PARA_LEVEL .ALL )
58775875 for i_l , i_b in ti .ndrange (n_links , _B ):
@@ -5887,11 +5885,11 @@ def kernel_get_state(
58875885 friction_ratio [i_b , i_l ] = geoms_state .friction_ratio [i_l , i_b ]
58885886
58895887
5890- # FIXME: This kernel cannot use 'pure' because 'gs.Tensor' is currently not support by GsTaichi
5891- @ti .kernel (fastcache = False )
5888+ @ti .kernel (fastcache = gs .use_fastcache )
58925889def kernel_set_state (
58935890 qpos : ti .types .ndarray (),
58945891 dofs_vel : ti .types .ndarray (),
5892+ dofs_acc : ti .types .ndarray (),
58955893 links_pos : ti .types .ndarray (),
58965894 links_quat : ti .types .ndarray (),
58975895 i_pos_shift : ti .types .ndarray (),
@@ -5917,6 +5915,7 @@ def kernel_set_state(
59175915 ti .loop_config (serialize = static_rigid_sim_config .para_level < gs .PARA_LEVEL .ALL )
59185916 for i_d , i_b_ in ti .ndrange (n_dofs , envs_idx .shape [0 ]):
59195917 dofs_state .vel [i_d , envs_idx [i_b_ ]] = dofs_vel [envs_idx [i_b_ ], i_d ]
5918+ dofs_state .acc [i_d , envs_idx [i_b_ ]] = dofs_acc [envs_idx [i_b_ ], i_d ]
59205919 dofs_state .ctrl_force [i_d , envs_idx [i_b_ ]] = gs .ti_float (0.0 )
59215920 dofs_state .ctrl_mode [i_d , envs_idx [i_b_ ]] = gs .CTRL_MODE .FORCE
59225921
0 commit comments