1
1
from functools import lru_cache
2
- from pytorch_kinematics .transforms .rotation_conversions import axis_and_angle_to_matrix
3
- from pytorch_kinematics .transforms .rotation_conversions import axis_and_d_to_pris_matrix
4
2
from typing import Optional , Sequence
5
3
6
4
import numpy as np
9
7
import pytorch_kinematics .transforms as tf
10
8
from pytorch_kinematics import jacobian
11
9
from pytorch_kinematics .frame import Frame , Link , Joint
10
+ from pytorch_kinematics .transforms .rotation_conversions import axis_and_angle_to_matrix_44 , axis_and_d_to_pris_matrix
12
11
13
12
14
- def get_th_size (th ):
13
+ def get_n_joints (th ):
15
14
"""
16
15
17
16
Args:
@@ -28,6 +27,19 @@ def get_th_size(th):
28
27
raise NotImplementedError (f"Unsupported type { type (th )} " )
29
28
30
29
30
+ def get_batch_size (th ):
31
+ if isinstance (th , torch .Tensor ) or isinstance (th , np .ndarray ):
32
+ return th .shape [0 ]
33
+ elif isinstance (th , dict ):
34
+ elem_shape = get_dict_elem_shape (th )
35
+ return elem_shape [0 ]
36
+ elif isinstance (th , list ):
37
+ # Lists cannot be batched. We don't allow lists of lists.
38
+ return 1
39
+ else :
40
+ raise NotImplementedError (f"Unsupported type { type (th )} " )
41
+
42
+
31
43
def ensure_2d_tensor (th , dtype , device ):
32
44
if not torch .is_tensor (th ):
33
45
th = torch .tensor (th , dtype = dtype , device = device )
@@ -282,6 +294,18 @@ def get_frame_indices(self, *frame_names):
282
294
return torch .tensor ([self .frame_to_idx [n ] for n in frame_names ], dtype = torch .long , device = self .device )
283
295
284
296
def forward_kinematics (self , th , frame_indices : Optional = None ):
297
+ """
298
+ Compute forward kinematics for the given joint values.
299
+
300
+ Args:
301
+ th: A dict, list, numpy array, or torch tensor of joints values. Possibly batched.
302
+ frame_indices: A list of frame indices to compute transforms for. If None, all frames are computed.
303
+ Use `get_frame_indices` to convert from frame names to frame indices.
304
+
305
+ Returns:
306
+ A dict of Transform3d objects for each frame.
307
+
308
+ """
285
309
if frame_indices is None :
286
310
frame_indices = self .get_all_frame_indices ()
287
311
@@ -294,7 +318,7 @@ def forward_kinematics(self, th, frame_indices: Optional = None):
294
318
# compute all joint transforms at once first
295
319
# in order to handle multiple joint types without branching, we create all possible transforms
296
320
# for all joint types and then select the appropriate one for each joint.
297
- rev_jnt_transform = axis_and_angle_to_matrix (axes_expanded , th )
321
+ rev_jnt_transform = axis_and_angle_to_matrix_44 (axes_expanded , th )
298
322
pris_jnt_transform = axis_and_d_to_pris_matrix (axes_expanded , th )
299
323
300
324
frame_transforms = {}
@@ -336,7 +360,7 @@ def forward_kinematics(self, th, frame_indices: Optional = None):
336
360
def ensure_tensor (self , th ):
337
361
"""
338
362
Converts a number of possible types into a tensor. The order of the tensor is determined by the order
339
- of self.get_joint_parameter_names().
363
+ of self.get_joint_parameter_names(). th must contain all joints in the entire chain.
340
364
"""
341
365
if isinstance (th , np .ndarray ):
342
366
th = torch .tensor (th , device = self .device , dtype = self .dtype )
@@ -362,32 +386,17 @@ def get_all_frame_indices(self):
362
386
frame_indices = self .get_frame_indices (* self .get_frame_names (exclude_fixed = False ))
363
387
return frame_indices
364
388
365
- def ensure_dict_of_2d_tensors (self , th ):
366
- if not isinstance (th , dict ):
367
- th , _ = ensure_2d_tensor (th , self .dtype , self .device )
368
- assert self .n_joints == th .shape [- 1 ]
369
- th_dict = dict ((j , th [..., i ]) for i , j in enumerate (self .get_joint_parameter_names ()))
370
- else :
371
- th_dict = {k : ensure_2d_tensor (v , self .dtype , self .device )[0 ] for k , v in th .items ()}
372
- return th_dict
373
-
374
389
def clamp (self , th ):
375
- th_dict = self . ensure_dict_of_2d_tensors ( th )
390
+ """
376
391
377
- out_th_dict = {}
378
- for joint_name , joint_position in th_dict .items ():
379
- joint = self .find_joint (joint_name )
380
- joint_position_clamped = joint .clamp (joint_position )
381
- out_th_dict [joint_name ] = joint_position_clamped
392
+ Args:
393
+ th: Joint configuration
382
394
383
- return self .match_input_type ( out_th_dict , th )
395
+ Returns: Always a tensor in the order of self.get_joint_parameter_names(), possibly batched.
384
396
385
- @staticmethod
386
- def match_input_type (th_dict , th ):
387
- if isinstance (th , dict ):
388
- return th_dict
389
- else :
390
- return torch .stack ([v for v in th_dict .values ()], dim = - 1 )
397
+ """
398
+ th = self .ensure_tensor (th )
399
+ return torch .clamp (th , self .low , self .high )
391
400
392
401
def get_joint_limits (self ):
393
402
low = []
@@ -466,22 +475,33 @@ def forward_kinematics(self, th, end_only: bool = True):
466
475
return mat
467
476
468
477
def convert_serial_inputs_to_chain_inputs (self , th , end_only : bool ):
478
+ # th = self.ensure_tensor(th)
479
+ th_b = get_batch_size (th )
480
+ th_n_joints = get_n_joints (th )
481
+ if isinstance (th , list ):
482
+ th = torch .tensor (th , device = self .device , dtype = self .dtype )
483
+
469
484
if end_only :
470
485
frame_indices = self .get_frame_indices (self ._serial_frames [- 1 ].name )
471
486
else :
472
487
# pass through default behavior for frame indices being None, which is currently
473
488
# to return all frames.
474
489
frame_indices = None
475
- if get_th_size ( th ) < self .n_joints :
490
+ if th_n_joints < self .n_joints :
476
491
# if th is only a partial list of joints, assume it's a list of joints for only the serial chain.
477
492
partial_th = th
478
493
nonfixed_serial_frames = list (filter (lambda f : f .joint .joint_type != 'fixed' , self ._serial_frames ))
479
- if len (nonfixed_serial_frames ) != len (partial_th ):
480
- raise ValueError (f'Expected { len (nonfixed_serial_frames )} joint values, got { len (partial_th )} .' )
481
- th = torch .zeros (self .n_joints , device = self .device , dtype = self .dtype )
482
- for frame , partial_th_i in zip (nonfixed_serial_frames , partial_th ):
494
+ if th_n_joints != len (nonfixed_serial_frames ):
495
+ raise ValueError (f'Expected { len (nonfixed_serial_frames )} joint values, got { th_n_joints } .' )
496
+ th = torch .zeros ([th_b , self .n_joints ], device = self .device , dtype = self .dtype )
497
+ for i , frame in enumerate (nonfixed_serial_frames ):
498
+ joint_name = frame .joint .name
499
+ if isinstance (partial_th , dict ):
500
+ partial_th_i = partial_th [joint_name ]
501
+ else :
502
+ partial_th_i = partial_th [..., i ]
483
503
k = self .frame_to_idx [frame .name ]
484
504
jnt_idx = self .joint_indices [k ]
485
505
if frame .joint .joint_type != 'fixed' :
486
- th [jnt_idx ] = partial_th_i
506
+ th [..., jnt_idx ] = partial_th_i
487
507
return frame_indices , th
0 commit comments