Skip to content

Commit 16b78b0

Browse files
authored
Correct application of rotations to coordinates in WeightedRigidAlign within alphafold3.py (#24)
* Correct application of rotations to coordinates in `WeightedRigidAlign` within `alphafold3.py` * Update `WeightedRigidAlign` unit test in `test_af3.py` * Fix order of axes in `alphafold3.py`
1 parent 7f7280f commit 16b78b0

File tree

2 files changed

+34
-11
lines changed

2 files changed

+34
-11
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@
7373

7474
ADDITIONAL_RESIDUE_FEATS = 10
7575

76+
# threshold for checking that point cross-correlation
77+
# is full-rank in `WeightedRigidAlign`
78+
AMBIGUOUS_ROT_SINGULAR_THR = 1e-15
79+
7680
LinearNoBias = partial(Linear, bias = False)
7781

7882
# helper functions
@@ -2210,6 +2214,7 @@ def forward(
22102214
weights: Float['b n'], # weights for each atom
22112215
mask: Bool['b n'] | None = None # mask for variable lengths
22122216
) -> Float['b n 3']:
2217+
batch_size, num_points, dim = pred_coords.shape
22132218

22142219
if exists(mask):
22152220
# zero out all predicted and true coordinates where not an atom
@@ -2228,26 +2233,36 @@ def forward(
22282233
pred_coords_centered = pred_coords - pred_centroid
22292234
true_coords_centered = true_coords - true_centroid
22302235

2236+
if num_points < (dim + 1):
2237+
print(
2238+
"Warning: The size of one of the point clouds is <= dim+1. "
2239+
+ "`WeightedRigidAlign` cannot return a unique rotation."
2240+
)
2241+
22312242
# Compute the weighted covariance matrix
2232-
weighted_true_coords_center = true_coords_centered * weights
2233-
cov_matrix = einsum(weighted_true_coords_center, pred_coords_centered, 'b n i, b n j -> b i j')
2243+
cov_matrix = einsum(weights * true_coords_centered, pred_coords_centered, 'b n i, b n j -> b i j')
22342244

22352245
# Compute the SVD of the covariance matrix
2236-
U, _, V = torch.svd(cov_matrix)
2246+
U, S, V = torch.svd(cov_matrix)
2247+
2248+
# Catch ambiguous rotation by checking the magnitude of singular values
2249+
if (S.abs() <= AMBIGUOUS_ROT_SINGULAR_THR).any() and not (num_points < (dim + 1)):
2250+
print(
2251+
"Warning: Excessively low rank of "
2252+
+ "cross-correlation between aligned point clouds. "
2253+
+ "`WeightedRigidAlign` cannot return a unique rotation."
2254+
)
22372255

22382256
# Compute the rotation matrix
22392257
rot_matrix = einsum(U, V, 'b i j, b k j -> b i k')
22402258

22412259
# Ensure proper rotation matrix with determinant 1
2242-
det = torch.det(rot_matrix)
2243-
det_mask = det < 0
2244-
V_fixed = V.clone()
2245-
V_fixed[det_mask, :, -1] *= -1
2246-
2247-
rot_matrix[det_mask] = einsum(U[det_mask], V_fixed[det_mask], 'b i j, b k j -> b i k')
2260+
F = torch.eye(dim, dtype=cov_matrix.dtype, device=cov_matrix.device)[None].repeat(batch_size, 1, 1)
2261+
F[:, -1, -1] = torch.det(rot_matrix)
2262+
rot_matrix = einsum(U, F, V, "b i j, b j k, b l k -> b i l")
22482263

22492264
# Apply the rotation and translation
2250-
aligned_coords = einsum(pred_coords_centered, rot_matrix, 'b n i, b i j -> b n j') + true_centroid
2265+
aligned_coords = einsum(pred_coords_centered, rot_matrix, 'b n i, b j i -> b n j') + true_centroid
22512266
aligned_coords.detach_()
22522267

22532268
return aligned_coords
@@ -2367,7 +2382,7 @@ def forward(self, coords: Float['b n 3']) -> Float['b n 3']:
23672382
translation_vector = rearrange(translation_vector, 'b c -> b 1 c')
23682383

23692384
# Apply rotation and translation
2370-
augmented_coords = einsum(centered_coords, rotation_matrix, 'b n i, b i j -> b n j') + translation_vector
2385+
augmented_coords = einsum(centered_coords, rotation_matrix, 'b n i, b j i -> b n j') + translation_vector
23712386

23722387
return augmented_coords
23732388

tests/test_af3.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ def test_weighted_rigid_align():
6565
rmsd = torch.sqrt(((pred_coords - aligned_coords) ** 2).sum(dim=-1).mean(dim=-1))
6666
assert (rmsd < 1e-5).all()
6767

68+
random_augment_fn = CentreRandomAugmentation()
69+
aligned_coords = align_fn(random_augment_fn(pred_coords), pred_coords, weights)
70+
71+
# `pred_coords` should match a random augmentation of itself after alignment
72+
73+
rmsd = torch.sqrt(((pred_coords - aligned_coords) ** 2).sum(dim=-1).mean(dim=-1))
74+
assert (rmsd < 1e-5).all()
75+
6876
def test_weighted_rigid_align_with_mask():
6977
pred_coords = torch.randn(2, 100, 3)
7078
true_coords = torch.randn(2, 100, 3)

0 commit comments

Comments
 (0)