Skip to content

Commit 1246088

Browse files
author
Peter
committed
testing torch.compile
1 parent 4385a64 commit 1246088

File tree

4 files changed

+98
-9
lines changed

4 files changed

+98
-9
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,5 +71,5 @@ mujoco = ["mujoco"]
7171
[build-system]
7272
# Including torch and ninja here are needed to build the native code.
7373
# They will be installed as dependencies during the build, which can take a while the first time.
74-
requires = ["setuptools>=60.0.0", "wheel", "torch", "ninja"]
74+
requires = ["setuptools>=60.0.0", "wheel", "torch==2.1.0", "ninja"]
7575
build-backend= "setuptools.build_meta"

src/pytorch_kinematics/chain.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def __init__(self, root_frame, dtype=torch.float32, device="cpu"):
123123
idx += 1
124124
self.joint_type_indices = torch.tensor(self.joint_type_indices)
125125
self.joint_indices = torch.tensor(self.joint_indices)
126+
# We need to use a dict because torch.compile doesn't list lists of tensors
126127
self.parents_indices = [torch.tensor(p, dtype=torch.long, device=self.device) for p in self.parents_indices]
127128

128129
def to(self, dtype=None, device=None):
@@ -317,6 +318,58 @@ def forward_kinematics(self, th, frame_indices: Optional = None):
317318

318319
return frame_names_and_transform3ds
319320

321+
def forward_kinematics_py(self, th, frame_indices: Optional = None):
322+
if frame_indices is None:
323+
frame_indices = self.get_all_frame_indices()
324+
325+
th = self.ensure_tensor(th)
326+
th = torch.atleast_2d(th)
327+
328+
b = th.shape[0]
329+
axes_expanded = self.axes.unsqueeze(0).repeat(b, 1, 1)
330+
331+
# compute all joint transforms at once first
332+
# in order to handle multiple joint types without branching, we create all possible transforms
333+
# for all joint types and then select the appropriate one for each joint.
334+
rev_jnt_transform = tensor_axis_and_angle_to_matrix(axes_expanded, th)
335+
pris_jnt_transform = tensor_axis_and_d_to_pris_matrix(axes_expanded, th)
336+
337+
frame_transforms = {}
338+
b = th.shape[0]
339+
for frame_idx in frame_indices:
340+
frame_transform = torch.eye(4).to(th).unsqueeze(0).repeat(b, 1, 1)
341+
342+
# iterate down the list and compose the transform
343+
for chain_idx in self.parents_indices[frame_idx.item()]:
344+
if chain_idx.item() in frame_transforms:
345+
frame_transform = frame_transforms[chain_idx.item()]
346+
else:
347+
link_offset_i = self.link_offsets[chain_idx]
348+
if link_offset_i is not None:
349+
frame_transform = frame_transform @ link_offset_i
350+
351+
joint_offset_i = self.joint_offsets[chain_idx]
352+
if joint_offset_i is not None:
353+
frame_transform = frame_transform @ joint_offset_i
354+
355+
jnt_idx = self.joint_indices[chain_idx]
356+
jnt_type = self.joint_type_indices[chain_idx]
357+
if jnt_type == 0:
358+
pass
359+
elif jnt_type == 1:
360+
jnt_transform_i = rev_jnt_transform[:, jnt_idx]
361+
frame_transform = frame_transform @ jnt_transform_i
362+
elif jnt_type == 2:
363+
jnt_transform_i = pris_jnt_transform[:, jnt_idx]
364+
frame_transform = frame_transform @ jnt_transform_i
365+
366+
frame_transforms[frame_idx.item()] = frame_transform
367+
368+
frame_names_and_transform3ds = {self.idx_to_frame[frame_idx]: tf.Transform3d(matrix=transform) for
369+
frame_idx, transform in frame_transforms.items()}
370+
371+
return frame_names_and_transform3ds
372+
320373
def ensure_tensor(self, th):
321374
"""
322375
Converts a number of possible types into a tensor. The order of the tensor is determined by the order

tests/gen_fk_perf.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,30 @@ def main():
2424
number = 100
2525

2626
# iterate over all combinations and store in a pandas dataframe
27-
headers = ['chain', 'device', 'dtype', 'batch_size', 'time']
27+
headers = ['method', 'chain', 'device', 'dtype', 'batch_size', 'time']
2828
data = []
2929

30+
def _fk_cpp(th):
31+
return chain.forward_kinematics(th)
32+
33+
@torch.compile(backend='eager')
34+
def _fk_torch_compile(th):
35+
return chain.forward_kinematics_py(th)
36+
37+
method_names = ['fk_cpp', 'fk_torch_compile']
38+
methods = [_fk_cpp, _fk_torch_compile]
39+
3040
for chain, name in zip(chains, names):
3141
for device in devices:
3242
for dtype in dtypes:
3343
for batch_size in batch_sizes:
34-
chain = chain.to(dtype=dtype, device=device)
35-
th = torch.zeros(batch_size, chain.n_joints).to(dtype=dtype, device=device)
44+
for method_name, method in zip(method_names, methods):
45+
chain = chain.to(dtype=dtype, device=device)
46+
th = torch.zeros(batch_size, chain.n_joints).to(dtype=dtype, device=device)
3647

37-
dt = timeit.timeit(lambda: chain.forward_kinematics(th), number=number)
38-
data.append([name, device, dtype, batch_size, dt / number])
39-
print(f"{name=} {device=} {dtype=} {batch_size=} {dt / number:.4f}")
48+
dt = timeit.timeit(lambda: method(th), number=number)
49+
data.append([name, device, dtype, batch_size, dt / number])
50+
print(f"{method_name} {name=} {device=} {dtype=} {batch_size=} {dt / number:.4f}")
4051

4152
# pickle the data for visualization in jupyter notebook
4253
import pickle

tests/test_kinematics.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,9 +255,34 @@ def test_mjcf_slide_joint_parsing():
255255

256256

257257
def test_fk_val():
258+
dtype = torch.float64
259+
d = "cuda" if torch.cuda.is_available() else "cpu"
260+
258261
chain = pk.build_chain_from_mjcf(open(os.path.join(TEST_DIR, "val.xml")).read())
259-
chain = chain.to(dtype=torch.float64)
260-
ret = chain.forward_kinematics(torch.zeros([1000, chain.n_joints], dtype=torch.float64))
262+
chain = chain.to(dtype=torch.float64, device=d)
263+
264+
th = torch.rand(1000, chain.n_joints, dtype=dtype, device=d)
265+
266+
def _fk_no_compile():
267+
return chain.forward_kinematics_py(th)
268+
269+
@torch.compile(backend='inductor')
270+
def _fk_compile():
271+
return chain.forward_kinematics_py(th)
272+
273+
from timeit import timeit
274+
275+
# warmup
276+
_fk_no_compile()
277+
_fk_compile()
278+
279+
number = 10
280+
ms_no_compile = timeit(_fk_no_compile, number=number) / number * 1000
281+
print(f"elapsed {ms_no_compile:.1f}ms for no compile")
282+
ms_compile = timeit(_fk_compile, number=number) / number * 1000
283+
print(f"elapsed {ms_compile:.1f}ms for compile")
284+
285+
ret = chain.forward_kinematics_py(th)
261286
tg = ret['drive45']
262287
pos, rot = quat_pos_from_transform3d(tg)
263288
torch.set_printoptions(precision=6, sci_mode=False)

0 commit comments

Comments
 (0)