|
1 | 1 | import math
|
| 2 | +from timeit import timeit |
2 | 3 | import os
|
3 | 4 |
|
4 | 5 | import torch
|
@@ -127,21 +128,24 @@ def test_urdf_serial():
|
127 | 128 | th_batch = torch.rand(N, len(chain.get_joint_parameter_names()), dtype=dtype, device=d)
|
128 | 129 | chain = chain.to(dtype=dtype, device=d)
|
129 | 130 |
|
130 |
| - import time |
131 |
| - start = time.time() |
132 |
| - tg_batch = chain.forward_kinematics(th_batch) |
133 |
| - m = tg_batch.get_matrix() |
134 |
| - elapsed = time.time() - start |
135 |
| - print("elapsed {}s for N={} when parallel".format(elapsed, N)) |
| 131 | + number = 10 |
136 | 132 |
|
137 |
| - start = time.time() |
138 |
| - elapsed = 0 |
139 |
| - for i in range(N): |
140 |
| - tg = chain.forward_kinematics(th_batch[i]) |
141 |
| - elapsed += time.time() - start |
142 |
| - start = time.time() |
143 |
| - assert torch.allclose(tg.get_matrix().view(4, 4), m[i]) |
144 |
| - print("elapsed {}s for N={} when serial".format(elapsed, N)) |
| 133 | + def _fk_parallel(): |
| 134 | + tg_batch = chain.forward_kinematics(th_batch) |
| 135 | + m = tg_batch.get_matrix() |
| 136 | + |
| 137 | + dt_parallel = timeit(_fk_parallel, number=number) / number |
| 138 | + print("elapsed {}s for N={} when parallel".format(dt_parallel, N)) |
| 139 | + |
| 140 | + def _fk_serial(): |
| 141 | + for i in range(N): |
| 142 | + tg = chain.forward_kinematics(th_batch[i]) |
| 143 | + m = tg.get_matrix() |
| 144 | + |
| 145 | + dt_serial = timeit(_fk_serial, number=number) / number |
| 146 | + print("elapsed {}s for N={} when serial".format(dt_serial, N)) |
| 147 | + |
| 148 | + # assert torch.allclose(tg.get_matrix().view(4, 4), m[i]) |
145 | 149 |
|
146 | 150 |
|
147 | 151 | # test robot with prismatic and fixed joints
|
|
0 commit comments