Skip to content

Commit 6c775d1

Browse files
authored
Ensure ExpressCoordinatesInFrame perfectly matches Algorithm 29 of the AF3 supplement (#25)
* Ensure `ExpressCoordinatesInFrame` perfectly matches Algorithm 29 of the AF3 supplement * Update test_af3.py
1 parent 16b78b0 commit 6c775d1

File tree

2 files changed

+24
-19
lines changed

2 files changed

+24
-19
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2293,22 +2293,26 @@ def forward(
22932293
elif frame.ndim == 3:
22942294
frame = rearrange(frame, 'b fr fc -> b 1 fr fc')
22952295

2296-
# Extract frame points
2297-
a, b, c = frame.unbind(dim = -1)
2298-
2299-
# Compute unit vectors of the frame
2300-
e1 = F.normalize(a - b, dim = -1, eps = self.eps)
2301-
e2 = F.normalize(c - b, dim = -1, eps = self.eps)
2302-
e3 = torch.cross(e1, e2, dim = -1)
2303-
2304-
# Express coordinates in the frame basis
2305-
v = coords - b
2306-
2307-
transformed_coords = torch.stack([
2308-
einsum(v, e1, '... i, ... i -> ...'),
2309-
einsum(v, e2, '... i, ... i -> ...'),
2310-
einsum(v, e3, '... i, ... i -> ...')
2311-
], dim = -1)
2296+
# Extract frame atoms
2297+
a, b, c = frame.unbind(dim=-1)
2298+
w1 = F.normalize(a - b, dim=-1, eps=self.eps)
2299+
w2 = F.normalize(c - b, dim=-1, eps=self.eps)
2300+
2301+
# Build orthonormal basis
2302+
e1 = F.normalize(w1 + w2, dim=-1, eps=self.eps)
2303+
e2 = F.normalize(w2 - w1, dim=-1, eps=self.eps)
2304+
e3 = torch.cross(e1, e2, dim=-1)
2305+
2306+
# Project onto frame basis
2307+
d = coords - b
2308+
transformed_coords = torch.stack(
2309+
[
2310+
einsum(d, e1, '... i, ... i -> ...'),
2311+
einsum(d, e2, '... i, ... i -> ...'),
2312+
einsum(d, e3, '... i, ... i -> ...'),
2313+
],
2314+
dim=-1,
2315+
)
23122316

23132317
return transformed_coords
23142318

tests/test_af3.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,15 @@ def test_express_coordinates_in_frame():
127127

128128
def test_compute_alignment_error():
129129
pred_coords = torch.randn(2, 100, 3)
130-
true_coords = torch.randn(2, 100, 3)
131130
pred_frames = torch.randn(2, 100, 3, 3)
132-
true_frames = torch.randn(2, 100, 3, 3)
131+
132+
# `pred_coords` should match itself in frame basis
133133

134134
error_fn = ComputeAlignmentError()
135-
alignment_errors = error_fn(pred_coords, true_coords, pred_frames, true_frames)
135+
alignment_errors = error_fn(pred_coords, pred_coords, pred_frames, pred_frames)
136136

137137
assert alignment_errors.shape == (2, 100)
138+
assert (alignment_errors.mean(-1) < 1e-3).all()
138139

139140
def test_centre_random_augmentation():
140141
coords = torch.randn(2, 100, 3)

0 commit comments

Comments
 (0)