Skip to content

Commit 03de1c2

Browse files
authored
Handle edge case for unknown polymer residues (#228)
* Update cli.py * Update amino_acid_constants.py * Update dna_constants.py * Update rna_constants.py * Update inputs.py
1 parent a62595e commit 03de1c2

File tree

5 files changed

+14
-7
lines changed

5 files changed

+14
-7
lines changed

alphafold3_pytorch/cli.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@
1616
@click.command()
1717
@click.option('-ckpt', '--checkpoint', type = str, help = 'path to alphafold3 checkpoint')
1818
@click.option('-p', '--protein', type = str, help = 'one protein sequence')
19-
@click.option('-o', '--output', type = str, help = 'output path', default = 'output.mmcif')
19+
@click.option('-o', '--output', type = str, help = 'output path', default = 'output.cif')
2020
def cli(
2121
checkpoint: str,
2222
protein: str,
2323
output: str
2424
):
2525

2626
checkpoint_path = Path(checkpoint)
27-
assert checkpoint_path.exists(), f'alphafold3 checkpoint must exist at {str(checkpoint_path)}'
27+
assert checkpoint_path.exists(), f'AlphaFold 3 checkpoint must exist at {str(checkpoint_path)}'
2828

2929
alphafold3_input = Alphafold3Input(
3030
proteins = [protein],
@@ -44,4 +44,4 @@ def cli(
4444
pdb_writer.set_structure(structure)
4545
pdb_writer.save(str(output_path))
4646

47-
print(f'mmcif saved to {str(output_path)}')
47+
print(f'mmCIF file saved to {str(output_path)}')

alphafold3_pytorch/common/amino_acid_constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@
216216

217217
def _make_constants():
218218
"""Fill the array(s) above."""
219-
for restype, restype_letter in enumerate(restypes):
219+
for restype, restype_letter in enumerate(restype_1to3.keys()):
220220
resname = restype_1to3[restype_letter]
221221
for compact_atomidx, atomname in enumerate(restype_name_to_compact_atom_names[resname]):
222222
if not atomname:

alphafold3_pytorch/common/dna_constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@
255255

256256
def _make_constants():
257257
"""Fill the array(s) above."""
258-
for restype, restype_letter in enumerate(restypes):
258+
for restype, restype_letter in enumerate(restype_1to3.keys()):
259259
resname = restype_1to3[restype_letter]
260260
for atomname in restype_name_to_compact_atom_names[resname]:
261261
if not atomname:

alphafold3_pytorch/common/rna_constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@
246246

247247
def _make_constants():
248248
"""Fill the array(s) above."""
249-
for restype, restype_letter in enumerate(restypes):
249+
for restype, restype_letter in enumerate(restype_1to3.keys()):
250250
resname = restype_1to3[restype_letter]
251251
for atomname in restype_name_to_compact_atom_names[resname]:
252252
if not atomname:

alphafold3_pytorch/inputs.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2281,10 +2281,17 @@ def extract_canonical_molecules_from_biomolecule_chains(
22812281
contiguous_res_atom_mapping = np.vectorize(contiguous_res_atom_mapping.get)(
22822282
res_atom_mapping
22832283
)
2284-
22852284
res_atom_positions = atom_positions[res_index][res_atom_mask][
22862285
contiguous_res_atom_mapping
22872286
]
2287+
2288+
num_atom_positions = len(res_atom_positions) + len(missing_atom_indices)
2289+
if num_atom_positions != mol.GetNumAtoms():
2290+
raise ValueError(
2291+
f"The number of (missing and present) atom positions ({num_atom_positions}) for residue {res} does not match the number of atoms in the RDKit molecule ({mol.GetNumAtoms()}). "
2292+
"Please ensure that these input features are correctly paired. Skipping this example."
2293+
)
2294+
22882295
mol = add_atom_positions_to_mol(
22892296
mol,
22902297
res_atom_positions.reshape(-1, 3),

0 commit comments

Comments
 (0)