Skip to content

Commit 61b74ea

Browse files
author
Peter
committed
jupyter notebook to explore performance data
1 parent 57b74fe commit 61b74ea

File tree

4 files changed

+266
-178
lines changed

4 files changed

+266
-178
lines changed

src/pytorch_kinematics/chain.py

Lines changed: 0 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -317,59 +317,6 @@ def forward_kinematics(self, th, frame_indices: Optional = None):
317317

318318
return frame_names_and_transform3ds
319319

320-
def forward_kinematics_py(self, th, frame_indices: Optional = None):
321-
if frame_indices is None:
322-
frame_indices = self.get_all_frame_indices()
323-
324-
th = self.ensure_tensor(th)
325-
th = torch.atleast_2d(th)
326-
327-
b = th.shape[0]
328-
329-
axes_expanded = self.axes.unsqueeze(0).repeat(b, 1, 1)
330-
331-
frame_transforms = {}
332-
333-
# compute all joint transforms at once first
334-
# in order to handle multiple joint types without branching, we create all possible transforms
335-
# for all joint types and then select the appropriate one for each joint.
336-
rev_jnt_transform = tensor_axis_and_angle_to_matrix(axes_expanded, th)
337-
pris_jnt_transform = tensor_axis_and_d_to_pris_matrix(axes_expanded, th)
338-
339-
for frame_idx in frame_indices:
340-
frame_transform = torch.eye(4).to(th).unsqueeze(0).repeat(b, 1, 1)
341-
342-
# iterate down the list and compose the transform
343-
for chain_idx in self.parents_indices[frame_idx]:
344-
if chain_idx.item() in frame_transforms:
345-
frame_transform = frame_transforms[chain_idx.item()]
346-
else:
347-
link_offset_i = self.link_offsets[chain_idx]
348-
if link_offset_i is not None:
349-
frame_transform = frame_transform @ link_offset_i
350-
351-
joint_offset_i = self.joint_offsets[chain_idx]
352-
if joint_offset_i is not None:
353-
frame_transform = frame_transform @ joint_offset_i
354-
355-
jnt_idx = self.joint_indices[chain_idx]
356-
jnt_type = self.joint_type_indices[chain_idx]
357-
if jnt_type == 0:
358-
pass
359-
elif jnt_type == 1:
360-
jnt_transform_i = rev_jnt_transform[:, jnt_idx]
361-
frame_transform = frame_transform @ jnt_transform_i
362-
elif jnt_type == 2:
363-
jnt_transform_i = pris_jnt_transform[:, jnt_idx]
364-
frame_transform = frame_transform @ jnt_transform_i
365-
366-
frame_transforms[frame_idx.item()] = frame_transform
367-
368-
frame_names_and_transform3ds = {self.idx_to_frame[frame_idx]: tf.Transform3d(matrix=transform) for
369-
frame_idx, transform in frame_transforms.items()}
370-
371-
return frame_names_and_transform3ds
372-
373320
def ensure_tensor(self, th):
374321
"""
375322
Converts a number of possible types into a tensor. The order of the tensor is determined by the order
@@ -502,18 +449,6 @@ def forward_kinematics(self, th, end_only: bool = True):
502449
else:
503450
return mat
504451

505-
def forward_kinematics_py(self, th, end_only: bool = True):
506-
""" Like the base class, except `th` only needs to contain the joints in the SerialChain, not all joints. """
507-
frame_indices, th = self.convert_serial_inputs_to_chain_inputs(end_only, th)
508-
509-
mat = super().forward_kinematics_py(th, frame_indices)
510-
511-
if end_only:
512-
return mat[self._serial_frames[-1].name]
513-
else:
514-
return mat
515-
516-
517452
def convert_serial_inputs_to_chain_inputs(self, end_only, th):
518453
if end_only:
519454
frame_indices = self.get_frame_indices(self._serial_frames[-1].name)
@@ -534,29 +469,3 @@ def convert_serial_inputs_to_chain_inputs(self, end_only, th):
534469
if frame.joint.joint_type != 'fixed':
535470
th[jnt_idx] = partial_th_i
536471
return frame_indices, th
537-
538-
def forward_kinematics_slow(self, th, world=None, end_only=True):
539-
if world is None:
540-
world = tf.Transform3d()
541-
if world.dtype != self.dtype or world.device != self.device:
542-
world = world.to(dtype=self.dtype, device=self.device, copy=True)
543-
th, N = ensure_2d_tensor(th, self.dtype, self.device)
544-
zeros = torch.zeros([N, 1], dtype=world.dtype, device=world.device)
545-
546-
theta_idx = 0
547-
link_transforms = {}
548-
trans = tf.Transform3d(matrix=world.get_matrix().repeat(N, 1, 1))
549-
for f in self._serial_frames:
550-
if f.link.offset is not None:
551-
trans = trans.compose(f.link.offset)
552-
553-
if f.joint.joint_type == "fixed": # If fixed
554-
trans = trans.compose(f.get_transform(zeros))
555-
else:
556-
joint_transform = f.get_transform(th[:, theta_idx].view(N, 1))
557-
trans = trans.compose(joint_transform)
558-
theta_idx += 1
559-
560-
link_transforms[f.link.name] = trans
561-
562-
return link_transforms[self._serial_frames[-1].link.name] if end_only else link_transforms
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {
7+
"ExecuteTime": {
8+
"end_time": "2023-10-30T15:37:54.975496787Z",
9+
"start_time": "2023-10-30T15:37:54.933122254Z"
10+
}
11+
},
12+
"outputs": [
13+
{
14+
"ename": "ModuleNotFoundError",
15+
"evalue": "No module named 'torch'",
16+
"output_type": "error",
17+
"traceback": [
18+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
19+
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
20+
"\u001b[0;32m<ipython-input-1-5e79930a6286>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# Load the data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'fk_perf.pkl'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
21+
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch'"
22+
]
23+
}
24+
],
25+
"source": [
26+
"# Load the data\n",
27+
"import torch\n",
28+
"import pickle\n",
29+
"\n",
30+
"with open('fk_perf.pkl', 'rb') as f:\n",
31+
" headers, data = pickle.load(f)\n",
32+
"\n",
33+
"import pandas as pd\n",
34+
"df = pd.DataFrame(data, columns=headers)\n",
35+
"\n",
36+
"df"
37+
]
38+
},
39+
{
40+
"cell_type": "code",
41+
"execution_count": 2,
42+
"metadata": {},
43+
"outputs": [
44+
{
45+
"data": {
46+
"text/plain": [
47+
"'/usr/bin/python3'"
48+
]
49+
},
50+
"execution_count": 2,
51+
"metadata": {},
52+
"output_type": "execute_result"
53+
}
54+
],
55+
"source": [
56+
"import sys\n",
57+
"sys.executable\n"
58+
]
59+
},
60+
{
61+
"cell_type": "code",
62+
"execution_count": null,
63+
"metadata": {},
64+
"outputs": [],
65+
"source": []
66+
}
67+
],
68+
"metadata": {
69+
"kernelspec": {
70+
"display_name": "Python 3",
71+
"language": "python",
72+
"name": "python3"
73+
},
74+
"language_info": {
75+
"codemirror_mode": {
76+
"name": "ipython",
77+
"version": 3
78+
},
79+
"file_extension": ".py",
80+
"mimetype": "text/x-python",
81+
"name": "python",
82+
"nbconvert_exporter": "python",
83+
"pygments_lexer": "ipython3",
84+
"version": "3.8.10"
85+
}
86+
},
87+
"nbformat": 4,
88+
"nbformat_minor": 5
89+
}

tests/test_fk_perf.py

Lines changed: 1 addition & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -5,93 +5,6 @@
55
import pytorch_kinematics as pk
66
import numpy as np
77

8-
N = 10000
9-
number = 100
10-
11-
12-
def test_val_fk_correctness():
13-
val = pk.build_chain_from_mjcf(open('val.xml').read())
14-
val = val.to(dtype=torch.float32, device='cuda')
15-
16-
th = torch.zeros(N, 20, dtype=torch.float32, device='cuda')
17-
18-
frame_indices = val.get_frame_indices('left_tool', 'right_tool')
19-
t_py = val.forward_kinematics_py(th, frame_indices)
20-
t_cpp = val.forward_kinematics(th, frame_indices)
21-
l_py = t_py['left_tool'].get_matrix()
22-
l_cpp = t_cpp['left_tool'].get_matrix()
23-
r_py = t_py['right_tool'].get_matrix()
24-
r_cpp = t_cpp['right_tool'].get_matrix()
25-
26-
assert torch.allclose(l_py, l_cpp)
27-
assert torch.allclose(r_py, r_cpp)
28-
29-
30-
def test_val_fk_perf():
31-
val = pk.build_serial_chain_from_mjcf(open('val.xml').read(), end_link_name='left_tool')
32-
val = val.to(dtype=torch.float32, device='cuda')
33-
34-
th = torch.zeros(N, 20, dtype=torch.float32, device='cuda')
35-
36-
def _val_old_fk():
37-
tg = val.forward_kinematics_slow(th, end_only=True)
38-
m = tg.get_matrix()
39-
return m
40-
41-
def _val_new_py_fk():
42-
tg = val.forward_kinematics_py(th, end_only=True)
43-
m = tg.get_matrix()
44-
return m
45-
46-
def _val_new_cpp_fk():
47-
tg = val.forward_kinematics(th, end_only=True)
48-
m = tg.get_matrix()
49-
return m
50-
51-
val_old_dt = timeit.timeit(_val_old_fk, number=number)
52-
print(f'Val FK OLD dt: {val_old_dt / number:.4f}')
53-
54-
val_new_py_dt = timeit.timeit(_val_new_py_fk, number=number)
55-
print(f'Val FK NEW dt: {val_new_py_dt / number:.4f}')
56-
57-
val_new_cpp_dt = timeit.timeit(_val_new_cpp_fk, number=number)
58-
print(f'Val FK NEW C++ dt: {val_new_cpp_dt / number:.4f}')
59-
60-
assert val_old_dt > val_new_cpp_dt
61-
62-
63-
def test_kuka_fk_perf():
64-
kuka = pk.build_serial_chain_from_urdf(open('kuka_iiwa.urdf').read(), end_link_name='lbr_iiwa_link_7')
65-
kuka = kuka.to(dtype=torch.float32, device='cuda')
66-
67-
th = torch.zeros(N, 7, dtype=torch.float32, device='cuda')
68-
69-
def _kuka_old_fk():
70-
tg = kuka.forward_kinematics_slow(th, end_only=True)
71-
m = tg.get_matrix()
72-
return m
73-
74-
def _kuka_new_py_fk():
75-
tg = kuka.forward_kinematics_py(th, end_only=True)
76-
m = tg.get_matrix()
77-
return m
78-
79-
def _kuka_new_cpp_fk():
80-
tg = kuka.forward_kinematics(th, end_only=True)
81-
m = tg.get_matrix()
82-
return m
83-
84-
kuka_old_dt = timeit.timeit(_kuka_old_fk, number=number)
85-
print(f'Kuka FK OLD dt: {kuka_old_dt / number:.4f}')
86-
87-
kuka_new_py_dt = timeit.timeit(_kuka_new_py_fk, number=number)
88-
print(f'Kuka FK NEW dt: {kuka_new_py_dt / number:.4f}')
89-
90-
kuka_new_cpp_dt = timeit.timeit(_kuka_new_cpp_fk, number=number)
91-
print(f'Kuka FK NEW C++ dt: {kuka_new_cpp_dt / number:.4f}')
92-
93-
assert kuka_old_dt > kuka_new_cpp_dt
94-
958

969
def main():
9710
# do an in-depth analysis of multiple models, devices, data types, batch sizes, etc.
@@ -108,6 +21,7 @@ def main():
10821
devices = ['cpu', 'cuda']
10922
dtypes = [torch.float32, torch.float64]
11023
batch_sizes = [1, 10, 100, 1_000, 10_000, 100_000]
24+
number = 100
11125

11226
# iterate over all combinations and store in a pandas dataframe
11327
headers = ['chain', 'device', 'dtype', 'batch_size', 'time']

tests/viz_fk_perf.ipynb

Lines changed: 176 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)