1313from genesis .engine .couplers import SAPCoupler
1414from genesis .engine .states .cache import QueriedStates
1515from genesis .engine .states .entities import FEMEntityState
16- from genesis .utils .misc import ALLOCATE_TENSOR_WARNING , to_gs_tensor , tensor_to_array
16+ from genesis .utils .misc import to_gs_tensor , tensor_to_array , broadcast_tensor
1717
1818from .base_entity import Entity
1919
@@ -126,6 +126,34 @@ def __init__(self, scene, solver, material, morph, surface, idx, v_start=0, el_s
126126 # ----------------------------------- basic entity ops -------------------------------
127127 # ------------------------------------------------------------------------------------
128128
129+ def _sanitize_verts_idx_local (self , verts_idx_local = None , envs_idx = None ):
130+ if verts_idx_local is None :
131+ verts_idx_local = range (self .n_vertices )
132+
133+ if envs_idx is None :
134+ verts_idx_local_ = broadcast_tensor (verts_idx_local , gs .tc_int , (- 1 ,), ("envs_idx" ,))
135+ else :
136+ verts_idx_local_ = broadcast_tensor (verts_idx_local , gs .tc_int , (len (envs_idx ), - 1 ), ("envs_idx" , "" ))
137+
138+ # FIXME: This check is too expensive
139+ # if not (0 <= verts_idx_local_ & verts_idx_local_ < self.n_vertices).all():
140+ # gs.raise_exception("Elements of `verts_idx_local' are out-of-range.")
141+
142+ return verts_idx_local_ .contiguous ()
143+
144+ def _sanitize_verts_tensor (self , tensor , dtype , verts_idx = None , envs_idx = None , element_shape = (), * , batched = True ):
145+ n_vertices = verts_idx .shape [- 1 ] if verts_idx is not None else self .n_vertices
146+ if batched :
147+ assert envs_idx is not None
148+ batch_shape = (len (envs_idx ), n_vertices )
149+ dim_names = ("envs_idx" , "verts_idx" , * ("" for _ in element_shape ))
150+ else :
151+ batch_shape = (n_vertices ,)
152+ dim_names = ("verts_idx" , * ("" for _ in element_shape ))
153+ tensor_shape = (* batch_shape , * element_shape )
154+
155+ return broadcast_tensor (tensor , dtype , tensor_shape , dim_names ).contiguous ()
156+
129157 def set_position (self , pos ):
130158 """
131159 Set the target position(s) for the FEM entity.
@@ -153,14 +181,14 @@ def set_position(self, pos):
153181 if pos .ndim == 1 :
154182 if pos .shape == (3 ,):
155183 pos = self .init_positions_COM_offset + pos
156- self ._tgt ["pos" ] = pos . unsqueeze ( 0 ) .tile ((self ._sim ._B , 1 , 1 ))
184+ self ._tgt ["pos" ] = pos [ None ] .tile ((self ._sim ._B , 1 , 1 ))
157185 is_valid = True
158186 elif pos .ndim == 2 :
159187 if pos .shape == (self .n_vertices , 3 ):
160- self ._tgt ["pos" ] = pos . unsqueeze ( 0 ) .tile ((self ._sim ._B , 1 , 1 ))
188+ self ._tgt ["pos" ] = pos [ None ] .tile ((self ._sim ._B , 1 , 1 ))
161189 is_valid = True
162190 elif pos .shape == (self ._sim ._B , 3 ):
163- pos = self .init_positions_COM_offset . unsqueeze ( 0 ) + pos . unsqueeze ( 1 )
191+ pos = self .init_positions_COM_offset [ None ] + pos [:, None ]
164192 self ._tgt ["pos" ] = pos
165193 is_valid = True
166194 elif pos .ndim == 3 :
@@ -200,10 +228,10 @@ def set_velocity(self, vel):
200228 is_valid = True
201229 elif vel .ndim == 2 :
202230 if vel .shape == (self .n_vertices , 3 ):
203- self ._tgt ["vel" ] = vel . unsqueeze ( 0 ) .tile ((self ._sim ._B , 1 , 1 ))
231+ self ._tgt ["vel" ] = vel [ None ] .tile ((self ._sim ._B , 1 , 1 ))
204232 is_valid = True
205233 elif vel .shape == (self ._sim ._B , 3 ):
206- self ._tgt ["vel" ] = vel . unsqueeze ( 1 ) .tile ((1 , self .n_vertices , 1 ))
234+ self ._tgt ["vel" ] = vel [:, None ] .tile ((1 , self .n_vertices , 1 ))
207235 is_valid = True
208236 elif vel .ndim == 3 :
209237 if vel .shape == (self ._sim ._B , self .n_vertices , 3 ):
@@ -241,7 +269,7 @@ def set_actuation(self, actu):
241269 is_valid = True
242270 elif actu .ndim == 1 :
243271 if actu .shape == (n_groups ,):
244- self ._tgt ["actu" ] = actu . unsqueeze ( 0 ) .tile ((self ._sim ._B , 1 ))
272+ self ._tgt ["actu" ] = actu [ None ] .tile ((self ._sim ._B , 1 ))
245273 is_valid = True
246274 elif actu .shape == (self .n_elements ,):
247275 gs .raise_exception ("Cannot set per-element actuation." )
@@ -814,46 +842,16 @@ def set_muscle_direction(self, muscle_direction):
814842 muscle_direction = muscle_direction ,
815843 )
816844
817- def _sanitize_input_tensor (self , tensor , dtype , unbatched_ndim = 1 ):
818- _tensor = torch .as_tensor (tensor , dtype = dtype , device = gs .device )
819-
820- if _tensor .ndim < unbatched_ndim + 1 :
821- _tensor = _tensor .repeat ((self ._sim ._B , * ((1 ,) * max (1 , _tensor .ndim ))))
822- if self ._sim ._B > 1 :
823- gs .logger .debug (ALLOCATE_TENSOR_WARNING )
824- else :
825- _tensor = _tensor .contiguous ()
826- if _tensor is not tensor :
827- gs .logger .debug (ALLOCATE_TENSOR_WARNING )
828-
829- if len (_tensor ) != self ._sim ._B :
830- gs .raise_exception ("Input tensor batch size must match the number of environments." )
831-
832- if _tensor .ndim != unbatched_ndim + 1 :
833- gs .raise_exception (f"Input tensor ndim is { _tensor .ndim } , should be { unbatched_ndim + 1 } ." )
834-
835- return _tensor
836-
837- def _sanitize_input_verts_idx (self , verts_idx_local ):
838- verts_idx = self ._sanitize_input_tensor (verts_idx_local , dtype = gs .tc_int , unbatched_ndim = 1 ) + self ._v_start
839- assert ((verts_idx >= 0 ) & (verts_idx < self ._solver .n_vertices )).all (), "Vertex indices out of bounds."
840- return verts_idx
841-
842- def _sanitize_input_poss (self , poss ):
843- poss = self ._sanitize_input_tensor (poss , dtype = gs .tc_float , unbatched_ndim = 2 )
844- assert poss .ndim == 3 and poss .shape [2 ] == 3 , "Position tensor must have shape (B, num_verts, 3)."
845- return poss
846-
847845 def set_vertex_constraints (
848- self , verts_idx , target_poss = None , link = None , is_soft_constraint = False , stiffness = 0.0 , envs_idx = None
846+ self , verts_idx_local , target_poss = None , link = None , is_soft_constraint = False , stiffness = 0.0 , envs_idx = None
849847 ):
850848 """
851849 Set vertex constraints for specified vertices.
852850
853851 Parameters
854852 ----------
855- verts_idx : array_like
856- List of vertex indices to constrain.
853+ verts_idx_local : array_like
854+ List of local vertex indices to constrain.
857855 target_poss : array_like, shape (len(verts_idx), 3), optional
858856 List of target positions [x, y, z] for each vertex. If not provided, the initial positions are used.
859857 link : RigidLink
@@ -874,30 +872,22 @@ def set_vertex_constraints(
874872 if not self ._solver ._constraints_initialized :
875873 self ._solver .init_constraints ()
876874
875+ use_current_poss = target_poss is None
877876 envs_idx = self ._scene ._sanitize_envs_idx (envs_idx )
878- verts_idx = self ._sanitize_input_verts_idx (verts_idx )
877+ verts_idx_local = self ._sanitize_verts_idx_local (verts_idx_local , envs_idx )
878+ verts_idx = verts_idx_local + self ._v_start
879+ target_poss = self ._sanitize_verts_tensor (target_poss , gs .tc_float , verts_idx , envs_idx , (3 ,))
879880
880- if target_poss is None :
881- target_poss = torch .zeros (
882- (verts_idx .shape [0 ], verts_idx .shape [1 ], 3 ), dtype = gs .tc_float , device = gs .device , requires_grad = False
883- )
881+ if use_current_poss :
884882 self ._kernel_get_verts_pos (self ._sim .cur_substep_local , target_poss , verts_idx )
885- target_poss = self ._sanitize_input_poss (target_poss )
886-
887- assert (
888- len (envs_idx ) == len (target_poss ) == len (verts_idx )
889- ), "First dimension should match number of environments."
890- assert target_poss .shape [1 ] == verts_idx .shape [1 ], "Target position should be provided for each vertex."
891883
892884 if link is None :
893885 link_init_pos = torch .zeros ((self ._sim ._B , 3 ), dtype = gs .tc_float , device = gs .device )
894886 link_init_quat = torch .zeros ((self ._sim ._B , 4 ), dtype = gs .tc_float , device = gs .device )
895887 link_idx = - 1
896888 else :
897889 assert isinstance (link , RigidLink ), "Only RigidLink is supported for vertex constraints."
898- link_init_pos = self ._sanitize_input_tensor (link .get_pos (), dtype = gs .tc_float )
899- link_init_quat = self ._sanitize_input_tensor (link .get_quat (), dtype = gs .tc_float )
900- link_idx = link .idx
890+ link_init_pos , link_init_quat , link_idx = link .get_pos (), link .get_quat (), link .idx
901891
902892 self ._solver ._kernel_set_vertex_constraints (
903893 self ._sim .cur_substep_local ,
@@ -911,31 +901,36 @@ def set_vertex_constraints(
911901 envs_idx ,
912902 )
913903
914- def update_constraint_targets (self , verts_idx , target_poss , envs_idx = None ):
904+ def update_constraint_targets (self , verts_idx_local , target_poss , envs_idx = None ):
915905 """Update target positions for existing constraints."""
916906 if not self ._solver ._constraints_initialized :
917907 gs .logger .warning ("Ignoring update_constraint_targets; constraints have not been initialized." )
918908 return
919909
910+ assert target_poss is not None
920911 envs_idx = self ._scene ._sanitize_envs_idx (envs_idx )
921- verts_idx = self ._sanitize_input_verts_idx ( verts_idx )
922- target_poss = self ._sanitize_input_poss ( target_poss )
923- assert target_poss . shape [ 1 ] == verts_idx . shape [ 1 ], "Target position should be provided for each vertex."
912+ verts_idx_local = self ._sanitize_verts_idx_local ( verts_idx_local , envs_idx )
913+ verts_idx = verts_idx_local + self ._v_start
914+ target_poss = self . _sanitize_verts_tensor ( target_poss , gs . tc_float , verts_idx , envs_idx , ( 3 ,))
924915
925916 self ._solver ._kernel_update_constraint_targets (verts_idx , target_poss , envs_idx )
926917
927- def remove_vertex_constraints (self , verts_idx = None , envs_idx = None ):
918+ def remove_vertex_constraints (self , verts_idx_local = None , envs_idx = None ):
928919 """Remove constraints from specified vertices, or all if None."""
929920 if not self ._solver ._constraints_initialized :
930921 gs .logger .warning ("Ignoring remove_vertex_constraints; constraints have not been initialized." )
931922 return
932923
933- if verts_idx is None :
924+ # FIXME: GsTaichi 'fill' method is very inefficient. Try using zero-copy if possible.
925+ if verts_idx_local is None :
934926 self ._solver .vertex_constraints .is_constrained .fill (0 )
935- else :
936- verts_idx = self ._sanitize_input_verts_idx (verts_idx )
937- envs_idx = self ._scene ._sanitize_envs_idx (envs_idx )
938- self ._solver ._kernel_remove_specific_constraints (verts_idx , envs_idx )
927+ return
928+
929+ envs_idx = self ._scene ._sanitize_envs_idx (envs_idx )
930+ verts_idx_local = self ._sanitize_verts_idx_local (verts_idx_local , envs_idx )
931+ verts_idx = verts_idx_local + self ._v_start
932+
933+ self ._solver ._kernel_remove_specific_constraints (verts_idx , envs_idx )
939934
940935 @ti .kernel
941936 def _kernel_get_verts_pos (self , f : ti .i32 , pos : ti .types .ndarray (), verts_idx : ti .types .ndarray ()):
@@ -954,14 +949,8 @@ def get_el2v(self):
954949 el2v : gs.Tensor
955950 Tensor of shape (n_elements, 4) mapping each element to its vertex indices.
956951 """
957-
958952 el2v = gs .zeros ((self .n_elements , 4 ), dtype = int , requires_grad = False , scene = self .scene )
959- self ._solver ._kernel_get_el2v (
960- element_el_start = self ._el_start ,
961- n_elements = self .n_elements ,
962- el2v = el2v ,
963- )
964-
953+ self ._solver ._kernel_get_el2v (element_el_start = self ._el_start , n_elements = self .n_elements , el2v = el2v )
965954 return el2v
966955
967956 @ti .kernel
0 commit comments