Skip to content

Commit 52d1475

Browse files
authored
Standardize all distogram binning logic (#213)
* Update README.md * Update template_parsing.py * Update alphafold3.py * Update model_utils.py * Update trainer_with_atom_dataset_created_from_pdb.yaml * Update alphafold3.yaml * Update trainer.yaml * Update trainer_with_atom_dataset.yaml * Update trainer_with_atom_dataset.yaml * Update alphafold3.yaml * Update trainer.yaml * Update trainer_with_pdb_dataset.yaml * Update trainer_with_pdb_dataset_and_weighted_sampling.yaml * Update training_with_pdb_dataset.yaml * Update test_dataloading.py * Update test_input.py * Update test_af3.py * Update test_trainer.py * Update test_af3.py * Update mocks.py * Update alphafold3.py * Update alphafold3.py
1 parent c43f8c0 commit 52d1475

13 files changed

+44
-44
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ alphafold3 = Alphafold3(
202202
dim_atompair_inputs = 5,
203203
atoms_per_window = 27,
204204
dim_template_feats = 108,
205-
num_dist_bins = 38,
205+
num_dist_bins = 64,
206206
num_molecule_mods = 0,
207207
confidence_head_kwargs = dict(
208208
pairformer_depth = 1

alphafold3_pytorch/alphafold3.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4218,7 +4218,7 @@ def __init__(
42184218
self,
42194219
*,
42204220
dim_pairwise = 128,
4221-
num_dist_bins = 38,
4221+
num_dist_bins = 64,
42224222
dim_atom = 128,
42234223
atom_resolution = False,
42244224
checkpoint = False,
@@ -5236,11 +5236,7 @@ class ComputeModelSelectionScore(Module):
52365236
def __init__(
52375237
self,
52385238
eps: float = 1e-8,
5239-
dist_breaks: Float[" dist_break"] = torch.linspace(
5240-
2.3125,
5241-
21.6875,
5242-
37,
5243-
),
5239+
dist_breaks: Float[" dist_break"] = torch.linspace(2, 22, 63),
52445240
nucleic_acid_cutoff: float = 30.0,
52455241
other_cutoff: float = 15.0,
52465242
contact_mask_threshold: float = 8.0,
@@ -5891,7 +5887,7 @@ def __init__(
58915887
num_atom_embeds: int | None = None,
58925888
num_atompair_embeds: int | None = None,
58935889
num_molecule_mods: int | None = DEFAULT_NUM_MOLECULE_MODS,
5894-
distance_bins: List[float] = torch.linspace(3, 20, 38).float().tolist(),
5890+
distance_bins: List[float] = torch.linspace(2, 22, 64).float().tolist(), # NOTE: in paper, they reuse AF2's setup of having 64 bins from 2 to 22
58955891
pae_bins: List[float] = torch.linspace(0.5, 32, 64).float().tolist(),
58965892
pde_bins: List[float] = torch.linspace(0.5, 32, 64).float().tolist(),
58975893
ignore_index = -1,

alphafold3_pytorch/data/template_parsing.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,16 +121,16 @@ def parse_m8(
121121
template_release_date = extract_mmcif_metadata_field(
122122
template_mmcif_object, "release_date"
123123
)
124-
template_biomol = _from_mmcif_object(
125-
template_mmcif_object, chain_ids=set(template_chain)
126-
)
127124
if not (
128125
exists(template_cutoff_date)
129126
and datetime.strptime(template_release_date, "%Y-%m-%d") <= template_cutoff_date
130127
):
131128
continue
132129
elif not exists(template_cutoff_date):
133130
pass
131+
template_biomol = _from_mmcif_object(
132+
template_mmcif_object, chain_ids=set(template_chain)
133+
)
134134
if len(template_biomol.atom_positions):
135135
template_biomols.append((template_biomol, template_type))
136136
except Exception as e:
@@ -148,7 +148,7 @@ def _extract_template_features(
148148
query_chemtype: List[str],
149149
num_restype_classes: int = 32,
150150
num_distogram_bins: int = 39,
151-
distance_bins: List[float] = torch.linspace(3.25, 50.75, 38).float(),
151+
distance_bins: List[float] = torch.linspace(3.25, 50.75, 39).float(),
152152
verbose: bool = False,
153153
) -> Dict[str, Any]:
154154
"""Parse atom positions in the target structure and align with the query.
@@ -181,9 +181,9 @@ def _extract_template_features(
181181
f"Mapping length {len(mapping)} must match query sequence length {len(query_sequence)} "
182182
f"and query chemtype length {len(query_chemtype)}."
183183
)
184-
assert num_distogram_bins == len(distance_bins) + 1, (
184+
assert num_distogram_bins == len(distance_bins), (
185185
f"Number of distance bins {num_distogram_bins} must match the length of distance bins "
186-
f"{len(distance_bins)} plus one."
186+
f"{len(distance_bins)}."
187187
)
188188

189189
all_atom_positions = template_biomol.atom_positions

alphafold3_pytorch/mocks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def __getitem__(self, idx):
8686
molecule_atom_indices += atom_offsets
8787
distogram_atom_indices += atom_offsets
8888

89-
distance_labels = torch.randint(0, 37, (seq_len, seq_len))
89+
distance_labels = torch.randint(0, 64, (seq_len, seq_len))
9090
resolved_labels = torch.randint(0, 2, (atom_seq_len,))
9191

9292
majority_asym_id = asym_id.mode().values.item()

alphafold3_pytorch/utils/model_utils.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,27 +38,31 @@ def default_lambda_lr_fn(steps: int) -> float:
3838

3939
@typecheck
4040
def distance_to_dgram(
41-
distance: Float['... dist'],
42-
bins: Float[' bins'],
41+
distance: Float["... dist"], # type: ignore
42+
bins: Float[" bins"], # type: ignore
4343
return_labels: bool = False,
44-
) -> Int['... dist'] | Int['... dist bins']:
45-
"""
46-
converting from distance to discrete bins, for distance_labels and pae_labels
47-
using the same logic as openfold
44+
) -> Int["... dist"] | Int["... dist bins"]: # type: ignore
45+
"""Converting from distance to discrete bins, e.g., for distance_labels and pae_labels using
46+
the same logic as OpenFold.
47+
48+
:param distance: The distance tensor.
49+
:param bins: The bins tensor.
50+
:param return_labels: Whether to return the labels.
51+
:return: The one-hot bins tensor or the bin labels.
4852
"""
4953

50-
distance = distance ** 2
54+
distance = distance.abs()
5155

52-
bins = F.pad(bins ** 2, (0, 1), value = float('inf'))
56+
bins = F.pad(bins, (0, 1), value = float('inf'))
5357
low, high = bins[:-1], bins[1:]
5458

5559
one_hot = (
56-
einx.greater_equal('..., bin_low -> ... bin_low', distance, low) &
57-
einx.less('..., bin_high -> ... bin_high', distance, high)
60+
einx.greater_equal("..., bin_low -> ... bin_low", distance, low)
61+
& einx.less("..., bin_high -> ... bin_high", distance, high)
5862
).long()
5963

6064
if return_labels:
61-
return one_hot.argmax(dim = -1)
65+
return one_hot.argmax(dim=-1)
6266

6367
return one_hot
6468

tests/configs/trainer_with_atom_dataset_created_from_pdb.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ model:
1111
dim_template_model: 8
1212
atoms_per_window: 27
1313
dim_template_feats: 108
14-
num_dist_bins: 38
14+
num_dist_bins: 64
1515
ignore_index: -1
1616
num_dist_bins: null
1717
num_plddt_bins: 50

tests/configs/trainer_with_pdb_dataset.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ model:
1111
dim_template_model: 8
1212
atoms_per_window: 27
1313
dim_template_feats: 108
14-
num_dist_bins: 38
14+
num_dist_bins: 64
1515
ignore_index: -1
1616
num_dist_bins: null
1717
num_plddt_bins: 50

tests/configs/trainer_with_pdb_dataset_and_weighted_sampling.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ model:
1111
dim_template_model: 8
1212
atoms_per_window: 27
1313
dim_template_feats: 108
14-
num_dist_bins: 38
14+
num_dist_bins: 64
1515
ignore_index: -1
1616
num_dist_bins: null
1717
num_plddt_bins: 50

tests/configs/training_with_pdb_dataset.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ model:
1717
dim_template_model: 8
1818
atoms_per_window: 27
1919
dim_template_feats: 108
20-
num_dist_bins: 38
20+
num_dist_bins: 64
2121
ignore_index: -1
2222
num_dist_bins: null
2323
num_plddt_bins: 50

tests/test_af3.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,7 @@ def test_alphafold3(
698698
dim_token = 8,
699699
atoms_per_window = atoms_per_window,
700700
dim_template_feats = 108,
701-
num_dist_bins = 38,
701+
num_dist_bins = 64,
702702
num_molecule_mods = num_molecule_mods,
703703
confidence_head_kwargs = dict(
704704
pairformer_depth = 1
@@ -804,7 +804,7 @@ def test_alphafold3_without_msa_and_templates():
804804
alphafold3 = Alphafold3(
805805
dim_atom_inputs = 77,
806806
dim_template_feats = 108,
807-
num_dist_bins = 38,
807+
num_dist_bins = 64,
808808
num_molecule_mods = 0,
809809
checkpoint_trunk_pairformer = True,
810810
checkpoint_diffusion_module = True,
@@ -871,7 +871,7 @@ def test_alphafold3_force_return_loss():
871871
alphafold3 = Alphafold3(
872872
dim_atom_inputs = 77,
873873
dim_template_feats = 108,
874-
num_dist_bins = 38,
874+
num_dist_bins = 64,
875875
num_molecule_mods = 0,
876876
confidence_head_kwargs = dict(
877877
pairformer_depth = 1
@@ -953,7 +953,7 @@ def test_alphafold3_force_return_loss_with_confidence_logits():
953953
alphafold3 = Alphafold3(
954954
dim_atom_inputs = 77,
955955
dim_template_feats = 108,
956-
num_dist_bins = 38,
956+
num_dist_bins = 64,
957957
num_molecule_mods = 0,
958958
confidence_head_kwargs = dict(
959959
pairformer_depth = 1
@@ -1170,7 +1170,7 @@ def test_model_selection_score():
11701170
atom_mask = torch.randint(0, 2, (atom_pos_true.shape[:-1])).type_as(atom_pos_true).bool()
11711171
tok_repr_atm_mask = torch.randint(0, 2, (batch_size, seq_len)).bool()
11721172

1173-
dist_logits = torch.randn(batch_size, 38, seq_len, seq_len)
1173+
dist_logits = torch.randn(batch_size, 64, seq_len, seq_len)
11741174
pde_logits = torch.randn(batch_size, 64, seq_len, seq_len)
11751175

11761176
chain_length = [random.randint(seq_len // 4, seq_len //2)
@@ -1222,7 +1222,7 @@ def test_model_selection_score_end_to_end():
12221222
dim_token = 8,
12231223
atoms_per_window = 27,
12241224
dim_template_feats = 108,
1225-
num_dist_bins = 38,
1225+
num_dist_bins = 64,
12261226
confidence_head_kwargs = dict(
12271227
pairformer_depth = 1
12281228
),
@@ -1427,7 +1427,7 @@ def test_readme2():
14271427
dim_atompair_inputs = 5,
14281428
atoms_per_window = 27,
14291429
dim_template_feats = 108,
1430-
num_dist_bins = 38,
1430+
num_dist_bins = 64,
14311431
num_molecule_mods = 0,
14321432
confidence_head_kwargs = dict(
14331433
pairformer_depth = 1

0 commit comments

Comments
 (0)