Skip to content

Commit f88e91a

Browse files
committed
take care of reordering atom positions to the non-canonical order during training
1 parent 6258d30 commit f88e91a

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ class Alphafold3Input:
372372
templates: Float['t n n dt'] | None = None
373373
msa: Float['s n dm'] | None = None
374374
atom_pos: List[Float['_ 3']] | Float['m 3'] | None = None
375+
reorder_atom_pos: bool = True
375376
template_mask: Bool[' t'] | None = None
376377
msa_mask: Bool[' s'] | None = None
377378
distance_labels: Int['n n'] | None = None
@@ -770,6 +771,8 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
770771

771772
if i.add_output_atompos_indices:
772773
offset = 0
774+
775+
reorder_atompos_indices = []
773776
output_atompos_indices = []
774777

775778
for chain in chainable_biomol_entries:
@@ -784,6 +787,8 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
784787
num_atoms -= 1
785788
atom_reorder_indices = atom_reorder_indices[:-1]
786789

790+
reorder_atompos_indices.append(atom_reorder_indices)
791+
787792
reorder_back_indices = atom_reorder_indices.argsort()
788793
output_atompos_indices.append(reorder_back_indices + offset)
789794

@@ -792,6 +797,18 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
792797
output_atompos_indices = torch.cat(output_atompos_indices, dim = -1)
793798
output_atompos_indices = F.pad(output_atompos_indices, (0, total_atoms - output_atompos_indices.shape[-1]), value = -1)
794799

800+
# if atom positions are passed in, need to be reordered for the bonds between residues / nucleotides to be contiguous
801+
# todo - fix to have no reordering needed (bonds are built not contiguous, just hydroxyl removed)
802+
803+
if i.reorder_atom_pos:
804+
orig_order = torch.arange(total_atoms)
805+
reorder_atompos_indices = torch.cat(reorder_atompos_indices, dim = -1)
806+
reorder_atompos_indices = F.pad(reorder_atompos_indices, (0, total_atoms - reorder_atompos_indices.shape[-1]), value = -1)
807+
808+
reorder_indices = torch.where(reorder_atompos_indices != -1, reorder_atompos_indices, orig_order)
809+
810+
atom_pos = atom_pos[reorder_indices]
811+
795812
# create molecule input
796813

797814
molecule_input = MoleculeInput(

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.101"
3+
version = "0.1.102"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)