Skip to content

Commit 6241e10

Browse files
committed
address #122
1 parent 0205136 commit 6241e10

12 files changed

+61
-68
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ batched_atom_input = alphafold3_inputs_to_batched_atom_input(train_alphafold3_in
174174

175175
alphafold3 = Alphafold3(
176176
dim_atom_inputs = 3,
177-
dim_atompair_inputs = 1,
177+
dim_atompair_inputs = 5,
178178
atoms_per_window = 27,
179179
dim_template_feats = 44,
180180
num_dist_bins = 38,

alphafold3_pytorch/alphafold3.py

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -230,57 +230,6 @@ def pad_and_window(
230230
t = rearrange(t, 'b (n w) ... -> b n w ...', w = window_size)
231231
return t
232232

233-
# to atompair input functions
234-
235-
@typecheck
236-
def atom_ref_pos_to_atompair_inputs(
237-
atom_ref_pos: Float['... m 3'],
238-
atom_ref_space_uid: Int['... m'],
239-
) -> Float['... m m 5']:
240-
241-
# Algorithm 5 - lines 2-6
242-
# allow for either batched or single
243-
244-
atom_ref_pos, unpack_one = pack_one(atom_ref_pos, '* m c')
245-
atom_ref_space_uid, _ = pack_one(atom_ref_space_uid, '* m')
246-
247-
assert atom_ref_pos.shape[0] == atom_ref_space_uid.shape[0]
248-
249-
# line 2
250-
251-
pairwise_rel_pos = einx.subtract('b i c, b j c -> b i j c', atom_ref_pos, atom_ref_pos)
252-
253-
# line 3
254-
255-
same_ref_space_mask = einx.equal('b i, b j -> b i j', atom_ref_space_uid, atom_ref_space_uid)
256-
257-
# line 5 - pairwise inverse squared distance
258-
259-
atom_inv_square_dist = (1 + pairwise_rel_pos.norm(dim = -1, p = 2) ** 2) ** -1
260-
261-
# concat all into atompair_inputs for projection into atompair_feats within Alphafold3
262-
263-
atompair_inputs, _ = pack((
264-
pairwise_rel_pos,
265-
atom_inv_square_dist,
266-
same_ref_space_mask.float(),
267-
), 'b i j *')
268-
269-
# mask out
270-
271-
atompair_inputs = einx.where(
272-
'b i j, b i j dapi, -> b i j dapi',
273-
same_ref_space_mask, atompair_inputs, 0.
274-
)
275-
276-
# reconstitute optional batch dimension
277-
278-
atompair_inputs = unpack_one(atompair_inputs, '* i j dapi')
279-
280-
# return
281-
282-
return atompair_inputs
283-
284233
# packed atom representation functions
285234

286235
@typecheck

alphafold3_pytorch/inputs.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from typing import Any, Callable, List, Literal, Set, Tuple, Type
1313

1414
import einx
15+
from einops import pack
1516

1617
import numpy as np
1718
from numpy.lib.format import open_memmap
@@ -54,7 +55,6 @@
5455
reverse_complement_tensor,
5556
)
5657

57-
5858
from alphafold3_pytorch.tensor_typing import Bool, Float, Int, typecheck
5959
from alphafold3_pytorch.utils.data_utils import RESIDUE_MOLECULE_TYPE, get_residue_molecule_type
6060
from alphafold3_pytorch.utils.model_utils import exclusive_cumsum
@@ -298,6 +298,51 @@ def __getitem__(self, idx: int) -> AtomInput:
298298
file = self.files[idx]
299299
return file_to_atom_input(file)
300300

301+
# atom reference position to atompair inputs
302+
# will be used in the `default_extract_atompair_feats_fn` below in MoleculeInput
303+
304+
@typecheck
305+
def atom_ref_pos_to_atompair_inputs(
306+
atom_ref_pos: Float['m 3'],
307+
atom_ref_space_uid: Int['m'] | None = None,
308+
) -> Float['m m 5']:
309+
310+
# Algorithm 5 - lines 2-6
311+
312+
# line 2
313+
314+
pairwise_rel_pos = einx.subtract('i c, j c -> i j c', atom_ref_pos, atom_ref_pos)
315+
316+
# line 5 - pairwise inverse squared distance
317+
318+
atom_inv_square_dist = (1 + pairwise_rel_pos.norm(dim = -1, p = 2) ** 2) ** -1
319+
320+
# line 3
321+
322+
if exists(atom_ref_space_uid):
323+
same_ref_space_mask = einx.equal('i, j -> i j', atom_ref_space_uid, atom_ref_space_uid)
324+
else:
325+
same_ref_space_mask = torch.ones_like(atom_inv_square_dist).bool()
326+
327+
# concat all into atompair_inputs for projection into atompair_feats within Alphafold3
328+
329+
atompair_inputs, _ = pack((
330+
pairwise_rel_pos,
331+
atom_inv_square_dist,
332+
same_ref_space_mask.float(),
333+
), 'i j *')
334+
335+
# mask out
336+
337+
atompair_inputs = einx.where(
338+
'i j, i j dapi, -> i j dapi',
339+
same_ref_space_mask, atompair_inputs, 0.
340+
)
341+
342+
# return
343+
344+
return atompair_inputs
345+
301346
# molecule input - accepting list of molecules as rdchem.Mol + the atomic lengths for how to pool into tokens
302347

303348
def default_extract_atom_feats_fn(atom: Atom):
@@ -316,8 +361,7 @@ def default_extract_atompair_feats_fn(mol: Mol):
316361

317362
all_atom_pos_tensor = tensor(all_atom_pos)
318363

319-
dist_matrix = torch.cdist(all_atom_pos_tensor, all_atom_pos_tensor)
320-
return torch.stack((dist_matrix,), dim = -1)
364+
return atom_ref_pos_to_atompair_inputs(all_atom_pos_tensor) # what they did in the paper, but can be overwritten
321365

322366
@typecheck
323367
@dataclass

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.76"
3+
version = "0.2.77"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/configs/trainer_with_atom_dataset_created_from_pdb.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ model:
77
dim_pairwise: 4
88
dim_token: 4
99
dim_atom_inputs: 3
10-
dim_atompair_inputs: 1
10+
dim_atompair_inputs: 5
1111
dim_template_model: 8
1212
atoms_per_window: 27
1313
dim_template_feats: 44

tests/configs/trainer_with_pdb_dataset.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ model:
77
dim_pairwise: 4
88
dim_token: 4
99
dim_atom_inputs: 3
10-
dim_atompair_inputs: 1
10+
dim_atompair_inputs: 5
1111
dim_template_model: 8
1212
atoms_per_window: 27
1313
dim_template_feats: 44

tests/configs/trainer_with_pdb_dataset_and_weighted_sampling.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ model:
77
dim_pairwise: 4
88
dim_token: 4
99
dim_atom_inputs: 3
10-
dim_atompair_inputs: 1
10+
dim_atompair_inputs: 5
1111
dim_template_model: 8
1212
atoms_per_window: 27
1313
dim_template_feats: 44

tests/configs/training_with_pdb_dataset.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ model:
1313
dim_pairwise: 4
1414
dim_token: 4
1515
dim_atom_inputs: 3
16-
dim_atompair_inputs: 1
16+
dim_atompair_inputs: 5
1717
dim_template_model: 8
1818
atoms_per_window: 27
1919
dim_template_feats: 44

tests/test_af3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@
4040
from alphafold3_pytorch.alphafold3 import (
4141
mean_pool_with_lens,
4242
repeat_consecutive_with_lens,
43-
full_pairwise_repr_to_windowed,
44-
atom_ref_pos_to_atompair_inputs
43+
full_pairwise_repr_to_windowed
4544
)
4645

4746
from alphafold3_pytorch.inputs import (
48-
IS_MOLECULE_TYPES
47+
IS_MOLECULE_TYPES,
48+
atom_ref_pos_to_atompair_inputs
4949
)
5050

5151
def test_atom_ref_pos_to_atompair_inputs():

tests/test_dataloading.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_data_input():
4141

4242
alphafold3 = Alphafold3(
4343
dim_atom_inputs=3,
44-
dim_atompair_inputs=1,
44+
dim_atompair_inputs=5,
4545
atoms_per_window=27,
4646
dim_template_feats=44,
4747
num_dist_bins=38,

0 commit comments

Comments
 (0)