Skip to content

Commit cd97810

Browse files
committed
allow for directed bonds, as defined by rdkit
1 parent c1fa9aa commit cd97810

File tree

3 files changed

+25
-5
lines changed

3 files changed

+25
-5
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ class MoleculeInput:
180180
resolved_labels: Int[' n'] | None = None
181181
add_atom_ids: bool = False
182182
add_atompair_ids: bool = False
183+
directed_bonds: bool = False
183184
extract_atom_feats_fn: Callable[[Atom], Float['m dai']] = default_extract_atom_feats_fn
184185
extract_atompair_feats_fn: Callable[[Mol], Float['m m dapi']] = default_extract_atompair_feats_fn
185186

@@ -247,6 +248,8 @@ def molecule_to_atom_input(
247248

248249
if i.add_atompair_ids:
249250
atom_bond_index = {symbol: (idx + 1) for idx, symbol in enumerate(ATOM_BONDS)}
251+
num_atom_bond_types = len(atom_bond_index)
252+
250253
other_index = len(ATOM_BONDS) + 1
251254

252255
atompair_ids = torch.zeros(total_atoms, total_atoms).long()
@@ -284,7 +287,17 @@ def molecule_to_atom_input(
284287
bond_type = bond.GetBondType()
285288
bond_id = atom_bond_index.get(bond_type, other_index) + 1
286289

287-
updates.extend([bond_id, bond_id])
290+
# default to symmetric bond type (undirected atom bonds)
291+
292+
bond_to = bond_from = bond_id
293+
294+
# if allowing for directed bonds, assume num_atompair_embeds = (2 * num_atom_bond_types) + 1
295+
# offset other edge by num_atom_bond_types
296+
297+
if i.directed_bonds:
298+
bond_from += num_atom_bond_types
299+
300+
updates.extend([bond_to, bond_from])
288301

289302
coordinates = tensor(coordinates).long()
290303
updates = tensor(updates).long()
@@ -386,6 +399,7 @@ class Alphafold3Input:
386399
add_atom_ids: bool = False
387400
add_atompair_ids: bool = False
388401
add_output_atompos_indices: bool = True
402+
directed_bonds: bool = False
389403
extract_atom_feats_fn: Callable[[Atom], Float['m dai']] = default_extract_atom_feats_fn
390404
extract_atompair_feats_fn: Callable[[Mol], Float['m m dapi']] = default_extract_atompair_feats_fn
391405

@@ -833,6 +847,7 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
833847
atom_parent_ids = atom_parent_ids,
834848
add_atom_ids = i.add_atom_ids,
835849
add_atompair_ids = i.add_atompair_ids,
850+
directed_bonds = i.directed_bonds,
836851
extract_atom_feats_fn = i.extract_atom_feats_fn,
837852
extract_atompair_feats_fn = i.extract_atompair_feats_fn
838853
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.106"
3+
version = "0.1.107"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_input.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pytest
12
import torch
23

34
from alphafold3_pytorch import (
@@ -23,7 +24,8 @@ def test_tensor_reverse_complement():
2324
rc = reverse_complement_tensor(seq)
2425
assert torch.allclose(reverse_complement_tensor(rc), seq)
2526

26-
def test_alphafold3_input():
27+
@pytest.mark.parametrize('directed_bonds', (False, True))
28+
def test_alphafold3_input(directed_bonds):
2729

2830
alphafold3_input = Alphafold3Input(
2931
proteins = ['MLEICLKLVGCKSKKGLSSSSSCYLEEALQRPVASDF', 'MGKCRGLRTARKLRSHRRDQKWHDKQYKKAHLGTALKANPFGGASHAKGIVLEKVGVEAKQPNSAIRKCVRVQLIKNGKKITAFVPNDGCLNFIEENDEVLVAGFGRKGHAVGDIPGVRFKVVKVANVSLLALYKGKKERPRS'],
@@ -35,18 +37,21 @@ def test_alphafold3_input():
3537
misc_molecule_ids = ['Phospholipid'],
3638
ligands = ['CC1=C(C=C(C=C1)NC(=O)C2=CC=C(C=C2)CN3CCN(CC3)C)NC4=NC=CC(=N4)C5=CN=CC=C5'],
3739
add_atom_ids = True,
38-
add_atompair_ids = True
40+
add_atompair_ids = True,
41+
directed_bonds = directed_bonds
3942
)
4043

4144
batched_atom_input = alphafold3_inputs_to_batched_atom_input(alphafold3_input)
4245

4346
# feed it into alphafold3
4447

48+
num_atom_bond_types = (6 * (2 if directed_bonds else 1))
49+
4550
alphafold3 = Alphafold3(
4651
dim_atom_inputs = 3,
4752
dim_atompair_inputs = 1,
4853
num_atom_embeds = 47,
49-
num_atompair_embeds = 6 + 1,
54+
num_atompair_embeds = num_atom_bond_types + 1, # 0 is for no bond
5055
atoms_per_window = 27,
5156
dim_template_feats = 44,
5257
num_dist_bins = 38,

0 commit comments

Comments
 (0)