Skip to content

Commit 4c42f1b

Browse files
authored
Transpose SVD matrix V in WeightedRigidAlign to get correct alignments (#22)
* Transpose SVD matrix `V` in `WeightedRigidAlign` to get correct alignment * Update test_af3.py
1 parent 932d71b commit 4c42f1b

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2198,15 +2198,15 @@ def forward(
21982198
U, _, V = torch.svd(cov_matrix)
21992199

22002200
# Compute the rotation matrix
2201-
rot_matrix = einsum(U, V, 'b i j, b j k -> b i k')
2201+
rot_matrix = einsum(U, V, 'b i j, b k j -> b i k')
22022202

22032203
# Ensure proper rotation matrix with determinant 1
22042204
det = torch.det(rot_matrix)
22052205
det_mask = det < 0
22062206
V_fixed = V.clone()
22072207
V_fixed[det_mask, :, -1] *= -1
22082208

2209-
rot_matrix[det_mask] = einsum(U[det_mask], V_fixed[det_mask], 'b i j, b j k -> b i k')
2209+
rot_matrix[det_mask] = einsum(U[det_mask], V_fixed[det_mask], 'b i j, b k j -> b i k')
22102210

22112211
# Apply the rotation and translation
22122212
aligned_coords = einsum(pred_coords_centered, rot_matrix, 'b n i, b i j -> b n j') + true_centroid

tests/test_af3.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,15 @@ def test_smooth_lddt_loss():
5555

5656
def test_weighted_rigid_align():
5757
pred_coords = torch.randn(2, 100, 3)
58-
true_coords = torch.randn(2, 100, 3)
5958
weights = torch.rand(2, 100)
6059

6160
align_fn = WeightedRigidAlign()
62-
aligned_coords = align_fn(pred_coords, true_coords, weights)
61+
aligned_coords = align_fn(pred_coords, pred_coords, weights)
62+
63+
# `pred_coords` should match itself without any change after alignment
6364

64-
assert aligned_coords.shape == pred_coords.shape
65+
rmsd = torch.sqrt(((pred_coords - aligned_coords) ** 2).sum(dim=-1).mean(dim=-1))
66+
assert (rmsd < 1e-5).all()
6567

6668
def test_weighted_rigid_align_with_mask():
6769
pred_coords = torch.randn(2, 100, 3)

0 commit comments

Comments
 (0)