Skip to content

Commit 4385a64

Browse files
author
Peter
committed
fix tests
1 parent e2909b6 commit 4385a64

File tree

3 files changed

+29
-24
lines changed

3 files changed

+29
-24
lines changed

tests/test_menagerie.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_menagerie():
4444
print(f"\t {chain.get_frame_names()}")
4545
print(f"\t {chain.get_joint_parameter_names()}")
4646
th = np.zeros(len(chain.get_joint_parameter_names()))
47-
fk_dict = chain.forward_kinematics(th, end_only=True)
47+
fk_dict = chain.forward_kinematics(th)
4848

4949

5050
if __name__ == '__main__':

tests/test_rotation_conversions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ def test_axis_angle_to_matrix_perf():
1111
number = 100
1212
N = 1_000
1313

14-
axis_angle = torch.randn([N, 3], device='cuda', dtype=torch.float64)
15-
axis_1d = torch.tensor([1., 0, 0], device='cuda', dtype=torch.float64) # in the FK code this is NOT batched!
14+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
15+
axis_angle = torch.randn([N, 3], device=device, dtype=torch.float64)
16+
axis_1d = torch.tensor([1., 0, 0], device=device, dtype=torch.float64) # in the FK code this is NOT batched!
1617
theta = axis_angle.norm(dim=1, keepdim=True)
1718

1819
dt1 = timeit.timeit(lambda: axis_angle_to_matrix(axis_angle), number=number)

0 commit comments

Comments
 (0)