Skip to content

Commit 244f097

Browse files
committed
researcher pass in constructed atompair inputs, which contains offset relative positions + inverse squared distance + mask itself, which is then all projected by one LinearNoBias within InputFeatureEmbedder
1 parent c02059d commit 244f097

File tree

6 files changed

+32
-20
lines changed

6 files changed

+32
-20
lines changed

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@ seq_len = 16
3737
atom_seq_len = seq_len * 27
3838

3939
atom_inputs = torch.randn(2, atom_seq_len, 77)
40+
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
41+
4042
atom_lens = torch.randint(0, 27, (2, seq_len))
41-
atompair_feats = torch.randn(2, atom_seq_len, atom_seq_len, 16)
4243
additional_residue_feats = torch.randn(2, seq_len, 10)
4344

4445
template_feats = torch.randn(2, 2, seq_len, seq_len, 44)
@@ -63,8 +64,8 @@ resolved_labels = torch.randint(0, 2, (2, seq_len))
6364
loss = alphafold3(
6465
num_recycling_steps = 2,
6566
atom_inputs = atom_inputs,
67+
atompair_inputs = atompair_inputs,
6668
residue_atom_lens = atom_lens,
67-
atompair_feats = atompair_feats,
6869
additional_residue_feats = additional_residue_feats,
6970
msa = msa,
7071
msa_mask = msa_mask,
@@ -87,8 +88,8 @@ sampled_atom_pos = alphafold3(
8788
num_recycling_steps = 4,
8889
num_sample_steps = 16,
8990
atom_inputs = atom_inputs,
91+
atompair_inputs = atompair_inputs,
9092
residue_atom_lens = atom_lens,
91-
atompair_feats = atompair_feats,
9293
additional_residue_feats = additional_residue_feats,
9394
msa = msa,
9495
msa_mask = msa_mask,

alphafold3_pytorch/alphafold3.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
ds - feature dimension (single)
1515
dp - feature dimension (pairwise)
1616
dap - feature dimension (atompair)
17+
dapi - feature dimension (atompair input)
1718
da - feature dimension (atom)
19+
dai - feature dimension (atom input)
1820
t - templates
1921
s - msa
2022
r - registers
@@ -2446,6 +2448,7 @@ def __init__(
24462448
self,
24472449
*,
24482450
dim_atom_inputs,
2451+
dim_atompair_inputs = 5,
24492452
atoms_per_window = 27,
24502453
dim_atom = 128,
24512454
dim_atompair = 16,
@@ -2461,6 +2464,8 @@ def __init__(
24612464

24622465
self.to_atom_feats = LinearNoBias(dim_atom_inputs, dim_atom)
24632466

2467+
self.to_atompair_feats = LinearNoBias(dim_atompair_inputs, dim_atompair)
2468+
24642469
self.atom_repr_to_atompair_feat_cond = nn.Sequential(
24652470
nn.LayerNorm(dim_atom),
24662471
LinearNoBiasThenOuterSum(dim_atom, dim_atompair),
@@ -2501,8 +2506,8 @@ def forward(
25012506
self,
25022507
*,
25032508
atom_inputs: Float['b m dai'],
2509+
atompair_inputs: Float['b m m dapi'],
25042510
atom_mask: Bool['b m'],
2505-
atompair_feats: Float['b m m dap'],
25062511
additional_residue_feats: Float[f'b n {ADDITIONAL_RESIDUE_FEATS}'],
25072512
residue_atom_lens: Int['b n'] | None = None,
25082513

@@ -2513,6 +2518,7 @@ def forward(
25132518
w = self.atoms_per_window
25142519

25152520
atom_feats = self.to_atom_feats(atom_inputs)
2521+
atompair_feats = self.to_atompair_feats(atompair_inputs)
25162522

25172523
atom_feats_cond = self.atom_repr_to_atompair_feat_cond(atom_feats)
25182524
atompair_feats = atom_feats_cond + atompair_feats
@@ -2709,6 +2715,7 @@ def __init__(
27092715
dim_template_model = 64,
27102716
atoms_per_window = 27,
27112717
dim_atom = 128,
2718+
dim_atompair_inputs = 5,
27122719
dim_atompair = 16,
27132720
dim_input_embedder_token = 384,
27142721
dim_single = 384,
@@ -2810,6 +2817,7 @@ def __init__(
28102817

28112818
self.input_embedder = InputFeatureEmbedder(
28122819
dim_atom_inputs = dim_atom_inputs,
2820+
dim_atompair_inputs = dim_atompair_inputs,
28132821
atoms_per_window = atoms_per_window,
28142822
dim_atom = dim_atom,
28152823
dim_atompair = dim_atompair,
@@ -2995,7 +3003,7 @@ def forward(
29953003
self,
29963004
*,
29973005
atom_inputs: Float['b m dai'],
2998-
atompair_feats: Float['b m m dap'],
3006+
atompair_inputs: Float['b m m dapi'],
29993007
additional_residue_feats: Float[f'b n {ADDITIONAL_RESIDUE_FEATS}'],
30003008
residue_atom_lens: Int['b n'] | None = None,
30013009
atom_mask: Bool['b m'] | None = None,
@@ -3063,8 +3071,8 @@ def forward(
30633071
atompair_feats
30643072
) = self.input_embedder(
30653073
atom_inputs = atom_inputs,
3074+
atompair_inputs = atompair_inputs,
30663075
atom_mask = atom_mask,
3067-
atompair_feats = atompair_feats,
30683076
additional_residue_feats = additional_residue_feats,
30693077
residue_atom_lens = residue_atom_lens
30703078
)

alphafold3_pytorch/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
class Alphafold3Input(TypedDict):
2626
atom_inputs: Float['m dai']
2727
residue_atom_lens: Int['n 2']
28-
atompair_feats: Float['m m dap']
28+
atompair_inputs: Float['m m dap']
2929
additional_residue_feats: Float['n 10']
3030
templates: Float['t n n dt']
3131
template_mask: Bool['t'] | None

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.62"
3+
version = "0.0.63"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,9 @@ def test_input_embedder():
353353

354354
atom_seq_len = 16 * 27
355355
atom_inputs = torch.randn(2, atom_seq_len, 77)
356+
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
357+
356358
atom_mask = torch.ones((2, atom_seq_len)).bool()
357-
atompair_feats = torch.randn(2, atom_seq_len, atom_seq_len, 16)
358359
additional_residue_feats = torch.randn(2, 16, 10)
359360

360361
embedder = InputFeatureEmbedder(
@@ -364,7 +365,7 @@ def test_input_embedder():
364365
embedder(
365366
atom_inputs = atom_inputs,
366367
atom_mask = atom_mask,
367-
atompair_feats = atompair_feats,
368+
atompair_inputs = atompair_inputs,
368369
additional_residue_feats = additional_residue_feats
369370
)
370371

@@ -383,8 +384,9 @@ def test_alphafold3():
383384
token_bond = torch.randint(0, 2, (2, seq_len, seq_len)).bool()
384385

385386
atom_inputs = torch.randn(2, atom_seq_len, 77)
387+
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
388+
386389
atom_lens = torch.randint(0, 27, (2, seq_len))
387-
atompair_feats = torch.randn(2, atom_seq_len, atom_seq_len, 16)
388390
additional_residue_feats = torch.randn(2, seq_len, 10)
389391

390392
template_feats = torch.randn(2, 2, seq_len, seq_len, 44)
@@ -428,7 +430,7 @@ def test_alphafold3():
428430
num_recycling_steps = 2,
429431
atom_inputs = atom_inputs,
430432
residue_atom_lens = atom_lens,
431-
atompair_feats = atompair_feats,
433+
atompair_inputs = atompair_inputs,
432434
additional_residue_feats = additional_residue_feats,
433435
token_bond = token_bond,
434436
msa = msa,
@@ -451,7 +453,7 @@ def test_alphafold3():
451453
num_sample_steps = 16,
452454
atom_inputs = atom_inputs,
453455
residue_atom_lens = atom_lens,
454-
atompair_feats = atompair_feats,
456+
atompair_inputs = atompair_inputs,
455457
additional_residue_feats = additional_residue_feats,
456458
msa = msa,
457459
templates = template_feats,
@@ -465,8 +467,8 @@ def test_alphafold3_without_msa_and_templates():
465467
atom_seq_len = seq_len * 27
466468

467469
atom_inputs = torch.randn(2, atom_seq_len, 77)
470+
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
468471
atom_lens = torch.randint(0, 27, (2, seq_len))
469-
atompair_feats = torch.randn(2, atom_seq_len, atom_seq_len, 16)
470472
additional_residue_feats = torch.randn(2, seq_len, 10)
471473

472474
atom_pos = torch.randn(2, atom_seq_len, 3)
@@ -505,7 +507,7 @@ def test_alphafold3_without_msa_and_templates():
505507
num_recycling_steps = 2,
506508
atom_inputs = atom_inputs,
507509
residue_atom_lens = atom_lens,
508-
atompair_feats = atompair_feats,
510+
atompair_inputs = atompair_inputs,
509511
additional_residue_feats = additional_residue_feats,
510512
atom_pos = atom_pos,
511513
residue_atom_indices = residue_atom_indices,
@@ -528,8 +530,8 @@ def test_alphafold3_with_packed_atom_repr():
528530
token_bond = torch.randint(0, 2, (2, seq_len, seq_len)).bool()
529531

530532
atom_inputs = torch.randn(2, atom_seq_len, 77)
533+
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
531534

532-
atompair_feats = torch.randn(2, atom_seq_len, atom_seq_len, 16)
533535
additional_residue_feats = torch.randn(2, seq_len, 10)
534536

535537
template_feats = torch.randn(2, 2, seq_len, seq_len, 44)
@@ -574,7 +576,7 @@ def test_alphafold3_with_packed_atom_repr():
574576
num_recycling_steps = 2,
575577
atom_inputs = atom_inputs,
576578
residue_atom_lens = residue_atom_lens,
577-
atompair_feats = atompair_feats,
579+
atompair_inputs = atompair_inputs,
578580
additional_residue_feats = additional_residue_feats,
579581
token_bond = token_bond,
580582
msa = msa,
@@ -596,7 +598,7 @@ def test_alphafold3_with_packed_atom_repr():
596598
num_sample_steps = 16,
597599
atom_inputs = atom_inputs,
598600
residue_atom_lens = residue_atom_lens,
599-
atompair_feats = atompair_feats,
601+
atompair_inputs = atompair_inputs,
600602
additional_residue_feats = additional_residue_feats,
601603
msa = msa,
602604
templates = template_feats,

tests/test_trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@ def __getitem__(self, idx):
3434
atom_seq_len = self.atom_seq_len
3535

3636
atom_inputs = torch.randn(atom_seq_len, 77)
37+
atompair_inputs = torch.randn(atom_seq_len, atom_seq_len, 5)
38+
3739
residue_atom_lens = torch.randint(0, 27, (seq_len,))
38-
atompair_feats = torch.randn(atom_seq_len, atom_seq_len, 16)
3940
additional_residue_feats = torch.randn(seq_len, 10)
4041

4142
templates = torch.randn(2, seq_len, seq_len, 44)
@@ -57,8 +58,8 @@ def __getitem__(self, idx):
5758

5859
return Alphafold3Input(
5960
atom_inputs = atom_inputs,
61+
atompair_inputs = atompair_inputs,
6062
residue_atom_lens = residue_atom_lens,
61-
atompair_feats = atompair_feats,
6263
additional_residue_feats = additional_residue_feats,
6364
templates = templates,
6465
template_mask = template_mask,

0 commit comments

Comments
 (0)