@@ -123,6 +123,7 @@ def __init__(self, root_frame, dtype=torch.float32, device="cpu"):
123
123
idx += 1
124
124
self .joint_type_indices = torch .tensor (self .joint_type_indices )
125
125
self .joint_indices = torch .tensor (self .joint_indices )
126
+ # We need to use a dict because torch.compile doesn't list lists of tensors
126
127
self .parents_indices = [torch .tensor (p , dtype = torch .long , device = self .device ) for p in self .parents_indices ]
127
128
128
129
def to (self , dtype = None , device = None ):
@@ -317,6 +318,58 @@ def forward_kinematics(self, th, frame_indices: Optional = None):
317
318
318
319
return frame_names_and_transform3ds
319
320
321
+ def forward_kinematics_py (self , th , frame_indices : Optional = None ):
322
+ if frame_indices is None :
323
+ frame_indices = self .get_all_frame_indices ()
324
+
325
+ th = self .ensure_tensor (th )
326
+ th = torch .atleast_2d (th )
327
+
328
+ b = th .shape [0 ]
329
+ axes_expanded = self .axes .unsqueeze (0 ).repeat (b , 1 , 1 )
330
+
331
+ # compute all joint transforms at once first
332
+ # in order to handle multiple joint types without branching, we create all possible transforms
333
+ # for all joint types and then select the appropriate one for each joint.
334
+ rev_jnt_transform = tensor_axis_and_angle_to_matrix (axes_expanded , th )
335
+ pris_jnt_transform = tensor_axis_and_d_to_pris_matrix (axes_expanded , th )
336
+
337
+ frame_transforms = {}
338
+ b = th .shape [0 ]
339
+ for frame_idx in frame_indices :
340
+ frame_transform = torch .eye (4 ).to (th ).unsqueeze (0 ).repeat (b , 1 , 1 )
341
+
342
+ # iterate down the list and compose the transform
343
+ for chain_idx in self .parents_indices [frame_idx .item ()]:
344
+ if chain_idx .item () in frame_transforms :
345
+ frame_transform = frame_transforms [chain_idx .item ()]
346
+ else :
347
+ link_offset_i = self .link_offsets [chain_idx ]
348
+ if link_offset_i is not None :
349
+ frame_transform = frame_transform @ link_offset_i
350
+
351
+ joint_offset_i = self .joint_offsets [chain_idx ]
352
+ if joint_offset_i is not None :
353
+ frame_transform = frame_transform @ joint_offset_i
354
+
355
+ jnt_idx = self .joint_indices [chain_idx ]
356
+ jnt_type = self .joint_type_indices [chain_idx ]
357
+ if jnt_type == 0 :
358
+ pass
359
+ elif jnt_type == 1 :
360
+ jnt_transform_i = rev_jnt_transform [:, jnt_idx ]
361
+ frame_transform = frame_transform @ jnt_transform_i
362
+ elif jnt_type == 2 :
363
+ jnt_transform_i = pris_jnt_transform [:, jnt_idx ]
364
+ frame_transform = frame_transform @ jnt_transform_i
365
+
366
+ frame_transforms [frame_idx .item ()] = frame_transform
367
+
368
+ frame_names_and_transform3ds = {self .idx_to_frame [frame_idx ]: tf .Transform3d (matrix = transform ) for
369
+ frame_idx , transform in frame_transforms .items ()}
370
+
371
+ return frame_names_and_transform3ds
372
+
320
373
def ensure_tensor (self , th ):
321
374
"""
322
375
Converts a number of possible types into a tensor. The order of the tensor is determined by the order
0 commit comments