Skip to content

Commit 8354dc6

Browse files
authored
Fix OXT parsing bug for amino acids, and add new subsetting options (#155)
* Update inputs.py * Update inputs.py * Update amino_acid_constants.py
1 parent 1091c0a commit 8354dc6

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

alphafold3_pytorch/common/amino_acid_constants.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@
4444
"CZ2",
4545
"CZ3",
4646
"NZ",
47-
"OXT",
47+
# "OXT", # NOTE: This often appears in mmCIF files, but it will not be used for any amino acid type in AlphaFold.
48+
"_",
4849
"_",
4950
"_",
5051
"_",

alphafold3_pytorch/inputs.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2566,6 +2566,13 @@ def pdb_input_to_molecule_input(
25662566
assert len(molecules) == len(missing_atom_indices)
25672567
assert len(missing_token_indices) == num_tokens
25682568

2569+
mol_total_atoms = sum([mol.GetNumAtoms() for mol in molecules])
2570+
num_missing_atom_indices = sum(
2571+
len(mol_miss_atom_indices) for mol_miss_atom_indices in missing_atom_indices
2572+
)
2573+
num_present_atoms = mol_total_atoms - num_missing_atom_indices
2574+
assert num_present_atoms == int(biomol.atom_mask.sum())
2575+
25692576
# TODO: install additional token features once MSAs are available
25702577
# 0: f_profile
25712578
# 1: f_deletion_mean
@@ -2645,6 +2652,7 @@ def __init__(
26452652
spatial_interface_weight: float = 0.4,
26462653
crop_size: int = 384,
26472654
training: bool | None = None, # extra training flag placed by Alex on PDBInput
2655+
sample_only_pdb_ids: Set[str] | None = None,
26482656
**pdb_input_kwargs,
26492657
):
26502658
if isinstance(folder, str):
@@ -2660,6 +2668,7 @@ def __init__(
26602668
self.sampler = sampler
26612669
self.sample_type = sample_type
26622670
self.training = training
2671+
self.sample_only_pdb_ids = sample_only_pdb_ids
26632672
self.pdb_input_kwargs = pdb_input_kwargs
26642673

26652674
self.cropping_config = {
@@ -2679,6 +2688,12 @@ def __init__(
26792688
if file in sampler_pdb_ids
26802689
}
26812690

2691+
if exists(sample_only_pdb_ids):
2692+
assert exists(self.sampler), "A sampler must be provided to use `sample_only_pdb_ids`."
2693+
assert all(
2694+
pdb_id in sampler_pdb_ids for pdb_id in sample_only_pdb_ids
2695+
), "Some PDB IDs in `sample_only_pdb_ids` are not present in the dataset's sampler mappings."
2696+
26822697
assert len(self) > 0, f"No valid mmCIFs / PDBs found at {str(folder)}"
26832698

26842699
def __len__(self):
@@ -2690,10 +2705,18 @@ def __getitem__(self, idx: int | str) -> PDBInput:
26902705
sampled_id = None
26912706

26922707
if exists(self.sampler):
2693-
if self.sample_type == "clustered":
2694-
(sampled_id,) = self.sampler.cluster_based_sample(1)
2695-
else:
2696-
(sampled_id,) = self.sampler.sample(1)
2708+
sample_fn = (
2709+
self.sampler.cluster_based_sample
2710+
if self.sample_type == "clustered"
2711+
else self.sampler.sample
2712+
)
2713+
(sampled_id,) = sample_fn(1)
2714+
2715+
# ensure that the sampled PDB ID is in the specified set of PDB IDs from which to sample
2716+
2717+
if exists(self.sample_only_pdb_ids):
2718+
while sampled_id[0] not in self.sample_only_pdb_ids:
2719+
(sampled_id,) = sample_fn(1)
26972720

26982721
pdb_id, chain_id_1, chain_id_2 = None, None, None
26992722

0 commit comments

Comments
 (0)