Skip to content

Commit 5b43671

Browse files
committed
handle the "restypes" as "molecule_ids" explicitly, removed from "additional molecule feats"
1 parent d25fbca commit 5b43671

File tree

6 files changed

+59
-18
lines changed

6 files changed

+59
-18
lines changed

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ atom_seq_len = molecule_atom_lens.sum(dim = -1).amax()
5050
atom_inputs = torch.randn(2, atom_seq_len, 77)
5151
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
5252

53-
additional_molecule_feats = torch.randn(2, seq_len, 10)
53+
additional_molecule_feats = torch.randn(2, seq_len, 9)
54+
molecule_ids = torch.randint(0, 32, (2, seq_len))
5455

5556
template_feats = torch.randn(2, 2, seq_len, seq_len, 44)
5657
template_mask = torch.ones((2, 2)).bool()
@@ -75,6 +76,7 @@ loss = alphafold3(
7576
num_recycling_steps = 2,
7677
atom_inputs = atom_inputs,
7778
atompair_inputs = atompair_inputs,
79+
molecule_ids = molecule_ids,
7880
molecule_atom_lens = molecule_atom_lens,
7981
additional_molecule_feats = additional_molecule_feats,
8082
msa = msa,
@@ -99,6 +101,7 @@ sampled_atom_pos = alphafold3(
99101
num_sample_steps = 16,
100102
atom_inputs = atom_inputs,
101103
atompair_inputs = atompair_inputs,
104+
molecule_ids = molecule_ids,
102105
molecule_atom_lens = molecule_atom_lens,
103106
additional_molecule_feats = additional_molecule_feats,
104107
msa = msa,

alphafold3_pytorch/alphafold3.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,23 +77,22 @@
7777
"""
7878

7979
"""
80-
additional_molecule_feats: [*, 10]:
80+
additional_molecule_feats: [*, 9]:
8181
8282
0: molecule_index
8383
1: token_index
8484
2: asym_id
8585
3: entity_id
8686
4: sym_id
87-
5: restype (must be one hot encoded to 32)
88-
6: is_protein
89-
7: is_rna
90-
8: is_dna
91-
9: is_ligand
87+
5: is_protein
88+
6: is_rna
89+
7: is_dna
90+
8: is_ligand
9291
"""
9392

9493
# constants
9594

96-
ADDITIONAL_MOLECULE_FEATS = 10
95+
ADDITIONAL_MOLECULE_FEATS = 9
9796

9897
LinearNoBias = partial(Linear, bias = False)
9998

@@ -2196,7 +2195,7 @@ def forward(
21962195
align_weights = atom_pos_ground_truth.new_ones(atom_pos_ground_truth.shape[:2])
21972196

21982197
if exists(additional_molecule_feats):
2199-
is_nucleotide_or_ligand_fields = (additional_molecule_feats[..., 7:] != 0.).unbind(dim = -1)
2198+
is_nucleotide_or_ligand_fields = (additional_molecule_feats[..., -3:] != 0.).unbind(dim = -1)
22002199

22012200
is_nucleotide_or_ligand_fields = tuple(repeat_consecutive_with_lens(t, molecule_atom_lens) for t in is_nucleotide_or_ligand_fields)
22022201
is_nucleotide_or_ligand_fields = tuple(pad_or_slice_to(t, length = align_weights.shape[-1], dim = -1) for t in is_nucleotide_or_ligand_fields)
@@ -2587,6 +2586,7 @@ def __init__(
25872586
dim_token = 384,
25882587
dim_single = 384,
25892588
dim_pairwise = 128,
2589+
num_molecule_types = 32,
25902590
atom_transformer_blocks = 3,
25912591
atom_transformer_heads = 4,
25922592
atom_transformer_kwargs: dict = dict(),
@@ -2632,6 +2632,11 @@ def __init__(
26322632
self.single_input_to_single_init = LinearNoBias(dim_single_input, dim_single)
26332633
self.single_input_to_pairwise_init = LinearNoBiasThenOuterSum(dim_single_input, dim_pairwise)
26342634

2635+
# this accounts for the `restypes` in the additional molecule features
2636+
2637+
self.single_molecule_embed = nn.Embedding(num_molecule_types, dim_single)
2638+
self.pairwise_molecule_embed = nn.Embedding(num_molecule_types, dim_pairwise)
2639+
26352640
@typecheck
26362641
def forward(
26372642
self,
@@ -2641,6 +2646,7 @@ def forward(
26412646
atom_mask: Bool['b m'],
26422647
additional_molecule_feats: Float[f'b n {ADDITIONAL_MOLECULE_FEATS}'],
26432648
molecule_atom_lens: Int['b n'],
2649+
molecule_ids: Int['b n']
26442650

26452651
) -> EmbeddedInputs:
26462652

@@ -2691,6 +2697,20 @@ def forward(
26912697
single_init = self.single_input_to_single_init(single_inputs)
26922698
pairwise_init = self.single_input_to_pairwise_init(single_inputs)
26932699

2700+
# account for molecule id (restypes)
2701+
2702+
molecule_ids = torch.where(molecule_ids >= 0, molecule_ids, 0) # account for padding
2703+
2704+
single_molecule_embed = self.single_molecule_embed(molecule_ids)
2705+
2706+
pairwise_molecule_embed = self.pairwise_molecule_embed(molecule_ids)
2707+
pairwise_molecule_embed = einx.add('b i dp, b j dp -> b i j dp', pairwise_molecule_embed, pairwise_molecule_embed)
2708+
2709+
# sum to single init and pairwise init, equivalent to one-hot in additional residue features
2710+
2711+
single_init = single_init + single_molecule_embed
2712+
pairwise_init = pairwise_init + pairwise_molecule_embed
2713+
26942714
return EmbeddedInputs(single_inputs, single_init, pairwise_init, atom_feats, atompair_feats)
26952715

26962716
# distogram head
@@ -2872,6 +2892,7 @@ def __init__(
28722892
dim_single = 384,
28732893
dim_pairwise = 128,
28742894
dim_token = 768,
2895+
num_molecule_types: int = 32, # restype in additional residue information, apparently 32 (must be human amino acids + nucleotides + something else)
28752896
num_atom_embeds: int | None = None,
28762897
num_atompair_embeds: int | None = None,
28772898
distance_bins: List[float] = torch.linspace(3, 20, 38).float().tolist(),
@@ -3192,6 +3213,7 @@ def forward(
31923213
atompair_inputs: Float['b m m dapi'] | Float['b nw w1 w2 dapi'],
31933214
additional_molecule_feats: Float[f'b n {ADDITIONAL_MOLECULE_FEATS}'],
31943215
molecule_atom_lens: Int['b n'],
3216+
molecule_ids: Int['b n'],
31953217
atom_ids: Int['b m'] | None = None,
31963218
atompair_ids: Int['b m m'] | Int['b nw w1 w2'] | None = None,
31973219
atom_mask: Bool['b m'] | None = None,
@@ -3265,7 +3287,8 @@ def forward(
32653287
atompair_inputs = atompair_inputs,
32663288
atom_mask = atom_mask,
32673289
additional_molecule_feats = additional_molecule_feats,
3268-
molecule_atom_lens = molecule_atom_lens
3290+
molecule_atom_lens = molecule_atom_lens,
3291+
molecule_ids = molecule_ids
32693292
)
32703293

32713294
# handle maybe atom and atompair embeddings

alphafold3_pytorch/inputs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
@typecheck
1111
class AtomInput(TypedDict):
1212
atom_inputs: Float['*b m dai']
13+
molecule_ids: Int['*b n']
1314
molecule_atom_lens: Int['*b n']
1415
atompair_inputs: Float['*b m m dapi'] | Float['*b nw w (w*2) dapi']
1516
additional_molecule_feats: Float['*b n 10']

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.52"
3+
version = "0.1.53"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def test_diffusion_module():
330330
assert sampled_atom_pos.shape == noised_atom_pos.shape
331331

332332
def test_relative_position_encoding():
333-
additional_molecule_feats = torch.randn(8, 100, 10)
333+
additional_molecule_feats = torch.randn(8, 100, 9)
334334

335335
embedder = RelativePositionEncoding()
336336

@@ -387,7 +387,8 @@ def test_input_embedder():
387387
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
388388

389389
atom_mask = torch.ones((2, atom_seq_len)).bool()
390-
additional_molecule_feats = torch.randn(2, 16, 10)
390+
additional_molecule_feats = torch.randn(2, 16, 9)
391+
molecule_ids = torch.randint(0, 32, (2, 16))
391392

392393
embedder = InputFeatureEmbedder(
393394
dim_atom_inputs = 77,
@@ -398,6 +399,7 @@ def test_input_embedder():
398399
atom_mask = atom_mask,
399400
atompair_inputs = atompair_inputs,
400401
molecule_atom_lens = molecule_atom_lens,
402+
molecule_ids = molecule_ids,
401403
additional_molecule_feats = additional_molecule_feats
402404
)
403405

@@ -429,7 +431,8 @@ def test_alphafold3(
429431
if window_atompair_inputs:
430432
atompair_inputs = full_pairwise_repr_to_windowed(atompair_inputs, window_size = atoms_per_window)
431433

432-
additional_molecule_feats = torch.randn(2, seq_len, 10)
434+
additional_molecule_feats = torch.randn(2, seq_len, 9)
435+
molecule_ids = torch.randint(0, 32, (2, seq_len))
433436

434437
template_feats = torch.randn(2, 2, seq_len, seq_len, 44)
435438
template_mask = torch.ones((2, 2)).bool()
@@ -473,6 +476,7 @@ def test_alphafold3(
473476
loss, breakdown = alphafold3(
474477
num_recycling_steps = 2,
475478
atom_inputs = atom_inputs,
479+
molecule_ids = molecule_ids,
476480
molecule_atom_lens = molecule_atom_lens,
477481
atompair_inputs = atompair_inputs,
478482
additional_molecule_feats = additional_molecule_feats,
@@ -496,6 +500,7 @@ def test_alphafold3(
496500
sampled_atom_pos = alphafold3(
497501
num_sample_steps = 16,
498502
atom_inputs = atom_inputs,
503+
molecule_ids = molecule_ids,
499504
molecule_atom_lens = molecule_atom_lens,
500505
atompair_inputs = atompair_inputs,
501506
additional_molecule_feats = additional_molecule_feats,
@@ -513,7 +518,8 @@ def test_alphafold3_without_msa_and_templates():
513518

514519
atom_inputs = torch.randn(2, atom_seq_len, 77)
515520
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
516-
additional_molecule_feats = torch.randn(2, seq_len, 10)
521+
additional_molecule_feats = torch.randn(2, seq_len, 9)
522+
molecule_ids = torch.randint(0, 32, (2, seq_len))
517523

518524
atom_pos = torch.randn(2, atom_seq_len, 3)
519525
molecule_atom_indices = molecule_atom_lens - 1
@@ -550,6 +556,7 @@ def test_alphafold3_without_msa_and_templates():
550556
loss, breakdown = alphafold3(
551557
num_recycling_steps = 2,
552558
atom_inputs = atom_inputs,
559+
molecule_ids = molecule_ids,
553560
molecule_atom_lens = molecule_atom_lens,
554561
atompair_inputs = atompair_inputs,
555562
additional_molecule_feats = additional_molecule_feats,
@@ -572,7 +579,8 @@ def test_alphafold3_force_return_loss():
572579

573580
atom_inputs = torch.randn(2, atom_seq_len, 77)
574581
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
575-
additional_molecule_feats = torch.randn(2, seq_len, 10)
582+
additional_molecule_feats = torch.randn(2, seq_len, 9)
583+
molecule_ids = torch.randint(0, 32, (2, seq_len))
576584

577585
atom_pos = torch.randn(2, atom_seq_len, 3)
578586
molecule_atom_indices = molecule_atom_lens - 1
@@ -609,6 +617,7 @@ def test_alphafold3_force_return_loss():
609617
sampled_atom_pos = alphafold3(
610618
num_recycling_steps = 2,
611619
atom_inputs = atom_inputs,
620+
molecule_ids = molecule_ids,
612621
molecule_atom_lens = molecule_atom_lens,
613622
atompair_inputs = atompair_inputs,
614623
additional_molecule_feats = additional_molecule_feats,
@@ -628,6 +637,7 @@ def test_alphafold3_force_return_loss():
628637
loss, _ = alphafold3(
629638
num_recycling_steps = 2,
630639
atom_inputs = atom_inputs,
640+
molecule_ids = molecule_ids,
631641
molecule_atom_lens = molecule_atom_lens,
632642
atompair_inputs = atompair_inputs,
633643
additional_molecule_feats = additional_molecule_feats,
@@ -658,7 +668,8 @@ def test_alphafold3_with_atom_and_bond_embeddings():
658668
atom_inputs = torch.randn(2, atom_seq_len, 77)
659669
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
660670

661-
additional_molecule_feats = torch.randn(2, seq_len, 10)
671+
additional_molecule_feats = torch.randn(2, seq_len, 9)
672+
molecule_ids = torch.randint(0, 32, (2, seq_len))
662673

663674
template_feats = torch.randn(2, 2, seq_len, seq_len, 44)
664675
template_mask = torch.ones((2, 2)).bool()
@@ -685,6 +696,7 @@ def test_alphafold3_with_atom_and_bond_embeddings():
685696
atompair_ids = atompair_ids,
686697
atom_inputs = atom_inputs,
687698
atompair_inputs = atompair_inputs,
699+
molecule_ids = molecule_ids,
688700
molecule_atom_lens = molecule_atom_lens,
689701
additional_molecule_feats = additional_molecule_feats,
690702
msa = msa,

tests/test_trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ def __getitem__(self, idx):
4646
atompair_inputs = torch.randn(atom_seq_len, atom_seq_len, 5)
4747

4848
molecule_atom_lens = torch.randint(1, self.atoms_per_window, (seq_len,))
49-
additional_molecule_feats = torch.randn(seq_len, 10)
49+
additional_molecule_feats = torch.randn(seq_len, 9)
50+
molecule_ids = torch.randint(0, 32, (seq_len,))
5051

5152
templates = torch.randn(2, seq_len, seq_len, 44)
5253
template_mask = torch.ones((2,)).bool()
@@ -71,6 +72,7 @@ def __getitem__(self, idx):
7172
return AtomInput(
7273
atom_inputs = atom_inputs,
7374
atompair_inputs = atompair_inputs,
75+
molecule_ids = molecule_ids,
7476
molecule_atom_lens = molecule_atom_lens,
7577
additional_molecule_feats = additional_molecule_feats,
7678
templates = templates,

0 commit comments

Comments
 (0)