1
1
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
2
3
3
import functools
4
+ import math
4
5
from typing import Optional
5
6
from warnings import warn
6
7
8
+ import numpy
7
9
import torch
8
10
import torch .nn .functional as F
9
11
@@ -450,7 +452,7 @@ def quaternion_apply(quaternion, point):
450
452
return out [..., 1 :]
451
453
452
454
453
- def tensor_axis_and_d_to_pris_matrix (axis , d ):
455
+ def axis_and_d_to_pris_matrix (axis , d ):
454
456
"""
455
457
Creates a 4x4 matrix that represents a translation along an axis of a distance d
456
458
Works with any number of batch dimensions.
@@ -470,7 +472,7 @@ def tensor_axis_and_d_to_pris_matrix(axis, d):
470
472
return mat44
471
473
472
474
473
- def tensor_axis_and_angle_to_matrix (axis , theta ):
475
+ def axis_and_angle_to_matrix (axis , theta ):
474
476
"""
475
477
Creates a 4x4 matrix that represents a rotation around an axis by an angle theta.
476
478
Works with any number of batch dimensions.
@@ -505,27 +507,6 @@ def tensor_axis_and_angle_to_matrix(axis, theta):
505
507
return mat44
506
508
507
509
508
- def axis_and_angle_to_matrix (axis , theta ):
509
- # based on https://ai.stackexchange.com/questions/14041/, and checked against wikipedia
510
- c = torch .cos (theta ) # NOTE: cos is not that precise for float32, you may want to use float64
511
- one_minus_c = 1 - c
512
- s = torch .sin (theta )
513
- kx , ky , kz = torch .unbind (axis , - 1 )
514
- r00 = c + kx * kx * one_minus_c
515
- r01 = kx * ky * one_minus_c - kz * s
516
- r02 = kx * kz * one_minus_c + ky * s
517
- r10 = ky * kx * one_minus_c + kz * s
518
- r11 = c + ky * ky * one_minus_c
519
- r12 = ky * kz * one_minus_c - kx * s
520
- r20 = kz * kx * one_minus_c - ky * s
521
- r21 = kz * ky * one_minus_c + kx * s
522
- r22 = c + kz * kz * one_minus_c
523
- rot = torch .stack ([torch .cat ([r00 , r01 , r02 ], - 1 ),
524
- torch .cat ([r10 , r11 , r12 ], - 1 ),
525
- torch .cat ([r20 , r21 , r22 ], - 1 )], - 2 )
526
- return rot
527
-
528
-
529
510
def axis_angle_to_matrix (axis_angle ):
530
511
"""
531
512
Convert rotations given as axis/angle to rotation matrices.
@@ -541,7 +522,7 @@ def axis_angle_to_matrix(axis_angle):
541
522
Returns:
542
523
Rotation matrices as tensor of shape (..., 3, 3).
543
524
"""
544
- warn ('This is deprecated because it is slow. Use axis_and_angle_to_matrix or zpk_cpp.axis_and_angle_to_matrix ' ,
525
+ warn ('This is deprecated because it is slow. Use axis_and_angle_to_matrix' ,
545
526
DeprecationWarning , stacklevel = 2 )
546
527
return quaternion_to_matrix (axis_angle_to_quaternion (axis_angle ))
547
528
@@ -682,3 +663,96 @@ def pos_rot_to_matrix(pos, rot):
682
663
m [..., :3 , 3 ] = pos
683
664
m [..., :3 , :3 ] = rot
684
665
return m
666
+
667
+
668
+ # axis sequences for Euler angles
669
+ _NEXT_AXIS = [1 , 2 , 0 , 1 ]
670
+
671
+ # map axes strings to/from tuples of inner axis, parity, repetition, frame
672
+ _AXES2TUPLE = {
673
+ 'sxyz' : (0 , 0 , 0 , 0 ),
674
+ 'sxyx' : (0 , 0 , 1 , 0 ),
675
+ 'sxzy' : (0 , 1 , 0 , 0 ),
676
+ 'sxzx' : (0 , 1 , 1 , 0 ),
677
+ 'syzx' : (1 , 0 , 0 , 0 ),
678
+ 'syzy' : (1 , 0 , 1 , 0 ),
679
+ 'syxz' : (1 , 1 , 0 , 0 ),
680
+ 'syxy' : (1 , 1 , 1 , 0 ),
681
+ 'szxy' : (2 , 0 , 0 , 0 ),
682
+ 'szxz' : (2 , 0 , 1 , 0 ),
683
+ 'szyx' : (2 , 1 , 0 , 0 ),
684
+ 'szyz' : (2 , 1 , 1 , 0 ),
685
+ 'rzyx' : (0 , 0 , 0 , 1 ),
686
+ 'rxyx' : (0 , 0 , 1 , 1 ),
687
+ 'ryzx' : (0 , 1 , 0 , 1 ),
688
+ 'rxzx' : (0 , 1 , 1 , 1 ),
689
+ 'rxzy' : (1 , 0 , 0 , 1 ),
690
+ 'ryzy' : (1 , 0 , 1 , 1 ),
691
+ 'rzxy' : (1 , 1 , 0 , 1 ),
692
+ 'ryxy' : (1 , 1 , 1 , 1 ),
693
+ 'ryxz' : (2 , 0 , 0 , 1 ),
694
+ 'rzxz' : (2 , 0 , 1 , 1 ),
695
+ 'rxyz' : (2 , 1 , 0 , 1 ),
696
+ 'rzyz' : (2 , 1 , 1 , 1 ),
697
+ }
698
+
699
+ _TUPLE2AXES = {v : k for k , v in _AXES2TUPLE .items ()}
700
+
701
+
702
+ def quaternion_from_euler (ai , aj , ak , axes = 'sxyz' ):
703
+ """
704
+ Return quaternion from Euler angles and axis sequence.
705
+ Taken from https://github.com/cgohlke/transformations/blob/master/transformations/transformations.py#L1238
706
+
707
+ ai, aj, ak : Euler's roll, pitch and yaw angles
708
+ axes : One of 24 axis sequences as string or encoded tuple
709
+
710
+ >>> q = quaternion_from_euler(1, 2, 3, 'ryxz')
711
+ >>> numpy.allclose(q, [0.435953, 0.310622, -0.718287, 0.444435])
712
+ True
713
+
714
+ """
715
+ try :
716
+ firstaxis , parity , repetition , frame = _AXES2TUPLE [axes .lower ()]
717
+ except (AttributeError , KeyError ):
718
+ _TUPLE2AXES [axes ] # noqa: validation
719
+ firstaxis , parity , repetition , frame = axes
720
+
721
+ i = firstaxis + 1
722
+ j = _NEXT_AXIS [i + parity - 1 ] + 1
723
+ k = _NEXT_AXIS [i - parity ] + 1
724
+
725
+ if frame :
726
+ ai , ak = ak , ai
727
+ if parity :
728
+ aj = - aj
729
+
730
+ ai /= 2.0
731
+ aj /= 2.0
732
+ ak /= 2.0
733
+ ci = math .cos (ai )
734
+ si = math .sin (ai )
735
+ cj = math .cos (aj )
736
+ sj = math .sin (aj )
737
+ ck = math .cos (ak )
738
+ sk = math .sin (ak )
739
+ cc = ci * ck
740
+ cs = ci * sk
741
+ sc = si * ck
742
+ ss = si * sk
743
+
744
+ q = numpy .empty ((4 ,))
745
+ if repetition :
746
+ q [0 ] = cj * (cc - ss )
747
+ q [i ] = cj * (cs + sc )
748
+ q [j ] = sj * (cc + ss )
749
+ q [k ] = sj * (cs - sc )
750
+ else :
751
+ q [0 ] = cj * cc + sj * ss
752
+ q [i ] = cj * sc - sj * cs
753
+ q [j ] = cj * ss + sj * cc
754
+ q [k ] = cj * cs - sj * sc
755
+ if parity :
756
+ q [j ] *= - 1.0
757
+
758
+ return q
0 commit comments