Skip to content

Commit 25cd7c4

Browse files
author
Peter
committed
fix handling of batched inputs for serial chains and add new test
1 parent bc9c65d commit 25cd7c4

File tree

8 files changed

+110
-46
lines changed

8 files changed

+110
-46
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ See `tests` for code samples; some are also shown here.
2222
If you use this package in your research, consider citing
2323
```
2424
@software{Zhong_PyTorch_Kinematics_2023,
25-
author = {Zhong, Sheng and Power, Thomas and Gupta, Ashwin},
25+
author = {Zhong, Sheng and Power, Thomas and Gupta, Ashwin and Peter, Mitrano},
2626
doi = {10.5281/zenodo.7700588},
2727
month = {3},
2828
title = {{PyTorch Kinematics}},

src/pytorch_kinematics/chain.py

Lines changed: 53 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
from functools import lru_cache
2-
from pytorch_kinematics.transforms.rotation_conversions import axis_and_angle_to_matrix
3-
from pytorch_kinematics.transforms.rotation_conversions import axis_and_d_to_pris_matrix
42
from typing import Optional, Sequence
53

64
import numpy as np
@@ -9,9 +7,10 @@
97
import pytorch_kinematics.transforms as tf
108
from pytorch_kinematics import jacobian
119
from pytorch_kinematics.frame import Frame, Link, Joint
10+
from pytorch_kinematics.transforms.rotation_conversions import axis_and_angle_to_matrix_44, axis_and_d_to_pris_matrix
1211

1312

14-
def get_th_size(th):
13+
def get_n_joints(th):
1514
"""
1615
1716
Args:
@@ -28,6 +27,19 @@ def get_th_size(th):
2827
raise NotImplementedError(f"Unsupported type {type(th)}")
2928

3029

30+
def get_batch_size(th):
31+
if isinstance(th, torch.Tensor) or isinstance(th, np.ndarray):
32+
return th.shape[0]
33+
elif isinstance(th, dict):
34+
elem_shape = get_dict_elem_shape(th)
35+
return elem_shape[0]
36+
elif isinstance(th, list):
37+
# Lists cannot be batched. We don't allow lists of lists.
38+
return 1
39+
else:
40+
raise NotImplementedError(f"Unsupported type {type(th)}")
41+
42+
3143
def ensure_2d_tensor(th, dtype, device):
3244
if not torch.is_tensor(th):
3345
th = torch.tensor(th, dtype=dtype, device=device)
@@ -282,6 +294,18 @@ def get_frame_indices(self, *frame_names):
282294
return torch.tensor([self.frame_to_idx[n] for n in frame_names], dtype=torch.long, device=self.device)
283295

284296
def forward_kinematics(self, th, frame_indices: Optional = None):
297+
"""
298+
Compute forward kinematics for the given joint values.
299+
300+
Args:
301+
th: A dict, list, numpy array, or torch tensor of joints values. Possibly batched.
302+
frame_indices: A list of frame indices to compute transforms for. If None, all frames are computed.
303+
Use `get_frame_indices` to convert from frame names to frame indices.
304+
305+
Returns:
306+
A dict of Transform3d objects for each frame.
307+
308+
"""
285309
if frame_indices is None:
286310
frame_indices = self.get_all_frame_indices()
287311

@@ -294,7 +318,7 @@ def forward_kinematics(self, th, frame_indices: Optional = None):
294318
# compute all joint transforms at once first
295319
# in order to handle multiple joint types without branching, we create all possible transforms
296320
# for all joint types and then select the appropriate one for each joint.
297-
rev_jnt_transform = axis_and_angle_to_matrix(axes_expanded, th)
321+
rev_jnt_transform = axis_and_angle_to_matrix_44(axes_expanded, th)
298322
pris_jnt_transform = axis_and_d_to_pris_matrix(axes_expanded, th)
299323

300324
frame_transforms = {}
@@ -336,7 +360,7 @@ def forward_kinematics(self, th, frame_indices: Optional = None):
336360
def ensure_tensor(self, th):
337361
"""
338362
Converts a number of possible types into a tensor. The order of the tensor is determined by the order
339-
of self.get_joint_parameter_names().
363+
of self.get_joint_parameter_names(). th must contain all joints in the entire chain.
340364
"""
341365
if isinstance(th, np.ndarray):
342366
th = torch.tensor(th, device=self.device, dtype=self.dtype)
@@ -362,32 +386,17 @@ def get_all_frame_indices(self):
362386
frame_indices = self.get_frame_indices(*self.get_frame_names(exclude_fixed=False))
363387
return frame_indices
364388

365-
def ensure_dict_of_2d_tensors(self, th):
366-
if not isinstance(th, dict):
367-
th, _ = ensure_2d_tensor(th, self.dtype, self.device)
368-
assert self.n_joints == th.shape[-1]
369-
th_dict = dict((j, th[..., i]) for i, j in enumerate(self.get_joint_parameter_names()))
370-
else:
371-
th_dict = {k: ensure_2d_tensor(v, self.dtype, self.device)[0] for k, v in th.items()}
372-
return th_dict
373-
374389
def clamp(self, th):
375-
th_dict = self.ensure_dict_of_2d_tensors(th)
390+
"""
376391
377-
out_th_dict = {}
378-
for joint_name, joint_position in th_dict.items():
379-
joint = self.find_joint(joint_name)
380-
joint_position_clamped = joint.clamp(joint_position)
381-
out_th_dict[joint_name] = joint_position_clamped
392+
Args:
393+
th: Joint configuration
382394
383-
return self.match_input_type(out_th_dict, th)
395+
Returns: Always a tensor in the order of self.get_joint_parameter_names(), possibly batched.
384396
385-
@staticmethod
386-
def match_input_type(th_dict, th):
387-
if isinstance(th, dict):
388-
return th_dict
389-
else:
390-
return torch.stack([v for v in th_dict.values()], dim=-1)
397+
"""
398+
th = self.ensure_tensor(th)
399+
return torch.clamp(th, self.low, self.high)
391400

392401
def get_joint_limits(self):
393402
low = []
@@ -466,22 +475,33 @@ def forward_kinematics(self, th, end_only: bool = True):
466475
return mat
467476

468477
def convert_serial_inputs_to_chain_inputs(self, th, end_only: bool):
478+
# th = self.ensure_tensor(th)
479+
th_b = get_batch_size(th)
480+
th_n_joints = get_n_joints(th)
481+
if isinstance(th, list):
482+
th = torch.tensor(th, device=self.device, dtype=self.dtype)
483+
469484
if end_only:
470485
frame_indices = self.get_frame_indices(self._serial_frames[-1].name)
471486
else:
472487
# pass through default behavior for frame indices being None, which is currently
473488
# to return all frames.
474489
frame_indices = None
475-
if get_th_size(th) < self.n_joints:
490+
if th_n_joints < self.n_joints:
476491
# if th is only a partial list of joints, assume it's a list of joints for only the serial chain.
477492
partial_th = th
478493
nonfixed_serial_frames = list(filter(lambda f: f.joint.joint_type != 'fixed', self._serial_frames))
479-
if len(nonfixed_serial_frames) != len(partial_th):
480-
raise ValueError(f'Expected {len(nonfixed_serial_frames)} joint values, got {len(partial_th)}.')
481-
th = torch.zeros(self.n_joints, device=self.device, dtype=self.dtype)
482-
for frame, partial_th_i in zip(nonfixed_serial_frames, partial_th):
494+
if th_n_joints != len(nonfixed_serial_frames):
495+
raise ValueError(f'Expected {len(nonfixed_serial_frames)} joint values, got {th_n_joints}.')
496+
th = torch.zeros([th_b, self.n_joints], device=self.device, dtype=self.dtype)
497+
for i, frame in enumerate(nonfixed_serial_frames):
498+
joint_name = frame.joint.name
499+
if isinstance(partial_th, dict):
500+
partial_th_i = partial_th[joint_name]
501+
else:
502+
partial_th_i = partial_th[..., i]
483503
k = self.frame_to_idx[frame.name]
484504
jnt_idx = self.joint_indices[k]
485505
if frame.joint.joint_type != 'fixed':
486-
th[jnt_idx] = partial_th_i
506+
th[..., jnt_idx] = partial_th_i
487507
return frame_indices, th

src/pytorch_kinematics/frame.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22

33
import pytorch_kinematics.transforms as tf
4-
from pytorch_kinematics.transforms import axis_and_angle_to_matrix
4+
from pytorch_kinematics.transforms import axis_and_angle_to_matrix_33
55

66

77
class Visual(object):
@@ -115,7 +115,7 @@ def get_transform(self, theta):
115115
dtype = self.joint.axis.dtype
116116
d = self.joint.axis.device
117117
if self.joint.joint_type == 'revolute':
118-
rot = axis_and_angle_to_matrix(self.joint.axis, theta)
118+
rot = axis_and_angle_to_matrix_33(self.joint.axis, theta)
119119
t = tf.Transform3d(rot=rot, dtype=dtype, device=d)
120120
elif self.joint.joint_type == 'prismatic':
121121
t = tf.Transform3d(pos=theta * self.joint.axis, dtype=dtype, device=d)

src/pytorch_kinematics/jacobian.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def calc_jacobian(serial_chain, th, tool=None):
4747
elif f.joint.joint_type == "prismatic":
4848
cnt += 1
4949
j_eef[:, :3, -cnt] = f.joint.axis.repeat(N, 1) @ cur_transform[:, :3, :3]
50-
cur_frame_transform = f.get_transform(th[:, -cnt].reshape(N, 1)).get_matrix()
50+
cur_frame_transform = f.get_transform(th[:, -cnt]).get_matrix()
5151
cur_transform = cur_frame_transform @ cur_transform
5252

5353
# currently j_eef is Jacobian in end-effector frame, convert to base/world frame

src/pytorch_kinematics/transforms/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
random_rotations,
2020
rotation_6d_to_matrix,
2121
standardize_quaternion,
22-
axis_and_angle_to_matrix,
22+
axis_and_angle_to_matrix_33,
2323
axis_and_d_to_pris_matrix,
2424
wxyz_to_xyzw,
2525
xyzw_to_wxyz,

src/pytorch_kinematics/transforms/rotation_conversions.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ def axis_and_d_to_pris_matrix(axis, d):
472472
return mat44
473473

474474

475-
def axis_and_angle_to_matrix(axis, theta):
475+
def axis_and_angle_to_matrix_44(axis, theta):
476476
"""
477477
Creates a 4x4 matrix that represents a rotation around an axis by an angle theta.
478478
Works with any number of batch dimensions.
@@ -483,6 +483,25 @@ def axis_and_angle_to_matrix(axis, theta):
483483
484484
Returns: [..., 4, 4]
485485
486+
"""
487+
rot = axis_and_angle_to_matrix_33(axis, theta)
488+
batch_shape = axis.shape[:-1]
489+
mat44 = torch.cat((rot, torch.zeros(*batch_shape, 3, 1).to(axis)), -1)
490+
mat44 = torch.cat((mat44, torch.tensor([0.0, 0.0, 0.0, 1.0]).expand(*batch_shape, 1, 4).to(axis)), -2)
491+
return mat44
492+
493+
494+
def axis_and_angle_to_matrix_33(axis, theta):
495+
"""
496+
Creates a 3x3 matrix that represents a rotation around an axis by an angle theta.
497+
Works with any number of batch dimensions.
498+
499+
Argsaxis.sh:
500+
axis: [..., 3]
501+
theta: [ ...]
502+
503+
Returns: [..., 3, 3]
504+
486505
"""
487506
# based on https://ai.stackexchange.com/questions/14041/, and checked against wikipedia
488507
c = torch.cos(theta) # NOTE: cos is not that precise for float32, you may want to use float64
@@ -501,10 +520,7 @@ def axis_and_angle_to_matrix(axis, theta):
501520
rot = torch.stack([torch.stack([r00, r01, r02], -1),
502521
torch.stack([r10, r11, r12], -1),
503522
torch.stack([r20, r21, r22], -1)], -2)
504-
batch_shape = axis.shape[:-1]
505-
mat44 = torch.cat((rot, torch.zeros(*batch_shape, 3, 1).to(axis)), -1)
506-
mat44 = torch.cat((mat44, torch.tensor([0.0, 0.0, 0.0, 1.0]).expand(*batch_shape, 1, 4).to(axis)), -2)
507-
return mat44
523+
return rot
508524

509525

510526
def axis_angle_to_matrix(axis_angle):

tests/test_kinematics.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
2-
from timeit import timeit
32
import os
3+
from timeit import timeit
44

55
import torch
66

@@ -260,7 +260,35 @@ def test_fk_val():
260260
assert torch.allclose(pos, torch.tensor([-0.225692, 0.259045, 0.262139], dtype=torch.float64))
261261

262262

263+
def test_fk_partial_batched_dict():
264+
# Test that you can pass in dict of batched joint configs for a subset of the joints
265+
chain = pk.build_serial_chain_from_mjcf(open(os.path.join(TEST_DIR, "val.xml")).read(), 'left_tool')
266+
th = {
267+
'joint56': torch.zeros([1000], dtype=torch.float64),
268+
'joint57': torch.zeros([1000], dtype=torch.float64),
269+
'joint41': torch.zeros([1000], dtype=torch.float64),
270+
'joint42': torch.zeros([1000], dtype=torch.float64),
271+
'joint43': torch.zeros([1000], dtype=torch.float64),
272+
'joint44': torch.zeros([1000], dtype=torch.float64),
273+
'joint45': torch.zeros([1000], dtype=torch.float64),
274+
'joint46': torch.zeros([1000], dtype=torch.float64),
275+
'joint47': torch.zeros([1000], dtype=torch.float64),
276+
}
277+
chain = chain.to(dtype=torch.float64)
278+
tg = chain.forward_kinematics(th)
279+
280+
281+
def test_fk_partial_batched():
282+
# Test that you can pass in dict of batched joint configs for a subset of the joints
283+
chain = pk.build_serial_chain_from_mjcf(open(os.path.join(TEST_DIR, "val.xml")).read(), 'left_tool')
284+
th = torch.zeros([1000, 9], dtype=torch.float64)
285+
chain = chain.to(dtype=torch.float64)
286+
tg = chain.forward_kinematics(th)
287+
288+
263289
if __name__ == "__main__":
290+
test_fk_partial_batched()
291+
test_fk_partial_batched_dict()
264292
test_fk_val()
265293
test_sdf_serial_chain()
266294
test_urdf_serial()

tests/test_rotation_conversions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44

5-
from pytorch_kinematics.transforms.rotation_conversions import axis_and_angle_to_matrix, axis_angle_to_matrix, \
5+
from pytorch_kinematics.transforms.rotation_conversions import axis_and_angle_to_matrix_33, axis_angle_to_matrix, \
66
pos_rot_to_matrix, matrix_to_pos_rot, random_rotations
77

88

@@ -18,7 +18,7 @@ def test_axis_angle_to_matrix_perf():
1818
dt1 = timeit.timeit(lambda: axis_angle_to_matrix(axis_angle), number=number)
1919
print(f'Old method: {dt1:.5f}')
2020

21-
dt2 = timeit.timeit(lambda: axis_and_angle_to_matrix(axis=axis_1d, theta=theta), number=number)
21+
dt2 = timeit.timeit(lambda: axis_and_angle_to_matrix_33(axis=axis_1d, theta=theta), number=number)
2222
print(f'New method: {dt2:.5f}')
2323

2424

0 commit comments

Comments
 (0)