Skip to content

Commit e2909b6

Browse files
author
Peter
committed
adjust serial vs parallel test
1 parent 0859d54 commit e2909b6

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

tests/test_kinematics.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
from timeit import timeit
23
import os
34

45
import torch
@@ -127,21 +128,24 @@ def test_urdf_serial():
127128
th_batch = torch.rand(N, len(chain.get_joint_parameter_names()), dtype=dtype, device=d)
128129
chain = chain.to(dtype=dtype, device=d)
129130

130-
import time
131-
start = time.time()
132-
tg_batch = chain.forward_kinematics(th_batch)
133-
m = tg_batch.get_matrix()
134-
elapsed = time.time() - start
135-
print("elapsed {}s for N={} when parallel".format(elapsed, N))
131+
number = 10
136132

137-
start = time.time()
138-
elapsed = 0
139-
for i in range(N):
140-
tg = chain.forward_kinematics(th_batch[i])
141-
elapsed += time.time() - start
142-
start = time.time()
143-
assert torch.allclose(tg.get_matrix().view(4, 4), m[i])
144-
print("elapsed {}s for N={} when serial".format(elapsed, N))
133+
def _fk_parallel():
134+
tg_batch = chain.forward_kinematics(th_batch)
135+
m = tg_batch.get_matrix()
136+
137+
dt_parallel = timeit(_fk_parallel, number=number) / number
138+
print("elapsed {}s for N={} when parallel".format(dt_parallel, N))
139+
140+
def _fk_serial():
141+
for i in range(N):
142+
tg = chain.forward_kinematics(th_batch[i])
143+
m = tg.get_matrix()
144+
145+
dt_serial = timeit(_fk_serial, number=number) / number
146+
print("elapsed {}s for N={} when serial".format(dt_serial, N))
147+
148+
# assert torch.allclose(tg.get_matrix().view(4, 4), m[i])
145149

146150

147151
# test robot with prismatic and fixed joints

0 commit comments

Comments
 (0)