@@ -809,13 +809,12 @@ def forward(
809809 self ,
810810 coords : Float ["b m 3" ], # type: ignore
811811 frame : Float ["b m 3 3" ] | Float ["b 3 3" ] | Float ["3 3" ], # type: ignore
812- pairwise : bool = False ,
813- ) -> Float ["b m 3" ] | Float ["b m m 3" ]: # type: ignore
812+ ) -> Float ["b m 3" ]: # type: ignore
814813 """Express coordinates in the given frame.
815814
816815 :param coords: Coordinates to be expressed in the given frame.
817816 :param frame: Frames defined by three points.
818- :return: The transformed coordinates or pairwise coordinates .
817+ :return: The transformed coordinates.
819818 """
820819
821820 if frame .ndim == 2 :
@@ -833,38 +832,19 @@ def forward(
833832 e2 = l2norm (w2 - w1 , eps = self .eps )
834833 e3 = torch .cross (e1 , e2 , dim = - 1 )
835834
836- if pairwise :
837- # Compute pairwise displacement vectors
838- pairwise_d = coords .unsqueeze (2 ) - coords .unsqueeze (1 )
835+ # Project onto frame basis
836+ d = coords - b
839837
840- # Project onto frame basis
841- pairwise_transformed_coords = torch .stack (
842- (
843- einsum (pairwise_d , e1 .unsqueeze (1 ), "... i, ... i -> ..." ),
844- einsum (pairwise_d , e2 .unsqueeze (1 ), "... i, ... i -> ..." ),
845- einsum (pairwise_d , e3 .unsqueeze (1 ), "... i, ... i -> ..." ),
846- ),
847- dim = - 1 ,
848- )
849-
850- # Normalize to get unit vectors
851- pairwise_transformed_coords = l2norm (pairwise_transformed_coords , eps = self .eps )
852- return pairwise_transformed_coords
853-
854- else :
855- # Project onto frame basis
856- d = coords - b
857-
858- transformed_coords = torch .stack (
859- (
860- einsum (d , e1 , "... i, ... i -> ..." ),
861- einsum (d , e2 , "... i, ... i -> ..." ),
862- einsum (d , e3 , "... i, ... i -> ..." ),
863- ),
864- dim = - 1 ,
865- )
838+ transformed_coords = torch .stack (
839+ (
840+ einsum (d , e1 , "... i, ... i -> ..." ),
841+ einsum (d , e2 , "... i, ... i -> ..." ),
842+ einsum (d , e3 , "... i, ... i -> ..." ),
843+ ),
844+ dim = - 1 ,
845+ )
866846
867- return transformed_coords
847+ return transformed_coords
868848
869849
870850class RigidFrom3Points (Module ):
@@ -906,3 +886,97 @@ def forward(
906886 t = unpack_one (t , "* c" )
907887
908888 return R , t
889+
890+
891+ class RigidFromReference3Points (Module ):
892+ """A modification of Algorithm 21 in Section 1.8.1 in AlphaFold 2 paper:
893+
894+ https://www.nature.com/articles/s41586-021-03819-2
895+
896+ Inpsired by the implementation in the OpenFold codebase:
897+ https://github.com/aqlaboratory/openfold/blob/6f63267114435f94ac0604b6d89e82ef45d94484/openfold/utils/feats.py#L143
898+ """
899+
900+ @typecheck
901+ def forward (
902+ self ,
903+ three_points : Tuple [Float ["... 3" ], Float ["... 3" ], Float ["... 3" ]] | Float ["3 ... 3" ], # type: ignore
904+ eps : float = 1e-20 ,
905+ ) -> Tuple [Float ["... 3 3" ], Float ["... 3" ]]: # type: ignore
906+ """Return a transformation object from reference coordinates.
907+
908+ NOTE: This method does not take care of symmetries. If you
909+ provide the atom positions in the non-standard way,
910+ e.g., the N atom of amino acid residues will end up
911+ not at [-0.527250, 1.359329, 0.0] but instead at
912+ [-0.527250, -1.359329, 0.0]. You need to take care
913+ of such cases in your code.
914+
915+ :param three_points: Three reference points to define the transformation.
916+ :param eps: A small value to avoid division by zero.
917+ :return: A transformation object. After applying the translation and
918+ rotation to the reference backbone, the coordinates will
919+ approximately equal to the input coordinates.
920+ """
921+ if isinstance (three_points , tuple ):
922+ three_points = torch .stack (three_points )
923+
924+ # allow for any number of leading dimensions
925+
926+ (x1 , x2 , x3 ), unpack_one = pack_one (three_points , "three * d" )
927+
928+ # main algorithm
929+
930+ t = - 1 * x2
931+ x1 = x1 + t
932+ x3 = x3 + t
933+
934+ x3_x , x3_y , x3_z = [x3 [..., i ] for i in range (3 )]
935+ norm = torch .sqrt (eps + x3_x ** 2 + x3_y ** 2 )
936+ sin_x3_1 = - x3_y / norm
937+ cos_x3_1 = x3_x / norm
938+
939+ x3_1_R = sin_x3_1 .new_zeros ((* sin_x3_1 .shape , 3 , 3 ))
940+ x3_1_R [..., 0 , 0 ] = cos_x3_1
941+ x3_1_R [..., 0 , 1 ] = - 1 * sin_x3_1
942+ x3_1_R [..., 1 , 0 ] = sin_x3_1
943+ x3_1_R [..., 1 , 1 ] = cos_x3_1
944+ x3_1_R [..., 2 , 2 ] = 1
945+
946+ norm = torch .sqrt (eps + x3_x ** 2 + x3_y ** 2 + x3_z ** 2 )
947+ sin_x3_2 = x3_z / norm
948+ cos_x3_2 = torch .sqrt (x3_x ** 2 + x3_y ** 2 ) / norm
949+
950+ x3_2_R = sin_x3_2 .new_zeros ((* sin_x3_2 .shape , 3 , 3 ))
951+ x3_2_R [..., 0 , 0 ] = cos_x3_2
952+ x3_2_R [..., 0 , 2 ] = sin_x3_2
953+ x3_2_R [..., 1 , 1 ] = 1
954+ x3_2_R [..., 2 , 0 ] = - 1 * sin_x3_2
955+ x3_2_R [..., 2 , 2 ] = cos_x3_2
956+
957+ x3_R = einsum (x3_2_R , x3_1_R , "n i j, n j k -> n i k" )
958+ x1 = einsum (x3_R , x1 , "n i j, n j -> n i" )
959+
960+ _ , x1_y , x1_z = [x1 [..., i ] for i in range (3 )]
961+ norm = torch .sqrt (eps + x1_y ** 2 + x1_z ** 2 )
962+ sin_x1 = - x1_z / norm
963+ cos_x1 = x1_y / norm
964+
965+ x1_R = sin_x3_2 .new_zeros ((* sin_x3_2 .shape , 3 , 3 ))
966+ x1_R [..., 0 , 0 ] = 1
967+ x1_R [..., 1 , 1 ] = cos_x1
968+ x1_R [..., 1 , 2 ] = - 1 * sin_x1
969+ x1_R [..., 2 , 1 ] = sin_x1
970+ x1_R [..., 2 , 2 ] = cos_x1
971+
972+ R = einsum (x1_R , x3_R , "n i j, n j k -> n i k" )
973+
974+ R = R .transpose (- 1 , - 2 )
975+ t = - 1 * t
976+
977+ # unpack
978+
979+ R = unpack_one (R , "* r1 r2" )
980+ t = unpack_one (t , "* c" )
981+
982+ return R , t
0 commit comments