Skip to content

Commit 9075957

Browse files
committed
it turns out zeroing out the coordinates and weights within weighted rigid align is enough, make sure alphafold3 passes in the atom mask
1 parent 72e10b8 commit 9075957

File tree

2 files changed

+85
-38
lines changed

2 files changed

+85
-38
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 52 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1847,7 +1847,8 @@ def forward(
18471847
atom_pos_aligned_ground_truth = self.weighted_rigid_align(
18481848
atom_pos_ground_truth,
18491849
denoised_atom_pos,
1850-
align_weights
1850+
align_weights,
1851+
mask = atom_mask
18511852
)
18521853

18531854
# main diffusion mse loss
@@ -1932,10 +1933,10 @@ def forward(
19321933
coords_mask: Bool['b n'] | None = None,
19331934
) -> Float['']:
19341935
"""
1935-
pred_coords: predicted coordinates (b, n, 3)
1936-
true_coords: true coordinates (b, n, 3)
1937-
is_dna: boolean tensor indicating DNA atoms (b, n)
1938-
is_rna: boolean tensor indicating RNA atoms (b, n)
1936+
pred_coords: predicted coordinates
1937+
true_coords: true coordinates
1938+
is_dna: boolean tensor indicating DNA atoms
1939+
is_rna: boolean tensor indicating RNA atoms
19391940
"""
19401941
# Compute distances between all pairs of atoms
19411942
pred_dists = torch.cdist(pred_coords, pred_coords)
@@ -1954,15 +1955,16 @@ def forward(
19541955

19551956
# Restrict to bespoke inclusion radius
19561957
is_nucleotide = is_dna | is_rna
1957-
is_nucleotide_pair = is_nucleotide.unsqueeze(-1) & is_nucleotide.unsqueeze(-2)
1958+
is_nucleotide_pair = einx.logical_and('b i, b j -> b i j', is_nucleotide, is_nucleotide)
1959+
19581960
inclusion_radius = torch.where(
19591961
is_nucleotide_pair,
19601962
true_dists < self.nucleic_acid_cutoff,
19611963
true_dists < self.other_cutoff
19621964
)
19631965

19641966
# Compute mean, avoiding self term
1965-
mask = torch.logical_and(inclusion_radius, torch.logical_not(torch.eye(pred_coords.shape[1], dtype=torch.bool, device=pred_coords.device)))
1967+
mask = inclusion_radius & ~torch.eye(pred_coords.shape[1], dtype=torch.bool, device=pred_coords.device)
19661968

19671969
# Take into account variable lengthed atoms in batch
19681970
if exists(coords_mask):
@@ -1974,59 +1976,68 @@ def forward(
19741976
lddt_count = mask.sum(dim=(-1, -2))
19751977
lddt = lddt_sum / lddt_count.clamp(min=1)
19761978

1977-
return 1 - lddt.mean()
1979+
return 1. - lddt.mean()
19781980

19791981
class WeightedRigidAlign(Module):
19801982
""" Algorithm 28 """
1981-
def __init__(self):
1982-
super().__init__()
19831983

19841984
@typecheck
19851985
def forward(
19861986
self,
1987-
pred_coords: Float['b n 3'],
1988-
true_coords: Float['b n 3'],
1989-
weights: Float['b n']
1987+
pred_coords: Float['b n 3'], # predicted coordinates
1988+
true_coords: Float['b n 3'], # true coordinates
1989+
weights: Float['b n'], # weights for each atom
1990+
mask: Bool['b n'] | None = None # mask for variable lengths
19901991
) -> Float['b n 3']:
1991-
"""
1992-
pred_coords: predicted coordinates (b, n, 3)
1993-
true_coords: true coordinates (b, n, 3)
1994-
weights: weights for each atom (b, n)
1995-
"""
1992+
1993+
if exists(mask):
1994+
# zero out all predicted and true coordinates where not an atom
1995+
pred_coords = einx.where('b n, b n c, -> b n c', mask, pred_coords, 0.)
1996+
true_coords = einx.where('b n, b n c, -> b n c', mask, true_coords, 0.)
1997+
weights = einx.where('b n, b n, -> b n', mask, weights, 0.)
1998+
1999+
# Take care of weights broadcasting for coordinate dimension
2000+
weights = rearrange(weights, 'b n -> b n 1')
19962001

19972002
# Compute weighted centroids
1998-
pred_centroid = (pred_coords * weights.unsqueeze(-1)).sum(dim=1) / weights.sum(dim=1, keepdim=True)
1999-
true_centroid = (true_coords * weights.unsqueeze(-1)).sum(dim=1) / weights.sum(dim=1, keepdim=True)
2003+
pred_centroid = (pred_coords * weights).sum(dim=1, keepdim=True) / weights.sum(dim=1, keepdim=True)
2004+
true_centroid = (true_coords * weights).sum(dim=1, keepdim=True) / weights.sum(dim=1, keepdim=True)
20002005

20012006
# Center the coordinates
2002-
pred_coords_centered = pred_coords - pred_centroid.unsqueeze(1)
2003-
true_coords_centered = true_coords - true_centroid.unsqueeze(1)
2007+
pred_coords_centered = pred_coords - pred_centroid
2008+
true_coords_centered = true_coords - true_centroid
20042009

20052010
# Compute the weighted covariance matrix
2006-
cov_matrix = torch.einsum('bni,bnj->bij', true_coords_centered * weights.unsqueeze(-1), pred_coords_centered)
2011+
weighted_true_coords_center = true_coords_centered * weights
2012+
cov_matrix = einsum(weighted_true_coords_center, pred_coords_centered, 'b n i, b n j -> b i j')
20072013

20082014
# Compute the SVD of the covariance matrix
20092015
U, _, V = torch.svd(cov_matrix)
20102016

20112017
# Compute the rotation matrix
2012-
rot_matrix = torch.einsum('bij,bjk->bik', U, V)
2018+
rot_matrix = einsum(U, V, 'b i j, b j k -> b i k')
20132019

20142020
# Ensure proper rotation matrix with determinant 1
20152021
det = torch.det(rot_matrix)
20162022
det_mask = det < 0
20172023
V_fixed = V.clone()
20182024
V_fixed[det_mask, :, -1] *= -1
2019-
rot_matrix[det_mask] = torch.einsum('bij,bjk->bik', U[det_mask], V_fixed[det_mask])
2025+
2026+
rot_matrix[det_mask] = einsum(U[det_mask], V_fixed[det_mask], 'b i j, b j k -> b i k')
20202027

20212028
# Apply the rotation and translation
2022-
aligned_coords = torch.einsum('bni,bij->bnj', pred_coords_centered, rot_matrix) + true_centroid.unsqueeze(1)
2029+
aligned_coords = einsum(pred_coords_centered, rot_matrix, 'b n i, b i j -> b n j') + true_centroid
2030+
aligned_coords.detach_()
20232031

2024-
return aligned_coords.detach()
2032+
return aligned_coords
20252033

20262034
class ExpressCoordinatesInFrame(Module):
20272035
""" Algorithm 29 """
20282036

2029-
def __init__(self, eps = 1e-8):
2037+
def __init__(
2038+
self,
2039+
eps = 1e-8
2040+
):
20302041
super().__init__()
20312042
self.eps = eps
20322043

@@ -2037,8 +2048,8 @@ def forward(
20372048
frame: Float['b m 3 3'] | Float['b 3 3'] | Float['3 3']
20382049
) -> Float['b m 3']:
20392050
"""
2040-
coords: coordinates to be expressed in the given frame (b, 3)
2041-
frame: frame defined by three points (b, 3, 3)
2051+
coords: coordinates to be expressed in the given frame
2052+
frame: frame defined by three points
20422053
"""
20432054

20442055
if frame.ndim == 2:
@@ -2067,8 +2078,12 @@ def forward(
20672078

20682079
class ComputeAlignmentError(Module):
20692080
""" Algorithm 30 """
2081+
20702082
@typecheck
2071-
def __init__(self, eps: float = 1e-8):
2083+
def __init__(
2084+
self,
2085+
eps: float = 1e-8
2086+
):
20722087
super().__init__()
20732088
self.eps = eps
20742089
self.express_coordinates_in_frame = ExpressCoordinatesInFrame()
@@ -2082,10 +2097,10 @@ def forward(
20822097
true_frames: Float['b n 3 3']
20832098
) -> Float['b n']:
20842099
"""
2085-
pred_coords: predicted coordinates (b, n, 3)
2086-
true_coords: true coordinates (b, n, 3)
2087-
pred_frames: predicted frames (b, n, 3, 3)
2088-
true_frames: true frames (b, n, 3, 3)
2100+
pred_coords: predicted coordinates
2101+
true_coords: true coordinates
2102+
pred_frames: predicted frames
2103+
true_frames: true frames
20892104
"""
20902105
# Express predicted coordinates in predicted frames
20912106
pred_coords_transformed = self.express_coordinates_in_frame(pred_coords, pred_frames)
@@ -2102,6 +2117,7 @@ def forward(
21022117

21032118
class CentreRandomAugmentation(Module):
21042119
""" Algorithm 19 """
2120+
21052121
@typecheck
21062122
def __init__(self, trans_scale: float = 1.0):
21072123
super().__init__()
@@ -2110,7 +2126,7 @@ def __init__(self, trans_scale: float = 1.0):
21102126
@typecheck
21112127
def forward(self, coords: Float['b n 3']) -> Float['b n 3']:
21122128
"""
2113-
coords: coordinates to be augmented (b, n, 3)
2129+
coords: coordinates to be augmented
21142130
"""
21152131
# Center the coordinates
21162132
centered_coords = coords - coords.mean(dim=1, keepdim=True)

tests/test_af3.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
os.environ['TYPECHECK'] = 'True'
33

44
import torch
5+
from torch.nn.utils.rnn import pad_sequence
6+
57
import pytest
68

79
from alphafold3_pytorch import (
@@ -43,8 +45,6 @@ def test_calc_smooth_lddt_loss():
4345

4446
assert torch.all(loss <= 1) and torch.all(loss >= 0)
4547

46-
# ToDo tests
47-
4848
def test_smooth_lddt_loss():
4949
pred_coords = torch.randn(2, 100, 3)
5050
true_coords = torch.randn(2, 100, 3)
@@ -66,6 +66,37 @@ def test_weighted_rigid_align():
6666

6767
assert aligned_coords.shape == pred_coords.shape
6868

69+
def test_weighted_rigid_align_with_mask():
70+
pred_coords = torch.randn(2, 100, 3)
71+
true_coords = torch.randn(2, 100, 3)
72+
weights = torch.rand(2, 100)
73+
mask = torch.randint(0, 2, (2, 100)).bool()
74+
75+
align_fn = WeightedRigidAlign()
76+
77+
# with mask
78+
79+
aligned_coords = align_fn(pred_coords, true_coords, weights, mask = mask)
80+
81+
# do it one sample at a time without make
82+
83+
all_aligned_coords = []
84+
85+
for one_mask, one_pred_coords, one_true_coords, one_weight in zip(mask, pred_coords, true_coords, weights):
86+
one_aligned_coords = align_fn(
87+
one_pred_coords[one_mask][None, ...],
88+
one_true_coords[one_mask][None, ...],
89+
one_weight[one_mask][None, ...]
90+
)
91+
92+
all_aligned_coords.append(one_aligned_coords.squeeze(0))
93+
94+
aligned_coords_without_mask = torch.cat(all_aligned_coords, dim = 0)
95+
96+
# both ways should come out with about the same results
97+
98+
assert torch.allclose(aligned_coords[mask], aligned_coords_without_mask, atol=1e-5)
99+
69100
def test_express_coordinates_in_frame():
70101
batch_size = 2
71102
num_coords = 100

0 commit comments

Comments
 (0)