Skip to content

Commit c4ba571

Browse files
authored
add a hard validation for the indices under an env flag (#174)
add a hard validation for all indices under an env flag or `hard_validate` on Alphafold3.forward
1 parent 0ddb3b0 commit c4ba571

File tree

5 files changed

+77
-10
lines changed

5 files changed

+77
-10
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
DEFAULT_NUM_MOLECULE_MODS,
6060
ADDITIONAL_MOLECULE_FEATS,
6161
BatchedAtomInput,
62+
hard_validate_atom_indices_ascending
6263
)
6364

6465
from alphafold3_pytorch.common.biomolecule import (
@@ -5347,7 +5348,8 @@ def forward(
53475348
rollout_show_tqdm_pbar: bool = False,
53485349
detach_when_recycling: bool = None,
53495350
min_conf_resolution: float = 0.1,
5350-
max_conf_resolution: float = 4.0
5351+
max_conf_resolution: float = 4.0,
5352+
hard_validate: bool = False
53515353
) -> (
53525354
Float['b m 3'] |
53535355
Tuple[Float['b m 3'], ConfidenceHeadLogits] |
@@ -5363,6 +5365,13 @@ def forward(
53635365
assert atom_inputs.shape[-1] == self.dim_atom_inputs, f'expected {self.dim_atom_inputs} for atom_inputs feature dimension, but received {atom_inputs.shape[-1]}'
53645366
assert atompair_inputs.shape[-1] == self.dim_atompair_inputs, f'expected {self.dim_atompair_inputs} for atompair_inputs feature dimension, but received {atompair_inputs.shape[-1]}'
53655367

5368+
# hard validate when debug env variable is turned on
5369+
5370+
if hard_validate or IS_DEBUGGING:
5371+
maybe(hard_validate_atom_indices_ascending)(distogram_atom_indices, 'distogram_atom_indices')
5372+
maybe(hard_validate_atom_indices_ascending)(molecule_atom_indices, 'molecule_atom_indices')
5373+
maybe(hard_validate_atom_indices_ascending)(atom_indices_for_frame, 'atom_indices_for_frame')
5374+
53665375
# soft validate
53675376

53685377
valid_molecule_atom_mask = valid_atom_len_mask = molecule_atom_lens >= 0
@@ -5385,8 +5394,6 @@ def forward(
53855394

53865395
assert exists(molecule_atom_lens) or exists(atom_mask)
53875396

5388-
# hard validate when debug env variable is turned on
5389-
53905397
if IS_DEBUGGING:
53915398
assert (molecule_atom_lens >= 0).all(), 'molecule_atom_lens must be greater or equal to 0'
53925399

alphafold3_pytorch/inputs.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,33 @@ def inner(x, *args, **kwargs):
166166
return fn(x, *args, **kwargs)
167167
return inner
168168

169+
# validation functions
170+
171+
def hard_validate_atom_indices_ascending(
172+
indices: Int['b n'] | Int['b n 3'],
173+
error_msg_field: str = 'indices'
174+
):
175+
# will do a hard validate
176+
# asserting if any of the indices that are not -1 (missing) are identical or descending
177+
# this will cover 'distogram_atom_indices', 'molecule_atom_indices', and 'atom_indices_for_frame'
178+
179+
if indices.ndim == 2:
180+
indices = rearrange(indices, '... -> ... 1')
181+
182+
for batch_index, sample_indices in enumerate(indices):
183+
184+
all_present = (sample_indices >= 0).all(dim = -1)
185+
present_indices = sample_indices[all_present]
186+
187+
# relaxed assumption that if all -1 or only one molecule, it passes the test
188+
189+
if present_indices.numel() <= 1:
190+
continue
191+
192+
difference = einx.subtract('n i, n j -> n (i j)', present_indices[1:], present_indices[:-1])
193+
194+
assert (difference >= 0).all(), f'detected invalid {error_msg_field} for in a batch: {present_indices}'
195+
169196
# atom level, what Alphafold3 accepts
170197

171198
UNCOLLATABLE_ATOM_INPUT_FIELDS = {'filepath'}

alphafold3_pytorch/mocks.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
DEFAULT_NUM_MOLECULE_MODS
1010
)
1111

12+
from alphafold3_pytorch.utils.model_utils import exclusive_cumsum
13+
1214
# mock dataset
1315

1416
class MockAtomDataset(Dataset):
@@ -37,6 +39,8 @@ def __getitem__(self, idx):
3739
atompair_inputs = torch.randn(atom_seq_len, atom_seq_len, 5)
3840

3941
molecule_atom_lens = torch.randint(1, self.atoms_per_window, (seq_len,))
42+
atom_offsets = exclusive_cumsum(molecule_atom_lens)
43+
4044
additional_molecule_feats = torch.randint(0, 2, (seq_len, 5))
4145
additional_token_feats = torch.randn(seq_len, 2)
4246
is_molecule_types = torch.randint(0, 2, (seq_len, IS_MOLECULE_TYPES)).bool()
@@ -77,6 +81,9 @@ def __getitem__(self, idx):
7781
molecule_atom_indices = molecule_atom_lens - 1
7882
distogram_atom_indices = molecule_atom_lens - 1
7983

84+
molecule_atom_indices += atom_offsets
85+
distogram_atom_indices += atom_offsets
86+
8087
distance_labels = torch.randint(0, 37, (atom_seq_len, atom_seq_len))
8188
resolved_labels = torch.randint(0, 2, (atom_seq_len,))
8289

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.3.15"
3+
version = "0.3.16"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

tests/test_af3.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
os.environ['TYPECHECK'] = 'True'
3+
os.environ['DEBUG'] = 'True'
34

45
import pytest
56
import random
@@ -63,6 +64,8 @@
6364
default_extract_atompair_feats_fn
6465
)
6566

67+
from alphafold3_pytorch.utils.model_utils import exclusive_cumsum
68+
6669
DATA_TEST_PDB_ID = '721p'
6770

6871
def test_atom_ref_pos_to_atompair_inputs():
@@ -538,11 +541,13 @@ def test_alphafold3(
538541
num_molecule_mods: int,
539542
):
540543
seq_len = 16
541-
atom_seq_len = 32
542544
atoms_per_window = 27
543545

544546
molecule_atom_indices = torch.randint(0, 2, (2, seq_len)).long()
545-
molecule_atom_lens = torch.full((2, seq_len), 2).long()
547+
molecule_atom_lens = torch.full((2, seq_len), 3).long()
548+
549+
atom_seq_len = molecule_atom_lens.sum(dim = -1).amax()
550+
atom_offset = exclusive_cumsum(molecule_atom_lens)
546551

547552
token_bonds = torch.randint(0, 2, (2, seq_len, seq_len)).bool()
548553

@@ -564,7 +569,8 @@ def test_alphafold3(
564569

565570
atom_indices_for_frame = None
566571
if calculate_pae:
567-
atom_indices_for_frame = repeat(torch.arange(3), 'c -> b n c', b = 2, n = seq_len)
572+
atom_indices_for_frame = repeat(torch.arange(3), 'c -> b n c', b = 2, n = seq_len).clone()
573+
atom_indices_for_frame += atom_offset[..., None]
568574

569575
missing_atom_mask = None
570576
if missing_atoms:
@@ -586,6 +592,13 @@ def test_alphafold3(
586592

587593
resolved_labels = torch.randint(0, 2, (2, atom_seq_len))
588594

595+
# offset indices correctly
596+
597+
distogram_atom_indices += atom_offset
598+
molecule_atom_indices += atom_offset
599+
600+
# alphafold3
601+
589602
alphafold3 = Alphafold3(
590603
dim_atom_inputs = 77,
591604
dim_pairwise = 8,
@@ -750,8 +763,11 @@ def test_alphafold3_force_return_loss():
750763
seq_len = 16
751764
atom_seq_len = 32
752765

753-
molecule_atom_indices = torch.randint(0, 2, (2, seq_len)).long()
754766
molecule_atom_lens = torch.full((2, seq_len), 2).long()
767+
atom_offsets = exclusive_cumsum(molecule_atom_lens)
768+
769+
molecule_atom_indices = torch.randint(0, 2, (2, seq_len)).long()
770+
molecule_atom_indices += atom_offsets
755771

756772
atom_inputs = torch.randn(2, atom_seq_len, 77)
757773
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
@@ -761,7 +777,9 @@ def test_alphafold3_force_return_loss():
761777
molecule_ids = torch.randint(0, 32, (2, seq_len))
762778

763779
atom_pos = torch.randn(2, atom_seq_len, 3)
780+
764781
distogram_atom_indices = molecule_atom_lens - 1
782+
distogram_atom_indices += atom_offsets
765783

766784
resolved_labels = torch.randint(0, 2, (2, atom_seq_len))
767785

@@ -828,8 +846,11 @@ def test_alphafold3_force_return_loss_with_confidence_logits():
828846
seq_len = 16
829847
atom_seq_len = 32
830848

831-
molecule_atom_indices = torch.randint(0, 2, (2, seq_len)).long()
832849
molecule_atom_lens = torch.full((2, seq_len), 2).long()
850+
atom_offsets = exclusive_cumsum(molecule_atom_lens)
851+
852+
molecule_atom_indices = torch.randint(0, 2, (2, seq_len)).long()
853+
molecule_atom_indices += atom_offsets
833854

834855
atom_inputs = torch.randn(2, atom_seq_len, 77)
835856
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
@@ -840,6 +861,7 @@ def test_alphafold3_force_return_loss_with_confidence_logits():
840861

841862
atom_pos = torch.randn(2, atom_seq_len, 3)
842863
distogram_atom_indices = molecule_atom_lens - 1
864+
distogram_atom_indices += atom_offsets
843865

844866
resolved_labels = torch.randint(0, 2, (2, atom_seq_len))
845867

@@ -916,8 +938,11 @@ def test_alphafold3_with_atom_and_bond_embeddings():
916938
seq_len = 16
917939
atom_seq_len = 32
918940

919-
molecule_atom_indices = torch.randint(0, 2, (2, seq_len)).long()
920941
molecule_atom_lens = torch.full((2, seq_len), 2).long()
942+
atom_offset = exclusive_cumsum(molecule_atom_lens)
943+
944+
molecule_atom_indices = torch.randint(0, 2, (2, seq_len)).long()
945+
molecule_atom_indices += atom_offset
921946

922947
atom_ids = torch.randint(0, 7, (2, atom_seq_len))
923948
atompair_ids = torch.randint(0, 3, (2, atom_seq_len, atom_seq_len))
@@ -940,6 +965,7 @@ def test_alphafold3_with_atom_and_bond_embeddings():
940965

941966
atom_pos = torch.randn(2, atom_seq_len, 3)
942967
distogram_atom_indices = molecule_atom_lens - 1 # last atom, as an example
968+
distogram_atom_indices += atom_offset
943969

944970
resolved_labels = torch.randint(0, 2, (2, atom_seq_len))
945971

0 commit comments

Comments
 (0)