Skip to content

Commit 81f2caf

Browse files
author
Peter
committed
address review comments
1 parent 1951e74 commit 81f2caf

File tree

4 files changed

+22
-39
lines changed

4 files changed

+22
-39
lines changed

src/pytorch_kinematics/transforms/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
axis_and_angle_to_matrix,
55
axis_angle_to_quaternion,
66
euler_angles_to_matrix,
7+
matrix_to_axis_angle,
78
matrix_to_euler_angles,
89
matrix_to_quaternion,
910
matrix_to_rotation_6d,
10-
matrix_to_axis_angle,
1111
quaternion_apply,
1212
quaternion_invert,
1313
quaternion_multiply,
@@ -18,8 +18,10 @@
1818
random_rotations,
1919
rotation_6d_to_matrix,
2020
standardize_quaternion,
21+
tensor_axis_and_angle_to_matrix,
22+
tensor_axis_and_d_to_pris_matrix,
23+
wxyz_to_xyzw,
2124
xyzw_to_wxyz,
22-
wxyz_to_xyzw
2325
)
2426
from .so3 import (
2527
so3_exp_map,

src/pytorch_kinematics/transforms/rotation_conversions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,7 @@ def quaternion_apply(quaternion, point):
452452

453453
def tensor_axis_and_d_to_pris_matrix(axis, d):
454454
"""
455+
Creates a 4x4 matrix that represents a translation along an axis of a distance d
455456
Works with any number of batch dimensions.
456457
457458
Args:
@@ -471,6 +472,7 @@ def tensor_axis_and_d_to_pris_matrix(axis, d):
471472

472473
def tensor_axis_and_angle_to_matrix(axis, theta):
473474
"""
475+
Creates a 4x4 matrix that represents a rotation around an axis by an angle theta.
474476
Works with any number of batch dimensions.
475477
476478
Args:

tests/gen_fk_perf.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@ def main():
1111
np.set_printoptions(precision=3, suppress=True, linewidth=220)
1212
torch.set_printoptions(precision=3, sci_mode=False, linewidth=220)
1313

14-
chains = [
15-
pk.build_chain_from_mjcf(open('val.xml').read()),
16-
pk.build_serial_chain_from_mjcf(open('val.xml').read(), end_link_name='left_tool'),
17-
pk.build_serial_chain_from_urdf(open('kuka_iiwa.urdf').read(), end_link_name='lbr_iiwa_link_7'),
18-
]
19-
names = ['val', 'val_serial', 'kuka_iiwa']
14+
chains = {
15+
'val': pk.build_chain_from_mjcf(open('val.xml').read()),
16+
'val_serial': pk.build_serial_chain_from_mjcf(open('val.xml').read(), end_link_name='left_tool'),
17+
'kuka_iiwa': pk.build_serial_chain_from_urdf(open('kuka_iiwa.urdf').read(), end_link_name='lbr_iiwa_link_7'),
18+
}
2019

2120
devices = ['cpu', 'cuda']
2221
dtypes = [torch.float32, torch.float64]
@@ -27,27 +26,19 @@ def main():
2726
headers = ['method', 'chain', 'device', 'dtype', 'batch_size', 'time']
2827
data = []
2928

30-
def _fk_cpp(th):
29+
def _fk(th):
3130
return chain.forward_kinematics(th)
3231

33-
@torch.compile(backend='eager')
34-
def _fk_torch_compile(th):
35-
return chain.forward_kinematics_py(th)
36-
37-
method_names = ['fk_cpp', 'fk_torch_compile']
38-
methods = [_fk_cpp, _fk_torch_compile]
39-
40-
for chain, name in zip(chains, names):
32+
for name, chain in chains.items():
4133
for device in devices:
4234
for dtype in dtypes:
4335
for batch_size in batch_sizes:
44-
for method_name, method in zip(method_names, methods):
45-
chain = chain.to(dtype=dtype, device=device)
46-
th = torch.zeros(batch_size, chain.n_joints).to(dtype=dtype, device=device)
36+
chain = chain.to(dtype=dtype, device=device)
37+
th = torch.zeros(batch_size, chain.n_joints).to(dtype=dtype, device=device)
4738

48-
dt = timeit.timeit(lambda: method(th), number=number)
49-
data.append([name, device, dtype, batch_size, dt / number])
50-
print(f"{method_name} {name=} {device=} {dtype=} {batch_size=} {dt / number:.4f}")
39+
dt = timeit.timeit(lambda: _fk(th), number=number)
40+
data.append([name, device, dtype, batch_size, dt / number])
41+
print(f"{name=} {device=} {dtype=} {batch_size=} {dt / number:.4f}")
5142

5243
# pickle the data for visualization in jupyter notebook
5344
import pickle

tests/test_kinematics.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,9 @@ def test_fk_mjcf():
2727
chain = chain.to(dtype=torch.float64)
2828
print(chain)
2929
print(chain.get_joint_parameter_names())
30-
th = {
31-
'hip_1': 1.0,
32-
'ankle_1': 1,
33-
'hip_2': 0.0,
34-
'ankle_2': 0.0,
35-
'hip_3': 0.0,
36-
'ankle_3': 0.0,
37-
'hip_4': 0.0,
38-
'ankle_4': 0.0,
39-
}
30+
31+
th = {joint: 0.0 for joint in chain.get_joint_parameter_names()}
32+
th.update({'hip_1': 1.0, 'ankle_1': 1})
4033
ret = chain.forward_kinematics(th)
4134
tg = ret['aux_1']
4235
pos, rot = quat_pos_from_transform3d(tg)
@@ -173,12 +166,7 @@ def test_fk_simple_arm():
173166
assert torch.allclose(pos, torch.tensor([1.05, 0.55, 0.5], dtype=torch.float64))
174167

175168
N = 100
176-
ret = chain.forward_kinematics({
177-
'arm_shoulder_pan_joint': torch.rand(N),
178-
'arm_elbow_pan_joint': torch.rand(N),
179-
'arm_wrist_lift_joint': torch.rand(N),
180-
'arm_wrist_roll_joint': torch.rand(N),
181-
})
169+
ret = chain.forward_kinematics({k: torch.rand(N) for k in chain.get_joint_parameter_names()})
182170
tg = ret['arm_wrist_roll']
183171
assert list(tg.get_matrix().shape) == [N, 4, 4]
184172

0 commit comments

Comments
 (0)