Skip to content

Commit 1e1bc21

Browse files
committed
move towards using residue_atom_lens for both packed and unpacked representations for clarity. atom mask internally derived for atom transformer
1 parent 9ab2eb0 commit 1e1bc21

File tree

4 files changed

+31
-30
lines changed

4 files changed

+31
-30
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ seq_len = 16
3636
atom_seq_len = seq_len * 27
3737

3838
atom_inputs = torch.randn(2, atom_seq_len, 77)
39-
atom_mask = torch.ones((2, atom_seq_len)).bool()
39+
atom_lens = torch.randint(0, 27, (2, seq_len))
4040
atompair_feats = torch.randn(2, atom_seq_len, atom_seq_len, 16)
4141
additional_residue_feats = torch.randn(2, seq_len, 33)
4242

@@ -61,7 +61,7 @@ resolved_labels = torch.randint(0, 2, (2, seq_len))
6161
loss = alphafold3(
6262
num_recycling_steps = 2,
6363
atom_inputs = atom_inputs,
64-
atom_mask = atom_mask,
64+
residue_atom_lens = atom_lens,
6565
atompair_feats = atompair_feats,
6666
additional_residue_feats = additional_residue_feats,
6767
msa = msa,

alphafold3_pytorch/alphafold3.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555

5656
# constants
5757

58+
DIM_ADDITIONAL_RESIDUE_FEATS = 10
59+
5860
LinearNoBias = partial(Linear, bias = False)
5961

6062
# helper functions
@@ -1435,7 +1437,8 @@ def forward(
14351437
w = self.atoms_per_window
14361438
is_unpacked_repr = exists(w)
14371439

1438-
assert is_unpacked_repr ^ exists(residue_atom_lens), '`residue_atom_lens` must be passed in if using packed_atom_repr (atoms_per_window is None)'
1440+
if not is_unpacked_repr:
1441+
assert exists(residue_atom_lens), '`residue_atom_lens` must be passed in if using packed_atom_repr (atoms_per_window is None)'
14391442

14401443
atom_feats = self.proj(atom_feats)
14411444

@@ -1613,7 +1616,8 @@ def forward(
16131616
w = self.atoms_per_window
16141617
is_unpacked_repr = exists(w)
16151618

1616-
assert is_unpacked_repr ^ exists(residue_atom_lens)
1619+
if not is_unpacked_repr:
1620+
assert exists(residue_atom_lens)
16171621

16181622
# in the paper, it seems they pack the atom feats
16191623
# but in this impl, will just use windows for simplicity when communicating between atom and residue resolutions. bit less efficient
@@ -2350,7 +2354,6 @@ def __init__(
23502354
self,
23512355
*,
23522356
dim_atom_inputs,
2353-
dim_additional_residue_feats = 10,
23542357
atoms_per_window = 27,
23552358
dim_atom = 128,
23562359
dim_atompair = 16,
@@ -2396,9 +2399,7 @@ def __init__(
23962399
atoms_per_window = atoms_per_window
23972400
)
23982401

2399-
dim_single_input = dim_token + dim_additional_residue_feats
2400-
2401-
self.dim_additional_residue_feats = dim_additional_residue_feats
2402+
dim_single_input = dim_token + DIM_ADDITIONAL_RESIDUE_FEATS
24022403

24032404
self.single_input_to_single_init = LinearNoBias(dim_single_input, dim_single)
24042405
self.single_input_to_pairwise_init = LinearNoBiasThenOuterSum(dim_single_input, dim_pairwise)
@@ -2415,7 +2416,7 @@ def forward(
24152416

24162417
) -> EmbeddedInputs:
24172418

2418-
assert additional_residue_feats.shape[-1] == self.dim_additional_residue_feats
2419+
assert additional_residue_feats.shape[-1] == DIM_ADDITIONAL_RESIDUE_FEATS
24192420

24202421
w = self.atoms_per_window
24212422

@@ -2608,7 +2609,6 @@ def __init__(
26082609
self,
26092610
*,
26102611
dim_atom_inputs,
2611-
dim_additional_residue_feats,
26122612
dim_template_feats,
26132613
dim_template_model = 64,
26142614
atoms_per_window = 27,
@@ -2713,7 +2713,6 @@ def __init__(
27132713

27142714
self.input_embedder = InputFeatureEmbedder(
27152715
dim_atom_inputs = dim_atom_inputs,
2716-
dim_additional_residue_feats = dim_additional_residue_feats,
27172716
atoms_per_window = atoms_per_window,
27182717
dim_atom = dim_atom,
27192718
dim_atompair = dim_atompair,
@@ -2723,7 +2722,7 @@ def __init__(
27232722
**input_embedder_kwargs
27242723
)
27252724

2726-
dim_single_inputs = dim_input_embedder_token + dim_additional_residue_feats
2725+
dim_single_inputs = dim_input_embedder_token + DIM_ADDITIONAL_RESIDUE_FEATS
27272726

27282727
# relative positional encoding
27292728
# used by pairwise in main alphafold2 trunk
@@ -2866,22 +2865,28 @@ def forward(
28662865

28672866
atom_seq_len = atom_inputs.shape[-2]
28682867

2868+
assert exists(residue_atom_lens) or exists(atom_mask)
2869+
28692870
# determine whether using packed or unpacked atom rep
28702871

2871-
assert exists(residue_atom_lens) ^ exists(atom_mask), 'either atom_lens or atom_mask must be given depending on whether packed_atom_repr kwarg is True or False'
2872+
if self.packed_atom_repr:
2873+
assert exists(residue_atom_lens), 'residue_atom_lens must be given if using packed atom repr'
28722874

28732875
if exists(residue_atom_lens):
2874-
assert self.packed_atom_repr, '`packed_atom_repr` kwarg on Alphafold3 must be True when passing in `atom_lens`'
28752876

2876-
# handle atom mask
2877+
if self.packed_atom_repr:
2878+
# handle atom mask
28772879

2878-
total_atoms = residue_atom_lens.sum(dim = -1)
2879-
atom_mask = lens_to_mask(total_atoms, max_len = atom_seq_len)
2880+
total_atoms = residue_atom_lens.sum(dim = -1)
2881+
atom_mask = lens_to_mask(total_atoms, max_len = atom_seq_len)
28802882

2881-
# handle offsets for residue atom indices
2883+
# handle offsets for residue atom indices
28822884

2883-
if exists(residue_atom_indices):
2884-
residue_atom_indices += F.pad(residue_atom_lens, (-1, 1), value = 0)
2885+
if exists(residue_atom_indices):
2886+
residue_atom_indices += F.pad(residue_atom_lens, (-1, 1), value = 0)
2887+
else:
2888+
atom_mask = lens_to_mask(residue_atom_lens, max_len = self.atoms_per_window)
2889+
atom_mask = rearrange(atom_mask, 'b ... -> b (...)')
28852890

28862891
# get atom sequence length and residue sequence length depending on whether using packed atomic seq
28872892

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.34"
3+
version = "0.0.36"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,6 @@ def test_input_embedder():
344344

345345
embedder = InputFeatureEmbedder(
346346
dim_atom_inputs = 77,
347-
dim_additional_residue_feats = 10
348347
)
349348

350349
embedder(
@@ -369,7 +368,7 @@ def test_alphafold3():
369368
token_bond = torch.randint(0, 2, (2, seq_len, seq_len)).bool()
370369

371370
atom_inputs = torch.randn(2, atom_seq_len, 77)
372-
atom_mask = torch.ones((2, atom_seq_len)).bool()
371+
atom_lens = torch.randint(0, 27, (2, seq_len))
373372
atompair_feats = torch.randn(2, atom_seq_len, atom_seq_len, 16)
374373
additional_residue_feats = torch.randn(2, seq_len, 10)
375374

@@ -389,7 +388,6 @@ def test_alphafold3():
389388

390389
alphafold3 = Alphafold3(
391390
dim_atom_inputs = 77,
392-
dim_additional_residue_feats = 10,
393391
dim_template_feats = 44,
394392
num_dist_bins = 38,
395393
confidence_head_kwargs = dict(
@@ -414,7 +412,7 @@ def test_alphafold3():
414412
loss, breakdown = alphafold3(
415413
num_recycling_steps = 2,
416414
atom_inputs = atom_inputs,
417-
atom_mask = atom_mask,
415+
residue_atom_lens = atom_lens,
418416
atompair_feats = atompair_feats,
419417
additional_residue_feats = additional_residue_feats,
420418
token_bond = token_bond,
@@ -437,7 +435,7 @@ def test_alphafold3():
437435
sampled_atom_pos = alphafold3(
438436
num_sample_steps = 16,
439437
atom_inputs = atom_inputs,
440-
atom_mask = atom_mask,
438+
residue_atom_lens = atom_lens,
441439
atompair_feats = atompair_feats,
442440
additional_residue_feats = additional_residue_feats,
443441
msa = msa,
@@ -452,7 +450,7 @@ def test_alphafold3_without_msa_and_templates():
452450
atom_seq_len = seq_len * 27
453451

454452
atom_inputs = torch.randn(2, atom_seq_len, 77)
455-
atom_mask = torch.ones((2, atom_seq_len)).bool()
453+
atom_lens = torch.randint(0, 27, (2, seq_len))
456454
atompair_feats = torch.randn(2, atom_seq_len, atom_seq_len, 16)
457455
additional_residue_feats = torch.randn(2, seq_len, 10)
458456

@@ -467,7 +465,6 @@ def test_alphafold3_without_msa_and_templates():
467465

468466
alphafold3 = Alphafold3(
469467
dim_atom_inputs = 77,
470-
dim_additional_residue_feats = 10,
471468
dim_template_feats = 44,
472469
num_dist_bins = 38,
473470
confidence_head_kwargs = dict(
@@ -492,7 +489,7 @@ def test_alphafold3_without_msa_and_templates():
492489
loss, breakdown = alphafold3(
493490
num_recycling_steps = 2,
494491
atom_inputs = atom_inputs,
495-
atom_mask = atom_mask,
492+
residue_atom_lens = atom_lens,
496493
atompair_feats = atompair_feats,
497494
additional_residue_feats = additional_residue_feats,
498495
atom_pos = atom_pos,
@@ -536,7 +533,6 @@ def test_alphafold3_with_packed_atom_repr():
536533

537534
alphafold3 = Alphafold3(
538535
dim_atom_inputs = 77,
539-
dim_additional_residue_feats = 10,
540536
dim_template_feats = 44,
541537
num_dist_bins = 38,
542538
packed_atom_repr = True,

0 commit comments

Comments
 (0)