Skip to content

Commit 9bbccd2

Browse files
committed
take care of reordering the atom positions back to canonical, if a flag is set on Alphafold3Input
1 parent fbb06b3 commit 9bbccd2

File tree

3 files changed

+44
-6
lines changed

3 files changed

+44
-6
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,8 @@ def alphafold3_input_to_molecule_input(
426426

427427
i = alphafold3_input
428428

429+
chainable_biomol_entries: List[List[dict]] = [] # for reordering the atom positions at the end
430+
429431
ss_rnas = list(i.ss_rna)
430432
ss_dnas = list(i.ss_dna)
431433

@@ -467,25 +469,31 @@ def alphafold3_input_to_molecule_input(
467469
protein_ids = maybe_string_to_int(HUMAN_AMINO_ACIDS, protein) + protein_offset
468470
molecule_ids.append(protein_ids)
469471

472+
chainable_biomol_entries.append(protein_entries)
473+
470474
# convert all single stranded nucleic acids to mol
471475

472476
mol_ss_dnas = []
473477
mol_ss_rnas = []
474478

475479
for seq in ss_rnas:
476-
mol_seq = map_int_or_string_indices_to_mol(RNA_NUCLEOTIDES, seq, chain = True)
480+
mol_seq, ss_rna_entries = map_int_or_string_indices_to_mol(RNA_NUCLEOTIDES, seq, chain = True, return_entries = True)
477481
mol_ss_rnas.append(mol_seq)
478482

479483
rna_ids = maybe_string_to_int(RNA_NUCLEOTIDES, seq) + rna_offset
480484
molecule_ids.append(rna_ids)
481485

486+
chainable_biomol_entries.append(ss_rna_entries)
487+
482488
for seq in ss_dnas:
483-
mol_seq = map_int_or_string_indices_to_mol(DNA_NUCLEOTIDES, seq, chain = True)
489+
mol_seq, ss_dna_entries = map_int_or_string_indices_to_mol(DNA_NUCLEOTIDES, seq, chain = True, return_entries = True)
484490
mol_ss_dnas.append(mol_seq)
485491

486492
dna_ids = maybe_string_to_int(DNA_NUCLEOTIDES, seq) + dna_offset
487493
molecule_ids.append(dna_ids)
488494

495+
chainable_biomol_entries.append(ss_dna_entries)
496+
489497
# convert metal ions to rdchem.Mol
490498

491499
metal_ions = alphafold3_input.metal_ions
@@ -558,6 +566,8 @@ def alphafold3_input_to_molecule_input(
558566
*metal_ions_pool_lens
559567
]
560568

569+
total_atoms = sum(token_pool_lens)
570+
561571
# construct the token bonds
562572

563573
# will be linearly connected for proteins and nucleic acids
@@ -724,11 +734,38 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
724734
# handle atom positions
725735

726736
atom_pos = i.atom_pos
737+
output_atompos_indices = None
727738

728739
if exists(atom_pos):
729740
if isinstance(atom_pos, list):
730741
atom_pos = torch.cat(atom_pos, dim = -2)
731742

743+
# to automatically reorder the atom positions back to canonical
744+
745+
if i.add_output_atompos_indices:
746+
offset = 0
747+
output_atompos_indices = []
748+
749+
for chain in chainable_biomol_entries:
750+
for idx, entry in enumerate(chain):
751+
is_last = idx == (len(chain) - 1)
752+
753+
mol = entry['rdchem_mol']
754+
num_atoms = mol.GetNumAtoms()
755+
atom_reorder_indices = entry['atom_reorder_indices']
756+
757+
if not is_last:
758+
num_atoms -= 1
759+
atom_reorder_indices = atom_reorder_indices[:-1]
760+
761+
reorder_back_indices = atom_reorder_indices.argsort()
762+
output_atompos_indices.append(reorder_back_indices + offset)
763+
764+
offset += num_atoms
765+
766+
output_atompos_indices = torch.cat(output_atompos_indices, dim = -1)
767+
output_atompos_indices = F.pad(output_atompos_indices, (0, total_atoms - output_atompos_indices.shape[-1]), value = -1)
768+
732769
# create molecule input
733770

734771
molecule_input = MoleculeInput(
@@ -741,6 +778,7 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
741778
additional_token_feats = default(i.additional_token_feats, torch.zeros(num_tokens, 2)),
742779
is_molecule_types = is_molecule_types,
743780
atom_pos = atom_pos,
781+
output_atompos_indices = output_atompos_indices,
744782
templates = i.templates,
745783
msa = i.msa,
746784
template_mask = i.template_mask,

alphafold3_pytorch/life.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,11 +362,11 @@ def remove_atom_from_mol(mol: Mol, atom_idx: int) -> Mol:
362362
atom_order[entry['last_atom_idx']] = 1e4
363363
atom_order[entry['hydroxyl_idx']] = 1e4 + 1
364364

365-
atom_reorder = atom_order.argsort().tolist()
365+
atom_reorder = atom_order.argsort()
366366

367-
mol = Chem.RenumberAtoms(mol, atom_reorder)
367+
mol = Chem.RenumberAtoms(mol, atom_reorder.tolist())
368368

369369
entry.update(
370-
atom_reorder = atom_reorder,
370+
atom_reorder_indices = atom_reorder,
371371
rdchem_mol = mol
372372
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.92"
3+
version = "0.1.94"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)