Skip to content

Commit cb4acf2

Browse files
committed
break out a boolean tensor from additional_molecular_feats
1 parent 4c8a341 commit cb4acf2

File tree

6 files changed

+67
-26
lines changed

6 files changed

+67
-26
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ atom_seq_len = molecule_atom_lens.sum(dim = -1).amax()
5454
atom_inputs = torch.randn(2, atom_seq_len, 77)
5555
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
5656

57-
additional_molecule_feats = torch.randn(2, seq_len, 9)
57+
additional_molecule_feats = torch.randn(2, seq_len, 5)
58+
is_molecule_types = torch.randint(0, 2, (2, seq_len)).bool()
5859
molecule_ids = torch.randint(0, 32, (2, seq_len))
5960

6061
template_feats = torch.randn(2, 2, seq_len, seq_len, 44)
@@ -83,6 +84,7 @@ loss = alphafold3(
8384
molecule_ids = molecule_ids,
8485
molecule_atom_lens = molecule_atom_lens,
8586
additional_molecule_feats = additional_molecule_feats,
87+
is_molecule_types = is_molecule_types,
8688
msa = msa,
8789
msa_mask = msa_mask,
8890
templates = template_feats,

alphafold3_pytorch/alphafold3.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -78,22 +78,30 @@
7878
"""
7979

8080
"""
81-
additional_molecule_feats: [*, 9]:
81+
additional_molecule_feats: [*, 5]:
8282
8383
0: molecule_index
8484
1: token_index
8585
2: asym_id
8686
3: entity_id
8787
4: sym_id
88-
5: is_protein
89-
6: is_rna
90-
7: is_dna
91-
8: is_ligand
88+
"""
89+
90+
"""
91+
is_molecule_types: [*, 4]
92+
93+
0: is_protein
94+
1: is_rna
95+
2: is_dna
96+
3: is_ligand
9297
"""
9398

9499
# constants
95100

96-
ADDITIONAL_MOLECULE_FEATS = 9
101+
from alphafold3_pytorch.inputs import (
102+
IS_MOLECULE_TYPES,
103+
ADDITIONAL_MOLECULE_FEATS
104+
)
97105

98106
LinearNoBias = partial(Linear, bias = False)
99107

@@ -1169,9 +1177,8 @@ def forward(
11691177
) -> Float['b n n dp']:
11701178

11711179
device = additional_molecule_feats.device
1172-
assert additional_molecule_feats.shape[-1] >= 5
11731180

1174-
res_idx, token_idx, asym_id, entity_id, sym_id = additional_molecule_feats[..., :5].unbind(dim = -1)
1181+
res_idx, token_idx, asym_id, entity_id, sym_id = additional_molecule_feats.unbind(dim = -1)
11751182

11761183
diff_res_idx = einx.subtract('b i, b j -> b i j', res_idx, res_idx)
11771184
diff_token_idx = einx.subtract('b i, b j -> b i j', token_idx, token_idx)
@@ -2173,6 +2180,7 @@ def forward(
21732180
molecule_atom_lens: Int['b n'],
21742181
atom_parent_ids: Int['b m'] | None = None,
21752182
return_denoised_pos = False,
2183+
is_molecule_types: Bool[f'b n {IS_MOLECULE_TYPES}'] | None = None,
21762184
additional_molecule_feats: Float[f'b n {ADDITIONAL_MOLECULE_FEATS}'] | None = None,
21772185
add_smooth_lddt_loss = False,
21782186
add_bond_loss = False,
@@ -2218,13 +2226,13 @@ def forward(
22182226

22192227
align_weights = atom_pos_ground_truth.new_ones(atom_pos_ground_truth.shape[:2])
22202228

2221-
if exists(additional_molecule_feats):
2222-
is_nucleotide_or_ligand_fields = (additional_molecule_feats[..., -3:] != 0.).unbind(dim = -1)
2229+
if exists(is_molecule_types):
2230+
is_nucleotide_or_ligand_fields = is_molecule_types.unbind(dim = -1)
22232231

22242232
is_nucleotide_or_ligand_fields = tuple(repeat_consecutive_with_lens(t, molecule_atom_lens) for t in is_nucleotide_or_ligand_fields)
22252233
is_nucleotide_or_ligand_fields = tuple(pad_or_slice_to(t, length = align_weights.shape[-1], dim = -1) for t in is_nucleotide_or_ligand_fields)
22262234

2227-
atom_is_dna, atom_is_rna, atom_is_ligand = is_nucleotide_or_ligand_fields
2235+
_, atom_is_dna, atom_is_rna, atom_is_ligand = is_nucleotide_or_ligand_fields
22282236

22292237
# section 3.7.1 equation 4
22302238

@@ -2281,7 +2289,7 @@ def forward(
22812289
smooth_lddt_loss = self.zero
22822290

22832291
if add_smooth_lddt_loss:
2284-
assert exists(additional_molecule_feats)
2292+
assert exists(is_molecule_types)
22852293

22862294
smooth_lddt_loss = self.smooth_lddt_loss(
22872295
denoised_atom_pos,
@@ -2651,7 +2659,7 @@ def __init__(
26512659
dim_out = dim_token
26522660
)
26532661

2654-
dim_single_input = dim_token + ADDITIONAL_MOLECULE_FEATS
2662+
dim_single_input = dim_token + ADDITIONAL_MOLECULE_FEATS + IS_MOLECULE_TYPES
26552663

26562664
self.single_input_to_single_init = LinearNoBias(dim_single_input, dim_single)
26572665
self.single_input_to_pairwise_init = LinearNoBiasThenOuterSum(dim_single_input, dim_pairwise)
@@ -2668,6 +2676,7 @@ def forward(
26682676
atom_inputs: Float['b m dai'],
26692677
atompair_inputs: Float['b m m dapi'] | Float['b nw w1 w2 dapi'],
26702678
atom_mask: Bool['b m'],
2679+
is_molecule_types: Bool[f'b n {IS_MOLECULE_TYPES}'],
26712680
additional_molecule_feats: Float[f'b n {ADDITIONAL_MOLECULE_FEATS}'],
26722681
molecule_atom_lens: Int['b n'],
26732682
molecule_ids: Int['b n']
@@ -2716,7 +2725,11 @@ def forward(
27162725
molecule_atom_lens = molecule_atom_lens
27172726
)
27182727

2719-
single_inputs = torch.cat((single_inputs, additional_molecule_feats), dim = -1)
2728+
single_inputs = torch.cat((
2729+
single_inputs,
2730+
additional_molecule_feats,
2731+
is_molecule_types.float()
2732+
), dim = -1)
27202733

27212734
single_init = self.single_input_to_single_init(single_inputs)
27222735
pairwise_init = self.single_input_to_pairwise_init(single_inputs)
@@ -3046,7 +3059,7 @@ def __init__(
30463059
**input_embedder_kwargs
30473060
)
30483061

3049-
dim_single_inputs = dim_input_embedder_token + ADDITIONAL_MOLECULE_FEATS
3062+
dim_single_inputs = dim_input_embedder_token + ADDITIONAL_MOLECULE_FEATS + IS_MOLECULE_TYPES
30503063

30513064
# relative positional encoding
30523065
# used by pairwise in main alphafold2 trunk
@@ -3236,6 +3249,7 @@ def forward(
32363249
atom_inputs: Float['b m dai'],
32373250
atompair_inputs: Float['b m m dapi'] | Float['b nw w1 w2 dapi'],
32383251
additional_molecule_feats: Float[f'b n {ADDITIONAL_MOLECULE_FEATS}'],
3252+
is_molecule_types: Bool[f'b n {IS_MOLECULE_TYPES}'],
32393253
molecule_atom_lens: Int['b n'],
32403254
molecule_ids: Int['b n'],
32413255
atom_ids: Int['b m'] | None = None,
@@ -3311,6 +3325,7 @@ def forward(
33113325
atom_inputs = atom_inputs,
33123326
atompair_inputs = atompair_inputs,
33133327
atom_mask = atom_mask,
3328+
is_molecule_types = is_molecule_types,
33143329
additional_molecule_feats = additional_molecule_feats,
33153330
molecule_atom_lens = molecule_atom_lens,
33163331
molecule_ids = molecule_ids
@@ -3513,6 +3528,7 @@ def forward(
35133528
pairwise,
35143529
relative_position_encoding,
35153530
additional_molecule_feats,
3531+
is_molecule_types,
35163532
molecule_atom_indices,
35173533
molecule_atom_lens,
35183534
pae_labels,
@@ -3535,6 +3551,7 @@ def forward(
35353551
pairwise,
35363552
relative_position_encoding,
35373553
additional_molecule_feats,
3554+
is_molecule_types,
35383555
molecule_atom_indices,
35393556
molecule_atom_lens,
35403557
pae_labels,
@@ -3566,6 +3583,7 @@ def forward(
35663583
diffusion_loss, denoised_atom_pos, diffusion_loss_breakdown, _ = self.edm(
35673584
atom_pos,
35683585
additional_molecule_feats = additional_molecule_feats,
3586+
is_molecule_types = is_molecule_types,
35693587
add_smooth_lddt_loss = diffusion_add_smooth_lddt_loss,
35703588
add_bond_loss = diffusion_add_bond_loss,
35713589
atom_feats = atom_feats,

alphafold3_pytorch/inputs.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
Int, Bool, Float
99
)
1010

11+
# constants
12+
13+
IS_MOLECULE_TYPES = 4
14+
ADDITIONAL_MOLECULE_FEATS = 5
15+
1116
# atom level, what Alphafold3 accepts
1217

1318
@typecheck
@@ -16,7 +21,8 @@ class AtomInput(TypedDict):
1621
molecule_ids: Int['n']
1722
molecule_atom_lens: Int['n']
1823
atompair_inputs: Float['m m dapi'] | Float['nw w (w*2) dapi']
19-
additional_molecule_feats: Float['n 9']
24+
additional_molecule_feats: Float[f'n {ADDITIONAL_MOLECULE_FEATS}']
25+
is_molecule_types: Bool[f'n {IS_MOLECULE_TYPES}']
2026
templates: Float['t n n dt']
2127
msa: Float['s n dm']
2228
token_bonds: Bool['n n'] | None
@@ -38,7 +44,8 @@ class BatchedAtomInput(TypedDict):
3844
molecule_ids: Int['b n']
3945
molecule_atom_lens: Int['b n']
4046
atompair_inputs: Float['b m m dapi'] | Float['b nw w (w*2) dapi']
41-
additional_molecule_feats: Float['b n 9']
47+
additional_molecule_feats: Float[f'b n {ADDITIONAL_MOLECULE_FEATS}']
48+
is_molecule_types: Bool[f'b n {IS_MOLECULE_TYPES}']
4249
templates: Float['b t n n dt']
4350
msa: Float['b s n dm']
4451
token_bonds: Bool['b n n'] | None

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.67"
3+
version = "0.1.68"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def test_diffusion_module():
330330
assert sampled_atom_pos.shape == noised_atom_pos.shape
331331

332332
def test_relative_position_encoding():
333-
additional_molecule_feats = torch.randn(8, 100, 9)
333+
additional_molecule_feats = torch.randn(8, 100, 5)
334334

335335
embedder = RelativePositionEncoding()
336336

@@ -387,7 +387,8 @@ def test_input_embedder():
387387
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
388388

389389
atom_mask = torch.ones((2, atom_seq_len)).bool()
390-
additional_molecule_feats = torch.randn(2, 16, 9)
390+
additional_molecule_feats = torch.randn(2, 16, 5)
391+
is_molecule_types = torch.randint(0, 2, (2, 16, 4)).bool()
391392
molecule_ids = torch.randint(0, 32, (2, 16))
392393

393394
embedder = InputFeatureEmbedder(
@@ -400,6 +401,7 @@ def test_input_embedder():
400401
atompair_inputs = atompair_inputs,
401402
molecule_atom_lens = molecule_atom_lens,
402403
molecule_ids = molecule_ids,
404+
is_molecule_types = is_molecule_types,
403405
additional_molecule_feats = additional_molecule_feats
404406
)
405407

@@ -433,7 +435,8 @@ def test_alphafold3(
433435
if window_atompair_inputs:
434436
atompair_inputs = full_pairwise_repr_to_windowed(atompair_inputs, window_size = atoms_per_window)
435437

436-
additional_molecule_feats = torch.randn(2, seq_len, 9)
438+
additional_molecule_feats = torch.randn(2, seq_len, 5)
439+
is_molecule_types = torch.randint(0, 2, (2, seq_len, 4)).bool()
437440
molecule_ids = torch.randint(0, 32, (2, seq_len))
438441

439442
atom_parent_ids = None
@@ -487,6 +490,7 @@ def test_alphafold3(
487490
molecule_atom_lens = molecule_atom_lens,
488491
atom_parent_ids = atom_parent_ids,
489492
atompair_inputs = atompair_inputs,
493+
is_molecule_types = is_molecule_types,
490494
additional_molecule_feats = additional_molecule_feats,
491495
token_bonds = token_bonds,
492496
msa = msa,
@@ -511,6 +515,7 @@ def test_alphafold3(
511515
molecule_ids = molecule_ids,
512516
molecule_atom_lens = molecule_atom_lens,
513517
atompair_inputs = atompair_inputs,
518+
is_molecule_types = is_molecule_types,
514519
additional_molecule_feats = additional_molecule_feats,
515520
msa = msa,
516521
templates = template_feats,
@@ -526,7 +531,8 @@ def test_alphafold3_without_msa_and_templates():
526531

527532
atom_inputs = torch.randn(2, atom_seq_len, 77)
528533
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
529-
additional_molecule_feats = torch.randn(2, seq_len, 9)
534+
additional_molecule_feats = torch.randn(2, seq_len, 5)
535+
is_molecule_types = torch.randint(0, 2, (2, seq_len, 4)).bool()
530536
molecule_ids = torch.randint(0, 32, (2, seq_len))
531537

532538
atom_pos = torch.randn(2, atom_seq_len, 3)
@@ -567,6 +573,7 @@ def test_alphafold3_without_msa_and_templates():
567573
molecule_ids = molecule_ids,
568574
molecule_atom_lens = molecule_atom_lens,
569575
atompair_inputs = atompair_inputs,
576+
is_molecule_types = is_molecule_types,
570577
additional_molecule_feats = additional_molecule_feats,
571578
atom_pos = atom_pos,
572579
molecule_atom_indices = molecule_atom_indices,
@@ -587,7 +594,8 @@ def test_alphafold3_force_return_loss():
587594

588595
atom_inputs = torch.randn(2, atom_seq_len, 77)
589596
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
590-
additional_molecule_feats = torch.randn(2, seq_len, 9)
597+
additional_molecule_feats = torch.randn(2, seq_len, 5)
598+
is_molecule_types = torch.randint(0, 2, (2, seq_len, 4)).bool()
591599
molecule_ids = torch.randint(0, 32, (2, seq_len))
592600

593601
atom_pos = torch.randn(2, atom_seq_len, 3)
@@ -628,6 +636,7 @@ def test_alphafold3_force_return_loss():
628636
molecule_ids = molecule_ids,
629637
molecule_atom_lens = molecule_atom_lens,
630638
atompair_inputs = atompair_inputs,
639+
is_molecule_types = is_molecule_types,
631640
additional_molecule_feats = additional_molecule_feats,
632641
atom_pos = atom_pos,
633642
molecule_atom_indices = molecule_atom_indices,
@@ -648,6 +657,7 @@ def test_alphafold3_force_return_loss():
648657
molecule_ids = molecule_ids,
649658
molecule_atom_lens = molecule_atom_lens,
650659
atompair_inputs = atompair_inputs,
660+
is_molecule_types = is_molecule_types,
651661
additional_molecule_feats = additional_molecule_feats,
652662
return_loss_breakdown = True,
653663
return_loss = True # force returning loss even if no labels given
@@ -676,7 +686,8 @@ def test_alphafold3_with_atom_and_bond_embeddings():
676686
atom_inputs = torch.randn(2, atom_seq_len, 77)
677687
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
678688

679-
additional_molecule_feats = torch.randn(2, seq_len, 9)
689+
additional_molecule_feats = torch.randn(2, seq_len, 5)
690+
is_molecule_types = torch.randint(0, 2, (2, seq_len, 4)).bool()
680691
molecule_ids = torch.randint(0, 32, (2, seq_len))
681692

682693
template_feats = torch.randn(2, 2, seq_len, seq_len, 44)
@@ -706,6 +717,7 @@ def test_alphafold3_with_atom_and_bond_embeddings():
706717
atompair_inputs = atompair_inputs,
707718
molecule_ids = molecule_ids,
708719
molecule_atom_lens = molecule_atom_lens,
720+
is_molecule_types = is_molecule_types,
709721
additional_molecule_feats = additional_molecule_feats,
710722
msa = msa,
711723
msa_mask = msa_mask,

tests/test_trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ def __getitem__(self, idx):
4747
atompair_inputs = torch.randn(atom_seq_len, atom_seq_len, 5)
4848

4949
molecule_atom_lens = torch.randint(1, self.atoms_per_window, (seq_len,))
50-
additional_molecule_feats = torch.randn(seq_len, 9)
50+
additional_molecule_feats = torch.randn(seq_len, 5)
51+
is_molecule_types = torch.randint(0, 2, (seq_len, 4)).bool()
5152
molecule_ids = torch.randint(0, 32, (seq_len,))
5253
token_bonds = torch.randint(0, 2, (seq_len, seq_len)).bool()
5354

@@ -78,6 +79,7 @@ def __getitem__(self, idx):
7879
token_bonds = token_bonds,
7980
molecule_atom_lens = molecule_atom_lens,
8081
additional_molecule_feats = additional_molecule_feats,
82+
is_molecule_types = is_molecule_types,
8183
templates = templates,
8284
template_mask = template_mask,
8385
msa = msa,

0 commit comments

Comments
 (0)