33
44import genesis as gs
55import genesis .utils .geom as gu
6+ from genesis .utils import array_class
67from genesis .utils .misc import DeprecationError
78from genesis .repr_base import RBC
89
910
10- @ti .data_oriented
1111class RigidJoint (RBC ):
1212 """
1313 Joint class for rigid body entities. Each RigidLink is connected to its parent link via a RigidJoint.
@@ -102,18 +102,11 @@ def get_anchor_pos(self):
102102 the anchor point is the "output" of the joint transmission, on which the child body is welded.
103103 """
104104 tensor = torch .empty ((self ._solver ._B , 3 ), dtype = gs .tc_float , device = gs .device )
105- self . _kernel_get_anchor_pos (tensor )
105+ _kernel_get_anchor_pos (self . _idx , tensor , self . _solver . joints_state )
106106 if self ._solver .n_envs == 0 :
107107 tensor = tensor .squeeze (0 )
108108 return tensor
109109
110- @ti .kernel
111- def _kernel_get_anchor_pos (self , tensor : ti .types .ndarray ()):
112- for i_b in range (self ._solver ._B ):
113- xpos = self ._solver .joints_state .xanchor [self ._idx , i_b ]
114- for i in ti .static (range (3 )):
115- tensor [i_b , i ] = xpos [i ]
116-
117110 @gs .assert_built
118111 def get_anchor_axis (self ):
119112 """
@@ -122,18 +115,11 @@ def get_anchor_axis(self):
122115 See `RigidJoint.get_anchor_pos` documentation for details about the notion on anchor point.
123116 """
124117 tensor = torch .empty ((self ._solver ._B , 3 ), dtype = gs .tc_float , device = gs .device )
125- self . _kernel_get_anchor_axis (tensor )
118+ _kernel_get_anchor_axis (self . _idx , tensor , self . _solver . joints_state )
126119 if self ._solver .n_envs == 0 :
127120 tensor = tensor .squeeze (0 )
128121 return tensor
129122
130- @ti .kernel
131- def _kernel_get_anchor_axis (self , tensor : ti .types .ndarray ()):
132- for i_b in range (self ._solver ._B ):
133- xaxis = self ._solver .joints_state .xaxis [self ._idx , i_b ]
134- for i in ti .static (range (3 )):
135- tensor [i_b , i ] = xaxis [i ]
136-
137123 def set_sol_params (self , sol_params ):
138124 """
139125 Set the solver parameters of this joint.
@@ -454,3 +440,21 @@ def is_built(self):
454440
455441 def _repr_brief (self ):
456442 return f"{ (self ._repr_type ())} : { self ._uid } , name: '{ self ._name } ', idx: { self ._idx } , type: { self ._type } "
443+
444+
445+ @ti .kernel
446+ def _kernel_get_anchor_pos (joint_idx : ti .i32 , tensor : ti .types .ndarray (), joints_state : array_class .JointsState ):
447+ _B = joints_state .xanchor .shape [1 ]
448+ for i_b in range (_B ):
449+ xpos = joints_state .xanchor [joint_idx , i_b ]
450+ for i in ti .static (range (3 )):
451+ tensor [i_b , i ] = xpos [i ]
452+
453+
454+ @ti .kernel
455+ def _kernel_get_anchor_axis (joint_idx : ti .i32 , tensor : ti .types .ndarray (), joints_state : array_class .JointsState ):
456+ _B = joints_state .xaxis .shape [1 ]
457+ for i_b in range (_B ):
458+ xaxis = joints_state .xaxis [joint_idx , i_b ]
459+ for i in ti .static (range (3 )):
460+ tensor [i_b , i ] = xaxis [i ]
0 commit comments