Skip to content

Commit 8fbcfd1

Browse files
authored
Make template unit vector computations match those of OpenFold (#235)
* Update __init__.py * Update test_af3.py * Update model_utils.py * Update alphafold3.py * Update template_parsing.py * Update template_parsing.py * Update alphafold3.py * Update __init__.py * Update template_parsing.py * Update model_utils.py
1 parent 1cb7481 commit 8fbcfd1

File tree

5 files changed

+137
-43
lines changed

5 files changed

+137
-43
lines changed

alphafold3_pytorch/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,15 @@
7171
from alphafold3_pytorch.utils.model_utils import (
7272
ExpressCoordinatesInFrame,
7373
RigidFrom3Points,
74+
RigidFromReference3Points,
7475
)
7576

7677
__all__ = [
7778
Attention,
7879
Attend,
7980
RelativePositionEncoding,
81+
RigidFrom3Points,
82+
RigidFromReference3Points,
8083
SmoothLDDTLoss,
8184
WeightedRigidAlign,
8285
MultiChainPermutationAlignment,

alphafold3_pytorch/alphafold3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
from alphafold3_pytorch.utils.model_utils import (
7979
ExpressCoordinatesInFrame,
8080
RigidFrom3Points,
81+
RigidFromReference3Points,
8182
calculate_weighted_rigid_align_weights,
8283
package_available,
8384
)

alphafold3_pytorch/data/template_parsing.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from datetime import datetime
33
from loguru import logger
44
from beartype.typing import Any, Dict, List, Literal, Mapping, Tuple
5+
from einops import einsum
56

67
import numpy as np
78
import polars as pl
@@ -22,8 +23,7 @@
2223
)
2324
from alphafold3_pytorch.utils.data_utils import extract_mmcif_metadata_field
2425
from alphafold3_pytorch.utils.model_utils import (
25-
ExpressCoordinatesInFrame,
26-
RigidFrom3Points,
26+
RigidFromReference3Points,
2727
distance_to_dgram,
2828
get_frames_from_atom_pos,
2929
)
@@ -150,6 +150,7 @@ def _extract_template_features(
150150
num_distogram_bins: int = 39,
151151
distance_bins: List[float] = torch.linspace(3.25, 50.75, 39).float(),
152152
verbose: bool = False,
153+
eps: float = 1e-20,
153154
) -> Dict[str, Any]:
154155
"""Parse atom positions in the target structure and align with the query.
155156
@@ -173,6 +174,7 @@ def _extract_template_features(
173174
:param distance_bins: List of floats representing the bins for the distance
174175
histogram (i.e., distogram).
175176
:param verbose: Whether to log verbose output.
177+
:param eps: A small value to prevent division by zero.
176178
177179
:return: A dictionary containing the extra features derived from the template
178180
structure.
@@ -380,17 +382,23 @@ def _extract_template_features(
380382
template_three_atom_indices_for_frame.unsqueeze(-1).expand(-1, -1, 3),
381383
)
382384

383-
rigid_from_three_points = RigidFrom3Points()
384-
template_backbone_frames, _ = rigid_from_three_points(
385+
rigid_from_reference_3_points = RigidFromReferenceThreePoints()
386+
template_backbone_frames, template_backbone_points = rigid_from_reference_3_points(
385387
template_backbone_frame_atom_positions.unbind(-2)
386388
)
387389

388-
express_coordinates_in_frame = ExpressCoordinatesInFrame()
389-
template_unit_vector = express_coordinates_in_frame(
390-
template_token_center_atom_positions.unsqueeze(0),
391-
template_backbone_frames.unsqueeze(0),
392-
pairwise=True,
393-
).squeeze(0)
390+
inv_template_backbone_frames = template_backbone_frames.transpose(-1, -2)
391+
template_backbone_vec = einsum(
392+
inv_template_backbone_frames,
393+
template_backbone_points.unsqueeze(-2) - template_backbone_points.unsqueeze(-3),
394+
"n i j, m n j -> m n i",
395+
)
396+
template_inv_distance_scalar = torch.rsqrt(eps + torch.sum(template_backbone_vec**2, dim=-1))
397+
template_inv_distance_scalar = (
398+
template_inv_distance_scalar * template_backbone_frame_mask.unsqueeze(-1)
399+
)
400+
401+
template_unit_vector = template_backbone_vec * template_inv_distance_scalar.unsqueeze(-1)
394402

395403
return {
396404
"template_restype": template_restype.float(),

alphafold3_pytorch/utils/model_utils.py

Lines changed: 107 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -809,13 +809,12 @@ def forward(
809809
self,
810810
coords: Float["b m 3"], # type: ignore
811811
frame: Float["b m 3 3"] | Float["b 3 3"] | Float["3 3"], # type: ignore
812-
pairwise: bool = False,
813-
) -> Float["b m 3"] | Float["b m m 3"]: # type: ignore
812+
) -> Float["b m 3"]: # type: ignore
814813
"""Express coordinates in the given frame.
815814
816815
:param coords: Coordinates to be expressed in the given frame.
817816
:param frame: Frames defined by three points.
818-
:return: The transformed coordinates or pairwise coordinates.
817+
:return: The transformed coordinates.
819818
"""
820819

821820
if frame.ndim == 2:
@@ -833,38 +832,19 @@ def forward(
833832
e2 = l2norm(w2 - w1, eps=self.eps)
834833
e3 = torch.cross(e1, e2, dim=-1)
835834

836-
if pairwise:
837-
# Compute pairwise displacement vectors
838-
pairwise_d = coords.unsqueeze(2) - coords.unsqueeze(1)
835+
# Project onto frame basis
836+
d = coords - b
839837

840-
# Project onto frame basis
841-
pairwise_transformed_coords = torch.stack(
842-
(
843-
einsum(pairwise_d, e1.unsqueeze(1), "... i, ... i -> ..."),
844-
einsum(pairwise_d, e2.unsqueeze(1), "... i, ... i -> ..."),
845-
einsum(pairwise_d, e3.unsqueeze(1), "... i, ... i -> ..."),
846-
),
847-
dim=-1,
848-
)
849-
850-
# Normalize to get unit vectors
851-
pairwise_transformed_coords = l2norm(pairwise_transformed_coords, eps=self.eps)
852-
return pairwise_transformed_coords
853-
854-
else:
855-
# Project onto frame basis
856-
d = coords - b
857-
858-
transformed_coords = torch.stack(
859-
(
860-
einsum(d, e1, "... i, ... i -> ..."),
861-
einsum(d, e2, "... i, ... i -> ..."),
862-
einsum(d, e3, "... i, ... i -> ..."),
863-
),
864-
dim=-1,
865-
)
838+
transformed_coords = torch.stack(
839+
(
840+
einsum(d, e1, "... i, ... i -> ..."),
841+
einsum(d, e2, "... i, ... i -> ..."),
842+
einsum(d, e3, "... i, ... i -> ..."),
843+
),
844+
dim=-1,
845+
)
866846

867-
return transformed_coords
847+
return transformed_coords
868848

869849

870850
class RigidFrom3Points(Module):
@@ -906,3 +886,97 @@ def forward(
906886
t = unpack_one(t, "* c")
907887

908888
return R, t
889+
890+
891+
class RigidFromReference3Points(Module):
892+
"""A modification of Algorithm 21 in Section 1.8.1 in AlphaFold 2 paper:
893+
894+
https://www.nature.com/articles/s41586-021-03819-2
895+
896+
Inpsired by the implementation in the OpenFold codebase:
897+
https://github.com/aqlaboratory/openfold/blob/6f63267114435f94ac0604b6d89e82ef45d94484/openfold/utils/feats.py#L143
898+
"""
899+
900+
@typecheck
901+
def forward(
902+
self,
903+
three_points: Tuple[Float["... 3"], Float["... 3"], Float["... 3"]] | Float["3 ... 3"], # type: ignore
904+
eps: float = 1e-20,
905+
) -> Tuple[Float["... 3 3"], Float["... 3"]]: # type: ignore
906+
"""Return a transformation object from reference coordinates.
907+
908+
NOTE: This method does not take care of symmetries. If you
909+
provide the atom positions in the non-standard way,
910+
e.g., the N atom of amino acid residues will end up
911+
not at [-0.527250, 1.359329, 0.0] but instead at
912+
[-0.527250, -1.359329, 0.0]. You need to take care
913+
of such cases in your code.
914+
915+
:param three_points: Three reference points to define the transformation.
916+
:param eps: A small value to avoid division by zero.
917+
:return: A transformation object. After applying the translation and
918+
rotation to the reference backbone, the coordinates will
919+
approximately equal to the input coordinates.
920+
"""
921+
if isinstance(three_points, tuple):
922+
three_points = torch.stack(three_points)
923+
924+
# allow for any number of leading dimensions
925+
926+
(x1, x2, x3), unpack_one = pack_one(three_points, "three * d")
927+
928+
# main algorithm
929+
930+
t = -1 * x2
931+
x1 = x1 + t
932+
x3 = x3 + t
933+
934+
x3_x, x3_y, x3_z = [x3[..., i] for i in range(3)]
935+
norm = torch.sqrt(eps + x3_x**2 + x3_y**2)
936+
sin_x3_1 = -x3_y / norm
937+
cos_x3_1 = x3_x / norm
938+
939+
x3_1_R = sin_x3_1.new_zeros((*sin_x3_1.shape, 3, 3))
940+
x3_1_R[..., 0, 0] = cos_x3_1
941+
x3_1_R[..., 0, 1] = -1 * sin_x3_1
942+
x3_1_R[..., 1, 0] = sin_x3_1
943+
x3_1_R[..., 1, 1] = cos_x3_1
944+
x3_1_R[..., 2, 2] = 1
945+
946+
norm = torch.sqrt(eps + x3_x**2 + x3_y**2 + x3_z**2)
947+
sin_x3_2 = x3_z / norm
948+
cos_x3_2 = torch.sqrt(x3_x**2 + x3_y**2) / norm
949+
950+
x3_2_R = sin_x3_2.new_zeros((*sin_x3_2.shape, 3, 3))
951+
x3_2_R[..., 0, 0] = cos_x3_2
952+
x3_2_R[..., 0, 2] = sin_x3_2
953+
x3_2_R[..., 1, 1] = 1
954+
x3_2_R[..., 2, 0] = -1 * sin_x3_2
955+
x3_2_R[..., 2, 2] = cos_x3_2
956+
957+
x3_R = einsum(x3_2_R, x3_1_R, "n i j, n j k -> n i k")
958+
x1 = einsum(x3_R, x1, "n i j, n j -> n i")
959+
960+
_, x1_y, x1_z = [x1[..., i] for i in range(3)]
961+
norm = torch.sqrt(eps + x1_y**2 + x1_z**2)
962+
sin_x1 = -x1_z / norm
963+
cos_x1 = x1_y / norm
964+
965+
x1_R = sin_x3_2.new_zeros((*sin_x3_2.shape, 3, 3))
966+
x1_R[..., 0, 0] = 1
967+
x1_R[..., 1, 1] = cos_x1
968+
x1_R[..., 1, 2] = -1 * sin_x1
969+
x1_R[..., 2, 1] = sin_x1
970+
x1_R[..., 2, 2] = cos_x1
971+
972+
R = einsum(x1_R, x3_R, "n i j, n j k -> n i k")
973+
974+
R = R.transpose(-1, -2)
975+
t = -1 * t
976+
977+
# unpack
978+
979+
R = unpack_one(R, "* r1 r2")
980+
t = unpack_one(t, "* c")
981+
982+
return R, t

tests/test_af3.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
MultiChainPermutationAlignment,
2020
ExpressCoordinatesInFrame,
2121
RigidFrom3Points,
22+
RigidFromReference3Points,
2223
ComputeAlignmentError,
2324
CentreRandomAugmentation,
2425
PairformerStack,
@@ -245,6 +246,13 @@ def test_rigid_from_three_points():
245246
rotation, _ = rigid_from_3_points((points, points, points))
246247
assert rotation.shape == (7, 11, 23, 3, 3)
247248

249+
def test_rigid_from_reference_three_points():
250+
rigid_from_reference_3_points = RigidFromReference3Points()
251+
252+
points = torch.randn(7, 11, 23, 3)
253+
rotation, _ = rigid_from_reference_3_points((points, points, points))
254+
assert rotation.shape == (7, 11, 23, 3, 3)
255+
248256
def test_deriving_frames_for_ligands():
249257
points = torch.tensor([
250258
[1., 1., 1.],

0 commit comments

Comments
 (0)