Skip to content

Commit 25305ee

Browse files
authored
Make training detection more robust for mmCIF bond insertion (#92)
* Update test_input.py * Update inputs.py
1 parent a76ee7a commit 25305ee

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,6 @@ def map_int_or_string_indices_to_mol(
539539
entries: dict,
540540
indices: Int[' _'] | List[str] | str,
541541
mol_keyname = 'rdchem_mol',
542-
chain = False,
543542
return_entries = False
544543
) -> List[Mol] | Tuple[List[Mol], List[dict]]:
545544

@@ -632,7 +631,7 @@ def alphafold3_input_to_molecule_input(alphafold3_input: Alphafold3Input) -> Mol
632631

633632
for protein in proteins:
634633
mol_peptides, protein_entries = map_int_or_string_indices_to_mol(
635-
HUMAN_AMINO_ACIDS, protein, chain=True, return_entries=True
634+
HUMAN_AMINO_ACIDS, protein, return_entries=True
636635
)
637636
mol_proteins.append(mol_peptides)
638637

@@ -657,7 +656,7 @@ def alphafold3_input_to_molecule_input(alphafold3_input: Alphafold3Input) -> Mol
657656

658657
for seq in ss_rnas:
659658
mol_seq, ss_rna_entries = map_int_or_string_indices_to_mol(
660-
RNA_NUCLEOTIDES, seq, chain=True, return_entries=True
659+
RNA_NUCLEOTIDES, seq, return_entries=True
661660
)
662661
mol_ss_rnas.append(mol_seq)
663662

@@ -675,7 +674,7 @@ def alphafold3_input_to_molecule_input(alphafold3_input: Alphafold3Input) -> Mol
675674

676675
for seq in ss_dnas:
677676
mol_seq, ss_dna_entries = map_int_or_string_indices_to_mol(
678-
DNA_NUCLEOTIDES, seq, chain=True, return_entries=True
677+
DNA_NUCLEOTIDES, seq, return_entries=True
679678
)
680679
mol_ss_dnas.append(mol_seq)
681680

@@ -1005,9 +1004,26 @@ class PDBInput:
10051004
add_atom_ids: bool = False
10061005
add_atompair_ids: bool = False
10071006
directed_bonds: bool = False
1007+
training: bool = False
10081008
extract_atom_feats_fn: Callable[[Atom], Float["m dai"]] = default_extract_atom_feats_fn # type: ignore
10091009
extract_atompair_feats_fn: Callable[[Mol], Float["m m dapi"]] = default_extract_atompair_feats_fn # type: ignore
10101010

1011+
def __post_init__(self):
1012+
"""Run post-init checks."""
1013+
if not os.path.exists(self.mmcif_filepath):
1014+
raise FileNotFoundError(f"mmCIF file not found: {self.mmcif_filepath}.")
1015+
if not self.mmcif_filepath.endswith(".cif"):
1016+
raise ValueError(
1017+
f"mmCIF file `{self.mmcif_filepath}` must have a `.cif` file extension."
1018+
)
1019+
1020+
if self.msa_dir is not None and not os.path.exists(self.msa_dir):
1021+
raise FileNotFoundError(f"Provided MSA directory not found: {self.msa_dir}.")
1022+
if self.templates_dir is not None and not os.path.exists(self.templates_dir):
1023+
raise FileNotFoundError(
1024+
f"Provided templates directory not found: {self.templates_dir}."
1025+
)
1026+
10111027

10121028
@typecheck
10131029
def extract_chain_sequences_from_biomolecule_chemical_components(
@@ -1390,7 +1406,7 @@ def get_token_index_from_composite_atom_id(
13901406

13911407

13921408
@typecheck
1393-
def pdb_input_to_molecule_input(pdb_input: PDBInput, training: bool = True) -> MoleculeInput:
1409+
def pdb_input_to_molecule_input(pdb_input: PDBInput) -> MoleculeInput:
13941410
"""Convert a PDBInput to a MoleculeInput."""
13951411
i = pdb_input
13961412

@@ -1573,7 +1589,7 @@ def pdb_input_to_molecule_input(pdb_input: PDBInput, training: bool = True) -> M
15731589
# per the AF3 supplement (Table 5, `token_bonds`)
15741590
bond_atom_indices = defaultdict(int)
15751591
for bond in biomol.bonds:
1576-
if not training:
1592+
if not i.training:
15771593
continue
15781594

15791595
# determine bond type

tests/test_input.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def test_pdbinput_input():
151151
if os.path.exists(filepath.replace(".cif", "-sampled.cif")):
152152
os.remove(filepath.replace(".cif", "-sampled.cif"))
153153

154-
train_pdb_input = PDBInput(filepath)
154+
train_pdb_input = PDBInput(filepath, training=True)
155155

156156
eval_pdb_input = PDBInput(filepath)
157157

0 commit comments

Comments
 (0)