Skip to content

Commit 8e8b799

Browse files
authored
#201 Move dgram calculation to model_utils, parametrize label returning (#208)
1 parent f3109e8 commit 8e8b799

File tree

3 files changed

+37
-53
lines changed

3 files changed

+37
-53
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@
7575
package_available,
7676
)
7777

78+
from alphafold3_pytorch.utils.model_utils import distance_to_dgram
79+
7880
from frame_averaging_pytorch import FrameAverage
7981

8082
from taylor_series_linear_attention import TaylorSeriesLinearAttn
@@ -455,28 +457,6 @@ def batch_repeat_interleave_pairwise(
455457
pairwise = batch_repeat_interleave(pairwise, molecule_atom_lens)
456458
return unpack_one(pairwise)
457459

458-
@typecheck
459-
def distance_to_bins(
460-
distance: Float['... dist'],
461-
bins: Float[' bins']
462-
) -> Int['... dist']:
463-
"""
464-
converting from distance to discrete bins, for distance_labels and pae_labels
465-
using the same logic as openfold
466-
"""
467-
468-
distance = distance ** 2
469-
470-
bins = F.pad(bins ** 2, (0, 1), value = float('inf'))
471-
low, high = bins[:-1], bins[1:]
472-
473-
one_hot = (
474-
einx.greater_equal('..., bin_low -> ... bin_low', distance, low) &
475-
einx.less('..., bin_high -> ... bin_high', distance, high)
476-
).long()
477-
478-
return one_hot.argmax(dim = -1)
479-
480460
# linear and outer sum
481461
# for single repr -> pairwise pattern throughout this architecture
482462

@@ -4501,7 +4481,9 @@ def forward(
45014481

45024482
intermolecule_dist = torch.cdist(pred_molecule_pos, pred_molecule_pos, p=2)
45034483

4504-
dist_bin_indices = distance_to_bins(intermolecule_dist, self.atompair_dist_bins)
4484+
dist_bin_indices = distance_to_dgram(
4485+
intermolecule_dist, self.atompair_dist_bins, return_labels=True
4486+
)
45054487
pairwise_repr = pairwise_repr + self.dist_bin_pairwise_embed(dist_bin_indices)
45064488

45074489
# pairformer stack
@@ -6734,7 +6716,9 @@ def forward(
67346716
distogram_mask = atom_mask
67356717

67366718
distogram_dist = torch.cdist(distogram_pos, distogram_pos, p=2)
6737-
distance_labels = distance_to_bins(distogram_dist, self.distance_bins)
6719+
distance_labels = distance_to_dgram(
6720+
distogram_dist, self.distance_bins, return_labels = True
6721+
)
67386722

67396723
# account for representative distogram atom missing from residue (-1 set on distogram_atom_indices field)
67406724

@@ -7014,9 +6998,9 @@ def forward(
70146998
mask=align_error_mask,
70156999
)
70167000

7017-
# calculate pae labels as alignment error binned to 64 (0 - 32A) (TODO: double-check correctness of `distance_to_bins`'s bin assignments)
7001+
# calculate pae labels as alignment error binned to 64 (0 - 32A)
70187002

7019-
pae_labels = distance_to_bins(align_error, self.pae_bins)
7003+
pae_labels = distance_to_dgram(align_error, self.pae_bins, return_labels = True)
70207004

70217005
# set ignore index for invalid molecules or frames
70227006

@@ -7058,7 +7042,7 @@ def forward(
70587042
# calculate pde labels as distance error binned to 64 (0 - 32A)
70597043

70607044
pde_dist = torch.abs(pde_pred_dist - pde_gt_dist)
7061-
pde_labels = distance_to_bins(pde_dist, self.pde_bins)
7045+
pde_labels = distance_to_dgram(pde_dist, self.pde_bins, return_labels = True)
70627046

70637047
# account for representative molecule atom missing from residue (-1 set on molecule_atom_indices field)
70647048

alphafold3_pytorch/data/template_parsing.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from alphafold3_pytorch.utils.model_utils import (
2525
ExpressCoordinatesInFrame,
2626
RigidFrom3Points,
27-
distance_to_bins,
27+
distance_to_dgram,
2828
get_frames_from_atom_pos,
2929
)
3030
from alphafold3_pytorch.tensor_typing import typecheck
@@ -367,17 +367,7 @@ def _extract_template_features(
367367
template_distogram_dist = torch.cdist(
368368
template_distogram_atom_positions, template_distogram_atom_positions, p=2
369369
)
370-
template_distogram_dist_binned = distance_to_bins(template_distogram_dist, distance_bins)
371-
372-
template_distogram_dist_binned[
373-
# NOTE: This assigns the last bin to distances greater than the maximum bin (e.g., > 50.75 Å).
374-
template_distogram_dist
375-
> distance_bins.max()
376-
] = (num_distogram_bins - 1)
377-
378-
template_distogram = F.one_hot(
379-
template_distogram_dist_binned, num_classes=num_distogram_bins
380-
).float()
370+
template_distogram = distance_to_dgram(template_distogram_dist, distance_bins)
381371

382372
# Construct unit vectors.
383373
template_unit_vector = torch.zeros(

alphafold3_pytorch/utils/model_utils.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,30 @@ def default_lambda_lr_fn(steps: int) -> float:
3737

3838

3939
@typecheck
40-
def distance_to_bins(
41-
distance: Float["... dist"], # type: ignore
42-
bins: Float[" bins"], # type: ignore
43-
) -> Int["... dist"]: # type: ignore
44-
"""Convert from distance to discrete bins, e.g., for `distance_labels`.
45-
46-
:param distance: The distance tensor.
47-
:param bins: The bins tensor.
48-
:return: The discrete bins.
49-
"""
50-
dist_from_dist_bins = einx.subtract(
51-
"... dist, dist_bins -> ... dist dist_bins", distance, bins
52-
).abs()
53-
return dist_from_dist_bins.argmin(dim=-1)
40+
def distance_to_dgram(
41+
distance: Float['... dist'],
42+
bins: Float[' bins'],
43+
return_labels: bool = False,
44+
) -> Int['... dist']:
45+
"""
46+
converting from distance to discrete bins, for distance_labels and pae_labels
47+
using the same logic as openfold
48+
"""
49+
50+
distance = distance ** 2
51+
52+
bins = F.pad(bins ** 2, (0, 1), value = float('inf'))
53+
low, high = bins[:-1], bins[1:]
54+
55+
one_hot = (
56+
einx.greater_equal('..., bin_low -> ... bin_low', distance, low) &
57+
einx.less('..., bin_high -> ... bin_high', distance, high)
58+
).long()
59+
60+
if return_labels:
61+
return one_hot.argmax(dim = -1)
62+
63+
return one_hot
5464

5565

5666
@typecheck

0 commit comments

Comments
 (0)