@@ -38,6 +38,16 @@ def ensure_2d_tensor(th, dtype, device):
38
38
return th , N
39
39
40
40
41
+ def get_dict_elem_shape (th_dict ):
42
+ elem = th_dict [list (th_dict .keys ())[0 ]]
43
+ if isinstance (elem , np .ndarray ):
44
+ return elem .shape
45
+ elif isinstance (elem , torch .Tensor ):
46
+ return elem .shape
47
+ else :
48
+ return ()
49
+
50
+
41
51
class Chain :
42
52
"""
43
53
Robot model that may be constructed from different descriptions via their respective parsers.
@@ -174,6 +184,7 @@ def get_joints(self, exclude_fixed=True):
174
184
joints = self ._get_joints (self ._root , exclude_fixed = exclude_fixed )
175
185
return joints
176
186
187
+ @lru_cache ()
177
188
def get_joint_parameter_names (self , exclude_fixed = True ):
178
189
names = []
179
190
for f in self .get_joints (exclude_fixed = exclude_fixed ):
@@ -270,11 +281,16 @@ def get_frame_indices(self, *frame_names):
270
281
271
282
def forward_kinematics (self , th , frame_indices : Optional = None ):
272
283
"""
273
- Instead of a tree, we can use a flat data structure with indexes to represent the parent
274
- then instead of recursion we can just iterate in order and use parent pointers. This
275
- reduces function call overhead and moves some of the indexing work to the constructor.
284
+ Compute forward kinematics.
285
+
286
+ Args:
287
+ th: A dict, list, numpy array, or torch tensor of ALL joints values. Possibly batched.
288
+ the fastest thing to use is a torch tensor, all other types get converted to that.
289
+ If any joint values are missing, an exception will be thrown.
290
+ frame_indices: A list of frame indices to compute forward kinematics for. If None, all frames are computed.
276
291
277
- use `get_frame_indices` to get the indices for the frames you want to compute the pose for.
292
+ Returns:
293
+ A dict of frame names and their corresponding Transform3d objects.
278
294
"""
279
295
if frame_indices is None :
280
296
frame_indices = self .get_all_frame_indices ()
@@ -343,18 +359,28 @@ def forward_kinematics(self, th, frame_indices: Optional = None):
343
359
return frame_names_and_transform3ds
344
360
345
361
def ensure_tensor (self , th ):
362
+ """
363
+ Converts a number of possible types into a tensor. The order of the tensor is determined by the order
364
+ of self.get_joint_parameter_names().
365
+ """
346
366
if isinstance (th , np .ndarray ):
347
367
th = torch .tensor (th , device = self .device , dtype = self .dtype )
348
- if isinstance (th , list ):
368
+ elif isinstance (th , list ):
349
369
th = torch .tensor (th , device = self .device , dtype = self .dtype )
350
- if isinstance (th , dict ):
370
+ elif isinstance (th , dict ):
351
371
# convert dict to a flat, complete, tensor of all joints values. Missing joints are filled with zeros.
352
372
th_dict = th
353
- th = torch .zeros (self .n_joints , device = self .device , dtype = self .dtype )
373
+ elem_shape = get_dict_elem_shape (th_dict )
374
+ th = torch .ones ([* elem_shape , self .n_joints ], device = self .device , dtype = self .dtype ) * torch .nan
354
375
joint_names = self .get_joint_parameter_names ()
355
376
for joint_name , joint_position in th_dict .items ():
356
377
jnt_idx = joint_names .index (joint_name )
357
- th [jnt_idx ] = joint_position
378
+ th [..., jnt_idx ] = joint_position
379
+ if torch .any (torch .isnan (th )):
380
+ msg = "Missing values for the following joints:\n "
381
+ for joint_name , th_i in zip (self .get_joint_parameter_names (), th ):
382
+ msg += joint_name + "\n "
383
+ raise ValueError (msg )
358
384
return th
359
385
360
386
def get_all_frame_indices (self ):
@@ -454,6 +480,7 @@ def jacobian(self, th, locations=None):
454
480
return jacobian .calc_jacobian (self , th , tool = locations )
455
481
456
482
def forward_kinematics (self , th , end_only : bool = True ):
483
+ """ Like the base class, except `th` only needs to contain the joints in the SerialChain, not all joints. """
457
484
if end_only :
458
485
frame_indices = self .get_frame_indices (self ._serial_frames [- 1 ].name )
459
486
else :
0 commit comments