Skip to content

Commit 30c76e9

Browse files
committed
first pass at marking the phosphodiesterase and peptide bonds correctly without reordering the atoms
1 parent dcde445 commit 30c76e9

File tree

2 files changed

+37
-8
lines changed

2 files changed

+37
-8
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,10 @@ def flatten(arr):
102102
def exclusive_cumsum(t):
103103
return t.cumsum(dim = -1) - t
104104

105-
def pad_to_len(t, length, value = 0):
106-
return F.pad(t, (0, max(0, length - t.shape[-1])), value = value)
105+
def pad_to_len(t, length, value = 0, dim = -1):
106+
assert dim < 0
107+
zeros = (0, 0) * (-dim - 1)
108+
return F.pad(t, (*zeros, 0, max(0, length - t.shape[-1])), value = value)
107109

108110
def compose(*fns: Callable):
109111
# for chaining from Alphafold3Input -> MoleculeInput -> AtomInput
@@ -205,6 +207,7 @@ class MoleculeInput:
205207
molecule_ids: Int[' n']
206208
additional_molecule_feats: Int[f'n {ADDITIONAL_MOLECULE_FEATS}']
207209
is_molecule_types: Bool[f'n {IS_MOLECULE_TYPES}']
210+
src_tgt_atom_indices: Int['n 2']
208211
token_bonds: Bool['n n']
209212
molecule_atom_indices: List[int | None] | None = None
210213
distogram_atom_indices: List[int | None] | None = None
@@ -308,7 +311,10 @@ def molecule_to_atom_input(
308311

309312
# for every molecule, build the bonds id matrix and add to `atompair_ids`
310313

311-
for idx, (mol, is_first_mol_in_chain, is_chainable_biomolecule, offset) in enumerate(zip(molecules, is_first_mol_in_chains, is_chainable_biomolecules, offsets)):
314+
prev_mol = None
315+
prev_src_tgt_atom_indices = None
316+
317+
for idx, (mol, is_first_mol_in_chain, is_chainable_biomolecule, src_tgt_atom_indices, offset) in enumerate(zip(molecules, is_first_mol_in_chains, is_chainable_biomolecules, i.src_tgt_atom_indices, offsets)):
312318

313319
coordinates = []
314320
updates = []
@@ -349,15 +355,21 @@ def molecule_to_atom_input(
349355
atompair_ids[row_col_slice, row_col_slice] = mol_atompair_ids
350356

351357
# if is chainable biomolecule
352-
# and not the first biomolecule in the chain, add a single covalent bond between first atom of incoming biomolecule and the last atom of the last biomolecule
358+
# and not the first biomolecule in the chain, add a single covalent bond between first atom of incoming biomolecule and the last atom of the last biomolecule
353359

354360
if is_chainable_biomolecule and not is_first_mol_in_chain:
355361

362+
_, last_atom_index = prev_src_tgt_atom_indices
363+
first_atom_index, _ = src_tgt_atom_indices
364+
365+
src_atom_offset = offset + first_atom_index
366+
tgt_atom_offset = offset - last_atom_index
356367

357-
atompair_ids[offset, offset - 1] = 1
358-
atompair_ids[offset - 1, offset] = 1
368+
atompair_ids[src_atom_offset, tgt_atom_offset] = 1
369+
atompair_ids[tgt_atom_offset, src_atom_offset] = 1
359370

360-
last_mol = mol
371+
prev_mol = mol
372+
prev_src_tgt_atom_indices = src_tgt_atom_indices
361373

362374
# atom_inputs
363375

@@ -538,6 +550,7 @@ def alphafold3_input_to_molecule_input(
538550

539551
distogram_atom_indices = []
540552
molecule_atom_indices = []
553+
src_tgt_atom_indices = []
541554

542555
for protein in proteins:
543556
mol_peptides, protein_entries = map_int_or_string_indices_to_mol(HUMAN_AMINO_ACIDS, protein, chain = True, return_entries = True)
@@ -546,6 +559,8 @@ def alphafold3_input_to_molecule_input(
546559
distogram_atom_indices.extend([entry['token_center_atom_idx'] for entry in protein_entries])
547560
molecule_atom_indices.extend([entry['distogram_atom_idx'] for entry in protein_entries])
548561

562+
src_tgt_atom_indices.extend([[entry['first_atom_idx'], entry['last_atom_idx']] for entry in protein_entries])
563+
549564
protein_ids = maybe_string_to_int(HUMAN_AMINO_ACIDS, protein) + protein_offset
550565
molecule_ids.append(protein_ids)
551566

@@ -560,6 +575,11 @@ def alphafold3_input_to_molecule_input(
560575
mol_seq, ss_rna_entries = map_int_or_string_indices_to_mol(RNA_NUCLEOTIDES, seq, chain = True, return_entries = True)
561576
mol_ss_rnas.append(mol_seq)
562577

578+
distogram_atom_indices.extend([entry['token_center_atom_idx'] for entry in ss_rna_entries])
579+
molecule_atom_indices.extend([entry['distogram_atom_idx'] for entry in ss_rna_entries])
580+
581+
src_tgt_atom_indices.extend([[entry['first_atom_idx'], entry['last_atom_idx']] for entry in ss_rna_entries])
582+
563583
rna_ids = maybe_string_to_int(RNA_NUCLEOTIDES, seq) + rna_offset
564584
molecule_ids.append(rna_ids)
565585

@@ -569,6 +589,11 @@ def alphafold3_input_to_molecule_input(
569589
mol_seq, ss_dna_entries = map_int_or_string_indices_to_mol(DNA_NUCLEOTIDES, seq, chain = True, return_entries = True)
570590
mol_ss_dnas.append(mol_seq)
571591

592+
distogram_atom_indices.extend([entry['token_center_atom_idx'] for entry in ss_dna_entries])
593+
molecule_atom_indices.extend([entry['distogram_atom_idx'] for entry in ss_dna_entries])
594+
595+
src_tgt_atom_indices.extend([[entry['first_atom_idx'], entry['last_atom_idx']] for entry in ss_dna_entries])
596+
572597
dna_ids = maybe_string_to_int(DNA_NUCLEOTIDES, seq) + dna_offset
573598
molecule_ids.append(dna_ids)
574599

@@ -814,6 +839,9 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
814839
molecule_atom_indices = tensor(molecule_atom_indices)
815840
molecule_atom_indices = pad_to_len(molecule_atom_indices, num_tokens, value = -1)
816841

842+
src_tgt_atom_indices = tensor(src_tgt_atom_indices)
843+
src_tgt_atom_indices = pad_to_len(src_tgt_atom_indices, num_tokens, value = -1, dim = -2)
844+
817845
# todo - handle atom positions for variable lengthed atoms (eventual missing atoms from mmCIF)
818846

819847
atom_pos = i.atom_pos
@@ -830,6 +858,7 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
830858
additional_molecule_feats = additional_molecule_feats,
831859
additional_token_feats = default(i.additional_token_feats, torch.zeros(num_tokens, 2)),
832860
is_molecule_types = is_molecule_types,
861+
src_tgt_atom_indices = src_tgt_atom_indices,
833862
atom_pos = atom_pos,
834863
templates = i.templates,
835864
msa = i.msa,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.121"
3+
version = "0.1.122"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)