|
75 | 75 | package_available, |
76 | 76 | ) |
77 | 77 |
|
| 78 | +from alphafold3_pytorch.utils.model_utils import distance_to_dgram |
| 79 | + |
78 | 80 | from frame_averaging_pytorch import FrameAverage |
79 | 81 |
|
80 | 82 | from taylor_series_linear_attention import TaylorSeriesLinearAttn |
@@ -455,28 +457,6 @@ def batch_repeat_interleave_pairwise( |
455 | 457 | pairwise = batch_repeat_interleave(pairwise, molecule_atom_lens) |
456 | 458 | return unpack_one(pairwise) |
457 | 459 |
|
458 | | -@typecheck |
459 | | -def distance_to_bins( |
460 | | - distance: Float['... dist'], |
461 | | - bins: Float[' bins'] |
462 | | -) -> Int['... dist']: |
463 | | - """ |
464 | | - converting from distance to discrete bins, for distance_labels and pae_labels |
465 | | - using the same logic as openfold |
466 | | - """ |
467 | | - |
468 | | - distance = distance ** 2 |
469 | | - |
470 | | - bins = F.pad(bins ** 2, (0, 1), value = float('inf')) |
471 | | - low, high = bins[:-1], bins[1:] |
472 | | - |
473 | | - one_hot = ( |
474 | | - einx.greater_equal('..., bin_low -> ... bin_low', distance, low) & |
475 | | - einx.less('..., bin_high -> ... bin_high', distance, high) |
476 | | - ).long() |
477 | | - |
478 | | - return one_hot.argmax(dim = -1) |
479 | | - |
480 | 460 | # linear and outer sum |
481 | 461 | # for single repr -> pairwise pattern throughout this architecture |
482 | 462 |
|
@@ -4501,7 +4481,9 @@ def forward( |
4501 | 4481 |
|
4502 | 4482 | intermolecule_dist = torch.cdist(pred_molecule_pos, pred_molecule_pos, p=2) |
4503 | 4483 |
|
4504 | | - dist_bin_indices = distance_to_bins(intermolecule_dist, self.atompair_dist_bins) |
| 4484 | + dist_bin_indices = distance_to_dgram( |
| 4485 | + intermolecule_dist, self.atompair_dist_bins, return_labels=True |
| 4486 | + ) |
4505 | 4487 | pairwise_repr = pairwise_repr + self.dist_bin_pairwise_embed(dist_bin_indices) |
4506 | 4488 |
|
4507 | 4489 | # pairformer stack |
@@ -6734,7 +6716,9 @@ def forward( |
6734 | 6716 | distogram_mask = atom_mask |
6735 | 6717 |
|
6736 | 6718 | distogram_dist = torch.cdist(distogram_pos, distogram_pos, p=2) |
6737 | | - distance_labels = distance_to_bins(distogram_dist, self.distance_bins) |
| 6719 | + distance_labels = distance_to_dgram( |
| 6720 | + distogram_dist, self.distance_bins, return_labels = True |
| 6721 | + ) |
6738 | 6722 |
|
6739 | 6723 | # account for representative distogram atom missing from residue (-1 set on distogram_atom_indices field) |
6740 | 6724 |
|
@@ -7014,9 +6998,9 @@ def forward( |
7014 | 6998 | mask=align_error_mask, |
7015 | 6999 | ) |
7016 | 7000 |
|
7017 | | - # calculate pae labels as alignment error binned to 64 (0 - 32A) (TODO: double-check correctness of `distance_to_bins`'s bin assignments) |
| 7001 | + # calculate pae labels as alignment error binned to 64 (0 - 32A) |
7018 | 7002 |
|
7019 | | - pae_labels = distance_to_bins(align_error, self.pae_bins) |
| 7003 | + pae_labels = distance_to_dgram(align_error, self.pae_bins, return_labels = True) |
7020 | 7004 |
|
7021 | 7005 | # set ignore index for invalid molecules or frames |
7022 | 7006 |
|
@@ -7058,7 +7042,7 @@ def forward( |
7058 | 7042 | # calculate pde labels as distance error binned to 64 (0 - 32A) |
7059 | 7043 |
|
7060 | 7044 | pde_dist = torch.abs(pde_pred_dist - pde_gt_dist) |
7061 | | - pde_labels = distance_to_bins(pde_dist, self.pde_bins) |
| 7045 | + pde_labels = distance_to_dgram(pde_dist, self.pde_bins, return_labels = True) |
7062 | 7046 |
|
7063 | 7047 | # account for representative molecule atom missing from residue (-1 set on molecule_atom_indices field) |
7064 | 7048 |
|
|
0 commit comments