1+ from functools import wraps
2+
13import igl
24import numpy as np
35import gstaichi as ti
1618from .base_entity import Entity
1719
1820
21+ def assert_muscle (method ):
22+ @wraps (method )
23+ def wrapper (self , * args , ** kwargs ):
24+ if not isinstance (self .material , gs .materials .FEM .Muscle ):
25+ gs .raise_exception ("This method is only supported by entities with 'FEM.Muscle' material." )
26+ return method (self , * args , ** kwargs )
27+
28+ return wrapper
29+
30+
1931@ti .data_oriented
2032class FEMEntity (Entity ):
2133 """
@@ -200,6 +212,7 @@ def set_velocity(self, vel):
200212 if not is_valid :
201213 gs .raise_exception ("Tensor shape not supported." )
202214
215+ @assert_muscle
203216 def set_actuation (self , actu ):
204217 """
205218 Set the actuation signal for the FEM entity.
@@ -221,9 +234,8 @@ def set_actuation(self, actu):
221234
222235 actu = to_gs_tensor (actu )
223236
224- n_groups = getattr (self .material , "n_groups" , 1 )
225-
226237 is_valid = False
238+ n_groups = self .material .n_groups
227239 if actu .ndim == 0 :
228240 self ._tgt ["actu" ] = actu .tile ((self ._sim ._B , n_groups ))
229241 is_valid = True
@@ -257,19 +269,17 @@ def set_muscle(self, muscle_group=None, muscle_direction=None):
257269 AssertionError
258270 If tensor shapes are incorrect or normalization fails.
259271 """
260-
261272 self ._assert_active ()
262273
263- if muscle_group is not None :
264- n_groups = getattr (self .material , "n_groups" , 1 )
265- max_group_id = muscle_group .max ().item ()
274+ n_groups = self .material .n_groups
275+ max_group_id = muscle_group .max ().item ()
266276
267- muscle_group = to_gs_tensor (muscle_group )
277+ muscle_group = to_gs_tensor (muscle_group )
268278
269- assert muscle_group .shape == (self .n_elements ,)
270- assert isinstance (max_group_id , int ) and max_group_id < n_groups
279+ assert muscle_group .shape == (self .n_elements ,)
280+ assert isinstance (max_group_id , int ) and max_group_id < n_groups
271281
272- self .set_muscle_group (muscle_group )
282+ self .set_muscle_group (muscle_group )
273283
274284 if muscle_direction is not None :
275285 muscle_direction = to_gs_tensor (muscle_direction )
@@ -280,12 +290,7 @@ def set_muscle(self, muscle_group=None, muscle_direction=None):
280290
281291 def get_state (self ):
282292 state = FEMEntityState (self , self ._sim .cur_step_global )
283- self .get_frame (
284- self ._sim .cur_substep_local ,
285- state .pos ,
286- state .vel ,
287- state .active ,
288- )
293+ self .get_frame (self ._sim .cur_substep_local , state .pos , state .vel , state .active )
289294
290295 # we store all queried states to track gradient flow
291296 self ._queried_states .append (state )
@@ -775,6 +780,7 @@ def set_active(self, f, active):
775780 active = active ,
776781 )
777782
783+ @assert_muscle
778784 def set_muscle_group (self , muscle_group ):
779785 """
780786 Set muscle group index for each element.
@@ -791,6 +797,7 @@ def set_muscle_group(self, muscle_group):
791797 muscle_group = muscle_group ,
792798 )
793799
800+ @assert_muscle
794801 def set_muscle_direction (self , muscle_direction ):
795802 """
796803 Set muscle force direction for each element.
0 commit comments