Skip to content

Commit 1951e74

Browse files
author
Peter
committed
fix timing test
1 parent b21e7c8 commit 1951e74

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

src/pytorch_kinematics/chain.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,8 @@ def get_link_names(self):
279279

280280
@lru_cache
281281
def get_frame_indices(self, *frame_names):
282-
return torch.tensor([self.frame_to_idx[n] for n in frame_names], dtype=torch.long,
283-
device=self.device)
282+
return torch.tensor([self.frame_to_idx[n] for n in frame_names], dtype=torch.long, device=self.device)
283+
284284
def forward_kinematics(self, th, frame_indices: Optional = None):
285285
if frame_indices is None:
286286
frame_indices = self.get_all_frame_indices()

tests/test_kinematics.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def test_urdf_serial():
115115
print(chain)
116116
print(chain.get_joint_parameter_names())
117117
th = [0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0]
118+
118119
ret = chain.forward_kinematics(th, end_only=False)
119120
tg = ret['lbr_iiwa_link_7']
120121
pos, rot = quat_pos_from_transform3d(tg)
@@ -126,8 +127,14 @@ def test_urdf_serial():
126127
dtype = torch.float64
127128

128129
th_batch = torch.rand(N, len(chain.get_joint_parameter_names()), dtype=dtype, device=d)
130+
129131
chain = chain.to(dtype=dtype, device=d)
130132

133+
# NOTE: Warmstart since pytorch can be slow the first time you run it
134+
# this has to be done after you move it to the GPU. Otherwise the timing isn't representative.
135+
for _ in range(5):
136+
ret = chain.forward_kinematics(th)
137+
131138
number = 10
132139

133140
def _fk_parallel():

0 commit comments

Comments
 (0)