Skip to content

Commit ff1daf1

Browse files
committed
minimal viable working implementation of atomic resolution confidence heads
1 parent 7e65978 commit ff1daf1

File tree

3 files changed

+38
-17
lines changed

3 files changed

+38
-17
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3137,7 +3137,8 @@ def __init__(
31373137
S_noise = 1.003,
31383138
),
31393139
augment_kwargs: dict = dict(),
3140-
stochastic_frame_average = False
3140+
stochastic_frame_average = False,
3141+
confidence_head_atom_resolution = False
31413142
):
31423143
super().__init__()
31433144

@@ -3315,6 +3316,7 @@ def __init__(
33153316
num_plddt_bins = num_plddt_bins,
33163317
num_pde_bins = num_pde_bins,
33173318
num_pae_bins = num_pae_bins,
3319+
atom_resolution = confidence_head_atom_resolution,
33183320
**confidence_head_kwargs
33193321
)
33203322

@@ -3434,11 +3436,11 @@ def forward(
34343436
molecule_atom_indices: Int['b n'] | None = None, # the 'token centre atoms' mentioned in the paper, unsure where it is used in the architecture
34353437
num_sample_steps: int | None = None,
34363438
atom_pos: Float['b m 3'] | None = None,
3437-
distance_labels: Int['b n n'] | None = None,
3438-
pae_labels: Int['b n n'] | None = None,
3439-
pde_labels: Int['b n n'] | None = None,
3440-
plddt_labels: Int['b n'] | None = None,
3441-
resolved_labels: Int['b n'] | None = None,
3439+
distance_labels: Int['b n n'] | Int['b m m'] | None = None,
3440+
pae_labels: Int['b n n'] | Int['b m m'] | None = None,
3441+
pde_labels: Int['b n n'] | Int['b m m'] | None = None,
3442+
plddt_labels: Int['b n'] | Int['b m'] | None = None,
3443+
resolved_labels: Int['b n'] | Int['b m'] | None = None,
34423444
return_loss_breakdown = False,
34433445
return_loss: bool = None,
34443446
return_present_sampled_atoms: bool = False,
@@ -3710,6 +3712,8 @@ def forward(
37103712
pairwise_repr = pairwise.detach(),
37113713
pred_atom_pos = confidence_head_atom_pos_input.detach(),
37123714
molecule_atom_indices = molecule_atom_indices,
3715+
molecule_atom_lens = molecule_atom_lens,
3716+
atom_feats = atom_feats,
37133717
mask = mask,
37143718
return_pae_logits = True
37153719
)
@@ -3884,24 +3888,37 @@ def forward(
38843888
pairwise_repr = pairwise.detach(),
38853889
pred_atom_pos = denoised_atom_pos.detach(),
38863890
molecule_atom_indices = molecule_atom_indices,
3891+
molecule_atom_lens = molecule_atom_lens,
38873892
mask = mask,
3893+
atom_feats = atom_feats,
38883894
return_pae_logits = return_pae_logits
38893895
)
38903896

3897+
# determine which mask to use for labels depending on atom resolution or not for confidence head
3898+
3899+
label_mask = mask
3900+
3901+
if self.confidence_head.atom_resolution:
3902+
label_mask = atom_mask
3903+
3904+
label_pairwise_mask = einx.logical_and('... i, ... j -> ... i j', label_mask, label_mask)
3905+
3906+
# cross entropy losses
3907+
38913908
if exists(pae_labels):
3892-
pae_labels = torch.where(pairwise_mask, pae_labels, ignore)
3909+
pae_labels = torch.where(label_pairwise_mask, pae_labels, ignore)
38933910
pae_loss = F.cross_entropy(ch_logits.pae, pae_labels, ignore_index = ignore)
38943911

38953912
if exists(pde_labels):
3896-
pde_labels = torch.where(pairwise_mask, pde_labels, ignore)
3913+
pde_labels = torch.where(label_pairwise_mask, pde_labels, ignore)
38973914
pde_loss = F.cross_entropy(ch_logits.pde, pde_labels, ignore_index = ignore)
38983915

38993916
if exists(plddt_labels):
3900-
plddt_labels = torch.where(mask, plddt_labels, ignore)
3917+
plddt_labels = torch.where(label_mask, plddt_labels, ignore)
39013918
plddt_loss = F.cross_entropy(ch_logits.plddt, plddt_labels, ignore_index = ignore)
39023919

39033920
if exists(resolved_labels):
3904-
resolved_labels = torch.where(mask, resolved_labels, ignore)
3921+
resolved_labels = torch.where(label_mask, resolved_labels, ignore)
39053922
resolved_loss = F.cross_entropy(ch_logits.resolved, resolved_labels, ignore_index = ignore)
39063923

39073924
confidence_loss = pae_loss + pde_loss + plddt_loss + resolved_loss

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.11"
3+
version = "0.2.12"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -459,12 +459,14 @@ def test_distogram_head():
459459
@pytest.mark.parametrize('missing_atoms', (True, False))
460460
@pytest.mark.parametrize('atom_transformer_intramolecular_attn', (True, False))
461461
@pytest.mark.parametrize('num_molecule_mods', (0, 5))
462+
@pytest.mark.parametrize('confidence_head_atom_resolution', (True, False))
462463
def test_alphafold3(
463464
window_atompair_inputs: bool,
464465
stochastic_frame_average: bool,
465466
missing_atoms: bool,
466467
atom_transformer_intramolecular_attn: bool,
467-
num_molecule_mods: int
468+
num_molecule_mods: int,
469+
confidence_head_atom_resolution: bool
468470
):
469471
seq_len = 16
470472
atoms_per_window = 27
@@ -509,10 +511,11 @@ def test_alphafold3(
509511
distogram_atom_indices = molecule_atom_lens - 1
510512
molecule_atom_indices = molecule_atom_lens - 1
511513

512-
pae_labels = torch.randint(0, 64, (2, seq_len, seq_len))
513-
pde_labels = torch.randint(0, 64, (2, seq_len, seq_len))
514-
plddt_labels = torch.randint(0, 50, (2, seq_len))
515-
resolved_labels = torch.randint(0, 2, (2, seq_len))
514+
label_len = atom_seq_len if confidence_head_atom_resolution else seq_len
515+
pae_labels = torch.randint(0, 64, (2, label_len, label_len))
516+
pde_labels = torch.randint(0, 64, (2, label_len, label_len))
517+
plddt_labels = torch.randint(0, 50, (2, label_len))
518+
resolved_labels = torch.randint(0, 2, (2, label_len))
516519

517520
alphafold3 = Alphafold3(
518521
dim_atom_inputs = 77,
@@ -538,7 +541,8 @@ def test_alphafold3(
538541
token_transformer_depth = 1,
539542
atom_decoder_depth = 1,
540543
),
541-
stochastic_frame_average = stochastic_frame_average
544+
stochastic_frame_average = stochastic_frame_average,
545+
confidence_head_atom_resolution = confidence_head_atom_resolution
542546
)
543547

544548
loss, breakdown = alphafold3(

0 commit comments

Comments
 (0)