Skip to content

Commit 3a246b7

Browse files
committed
complete the logic for designating the peptide and phosphodiesterase bond for chains of polypeptides and nucleic acids
1 parent 9b7a88e commit 3a246b7

File tree

3 files changed

+36
-19
lines changed

3 files changed

+36
-19
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -154,15 +154,17 @@ def molecule_to_atom_input(
154154
mol_input: MoleculeInput
155155
) -> AtomInput:
156156

157-
molecules = mol_input.molecules
158-
atom_lens = mol_input.molecule_token_pool_lens
157+
i = mol_input
158+
159+
molecules = i.molecules
160+
atom_lens = i.molecule_token_pool_lens
159161

160162
# get total number of atoms
161163

162164
if not exists(atom_lens):
163165
atom_lens = []
164166

165-
for mol, is_ligand in zip(molecules, mol_input.is_molecule_types[:, -1]):
167+
for mol, is_ligand in zip(molecules, i.is_molecule_types[:, -1]):
166168
num_atoms = mol.GetNumAtoms()
167169

168170
if is_ligand:
@@ -184,7 +186,7 @@ def molecule_to_atom_input(
184186

185187
atom_ids = None
186188

187-
if mol_input.add_atom_ids:
189+
if i.add_atom_ids:
188190
atom_index = {symbol: i for i, symbol in enumerate(ATOMS)}
189191

190192
atom_ids = []
@@ -201,14 +203,22 @@ def molecule_to_atom_input(
201203

202204
atompair_ids = None
203205

204-
if mol_input.add_atompair_ids:
205-
atom_bond_index = {symbol: (i + 1) for i, symbol in enumerate(ATOM_BONDS)}
206+
if i.add_atompair_ids:
207+
atom_bond_index = {symbol: (idx + 1) for idx, symbol in enumerate(ATOM_BONDS)}
206208
other_index = len(ATOM_BONDS) + 1
207209

208210
atompair_ids = torch.zeros(total_atoms, total_atoms).long()
211+
209212
offset = 0
210213

211-
for mol in molecules:
214+
# need the asym_id (each molecule for each chain ascending) as well as `is_protein | is_dna | is_rna` for is_molecule_types (chainable biomolecules)
215+
# will do a single bond from a peptide or nucleotide to the one before, if `asym_id` != 0 (first in the chain)
216+
217+
asym_ids = i.additional_molecule_feats[..., 2]
218+
is_chainable_biomolecules = i.is_molecule_types[..., :3].any(dim = -1)
219+
220+
for idx, (mol, asym_id, is_chainable_biomolecule) in enumerate(zip(molecules, asym_ids, is_chainable_biomolecules)):
221+
212222
coordinates = []
213223
updates = []
214224

@@ -225,7 +235,7 @@ def molecule_to_atom_input(
225235
])
226236

227237
bond_type = bond.GetBondType()
228-
bond_id = atom_bond_index.get(bond_type, other_index)
238+
bond_id = atom_bond_index.get(bond_type, other_index) + 1
229239

230240
updates.extend([bond_id, bond_id])
231241

@@ -237,7 +247,14 @@ def molecule_to_atom_input(
237247
row_col_slice = slice(offset, offset + num_atoms)
238248
atompair_ids[row_col_slice, row_col_slice] = mol_atompair_ids
239249

240-
offset += num_atoms
250+
# if is chainable biomolecule
251+
# and not the first biomolecule in the chain, add a single covalent bond between first atom of incoming biomolecule and the last atom of the last biomolecule
252+
253+
if is_chainable_biomolecule and asym_id != 0:
254+
atompair_ids[offset, offset - 1] = 1
255+
atompair_ids[offset - 1, offset] = 1
256+
257+
offset += num_atoms
241258

242259
# atom_inputs
243260

@@ -266,8 +283,8 @@ def molecule_to_atom_input(
266283

267284
all_atom_pos = []
268285

269-
for i, atom in enumerate(mol.GetAtoms()):
270-
pos = mol.GetConformer().GetAtomPosition(i)
286+
for idx, atom in enumerate(mol.GetAtoms()):
287+
pos = mol.GetConformer().GetAtomPosition(idx)
271288
all_atom_pos.append([pos.x, pos.y, pos.z])
272289

273290
all_atom_pos_tensor = tensor(all_atom_pos)
@@ -285,12 +302,12 @@ def molecule_to_atom_input(
285302
atom_inputs = tensor(atom_inputs, dtype = torch.float),
286303
atompair_inputs = atompair_inputs,
287304
molecule_atom_lens = tensor(atom_lens, dtype = torch.long),
288-
molecule_ids = mol_input.molecule_ids,
289-
additional_token_feats = mol_input.additional_token_feats,
290-
additional_molecule_feats = mol_input.additional_molecule_feats,
291-
is_molecule_types = mol_input.is_molecule_types,
292-
token_bonds = mol_input.token_bonds,
293-
atom_parent_ids = mol_input.atom_parent_ids,
305+
molecule_ids = i.molecule_ids,
306+
additional_token_feats = i.additional_token_feats,
307+
additional_molecule_feats = i.additional_molecule_feats,
308+
is_molecule_types = i.is_molecule_types,
309+
token_bonds = i.token_bonds,
310+
atom_parent_ids = i.atom_parent_ids,
294311
atom_ids = atom_ids,
295312
atompair_ids = atompair_ids
296313
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.86"
3+
version = "0.1.88"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_input.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_alphafold3_input():
4949
dim_atom_inputs = 3,
5050
dim_atompair_inputs = 1,
5151
num_atom_embeds = 47,
52-
num_atompair_embeds = 6,
52+
num_atompair_embeds = 6 + 1,
5353
atoms_per_window = 27,
5454
dim_template_feats = 44,
5555
num_dist_bins = 38,

0 commit comments

Comments
 (0)