Skip to content

Commit 1083f23

Browse files
committed
augmentation makes most sense on the outermost level, given labels. they did up to 48 augmentations, after the trunk and its loss is processed
1 parent 6bd8813 commit 1083f23

File tree

2 files changed

+121
-41
lines changed

2 files changed

+121
-41
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 120 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
global ein notation:
33
44
b - batch
5+
ba - batch with augmentation
56
h - heads
67
n - residue sequence length
78
i - residue sequence length (source)
@@ -19,7 +20,7 @@
1920
from __future__ import annotations
2021

2122
from math import pi, sqrt
22-
from functools import partial
23+
from functools import partial, wraps
2324
from collections import namedtuple
2425

2526
import torch
@@ -78,6 +79,14 @@ def pack_one(t, pattern):
7879
def unpack_one(t, ps, pattern):
7980
return unpack(t, ps, pattern)[0]
8081

82+
def maybe(fn):
83+
@wraps(fn)
84+
def inner(t, *args, **kwargs):
85+
if not exists(t):
86+
return None
87+
return fn(t, *args, **kwargs)
88+
return inner
89+
8190
# Loss functions
8291

8392
@typecheck
@@ -1562,9 +1571,9 @@ class DiffusionLossBreakdown(NamedTuple):
15621571

15631572
class ElucidatedAtomDiffusionReturn(NamedTuple):
15641573
loss: Float['']
1565-
denoised_atom_pos: Float['b m 3']
1574+
denoised_atom_pos: Float['ba m 3']
15661575
loss_breakdown: DiffusionLossBreakdown
1567-
noise_sigmas: Float[' b']
1576+
noise_sigmas: Float[' ba']
15681577

15691578
class ElucidatedAtomDiffusion(Module):
15701579
@typecheck
@@ -1583,7 +1592,7 @@ def __init__(
15831592
S_tmin = 0.05,
15841593
S_tmax = 50,
15851594
S_noise = 1.003,
1586-
smooth_lddt_loss_kwargs: dict = dict()
1595+
smooth_lddt_loss_kwargs: dict = dict(),
15871596
):
15881597
super().__init__()
15891598
self.net = net
@@ -1749,16 +1758,24 @@ def forward(
17491758
self,
17501759
normalized_atom_pos: Float['b m 3'],
17511760
atom_mask: Bool['b m'],
1761+
atom_feats: Float['b m da'],
1762+
atompair_feats: Float['b m m dap'],
1763+
mask: Bool['b n'],
1764+
single_trunk_repr: Float['b n dst'],
1765+
single_inputs_repr: Float['b n dsi'],
1766+
pairwise_trunk: Float['b n n dpt'],
1767+
pairwise_rel_pos_feats: Float['b n n dpr'],
17521768
return_denoised_pos = False,
17531769
additional_residue_feats: Float['b n 10'] | None = None,
17541770
add_smooth_lddt_loss = False,
17551771
add_bond_loss = False,
17561772
nucleotide_loss_weight = 5.,
17571773
ligand_loss_weight = 10.,
17581774
return_loss_breakdown = False,
1759-
**network_condition_kwargs
17601775
) -> ElucidatedAtomDiffusionReturn:
17611776

1777+
# diffusion loss
1778+
17621779
batch_size = normalized_atom_pos.shape[0]
17631780

17641781
sigmas = self.noise_distribution(batch_size)
@@ -1768,12 +1785,19 @@ def forward(
17681785

17691786
noised_atom_pos = normalized_atom_pos + padded_sigmas * noise # alphas are 1. in the paper
17701787

1771-
network_condition_kwargs.update(atom_mask = atom_mask)
1772-
17731788
denoised_atom_pos = self.preconditioned_network_forward(
17741789
noised_atom_pos,
17751790
sigmas,
1776-
network_condition_kwargs = network_condition_kwargs
1791+
network_condition_kwargs = dict(
1792+
atom_feats = atom_feats,
1793+
atom_mask = atom_mask,
1794+
atompair_feats = atompair_feats,
1795+
mask = mask,
1796+
single_trunk_repr = single_trunk_repr,
1797+
single_inputs_repr = single_inputs_repr,
1798+
pairwise_trunk = pairwise_trunk,
1799+
pairwise_rel_pos_feats = pairwise_rel_pos_feats,
1800+
)
17771801
)
17781802

17791803
total_loss = 0.
@@ -2386,6 +2410,7 @@ def __init__(
23862410
num_pde_bins = 64,
23872411
num_pae_bins = 64,
23882412
sigma_data = 16,
2413+
diffusion_num_augmentations = 4,
23892414
loss_confidence_weight = 1e-4,
23902415
loss_distogram_weight = 1e-2,
23912416
loss_diffusion_weight = 4.,
@@ -2447,12 +2472,18 @@ def __init__(
24472472
S_tmin = 0.05,
24482473
S_tmax = 50,
24492474
S_noise = 1.003,
2450-
)
2475+
),
2476+
augment_kwargs: dict = dict()
24512477
):
24522478
super().__init__()
24532479

24542480
self.atoms_per_window = atoms_per_window
24552481

2482+
# augmentation
2483+
2484+
self.num_augmentations = diffusion_num_augmentations
2485+
self.augmenter = CentreRandomAugmentation(**augment_kwargs)
2486+
24562487
# input feature embedder
24572488

24582489
self.input_embedder = InputFeatureEmbedder(
@@ -2688,44 +2719,26 @@ def forward(
26882719

26892720
return_loss = atom_pos_given or has_labels
26902721

2691-
# setup all the data necessary for conditioning the diffusion module
2692-
2693-
diffusion_cond = dict(
2694-
atom_feats = atom_feats,
2695-
atompair_feats = atompair_feats,
2696-
atom_mask = atom_mask,
2697-
mask = mask,
2698-
single_trunk_repr = single,
2699-
single_inputs_repr = single_inputs,
2700-
pairwise_trunk = pairwise,
2701-
pairwise_rel_pos_feats = relative_position_encoding
2702-
)
2703-
27042722
# if neither atom positions or any labels are passed in, sample a structure and return
27052723

27062724
if not return_loss:
27072725
return self.edm.sample(
27082726
num_sample_steps = num_sample_steps,
2709-
**diffusion_cond
2727+
atom_feats = atom_feats,
2728+
atompair_feats = atompair_feats,
2729+
atom_mask = atom_mask,
2730+
mask = mask,
2731+
single_trunk_repr = single,
2732+
single_inputs_repr = single_inputs,
2733+
pairwise_trunk = pairwise,
2734+
pairwise_rel_pos_feats = relative_position_encoding
27102735
)
27112736

27122737
# losses default to 0
27132738

27142739
distogram_loss = diffusion_loss = confidence_loss = pae_loss = pde_loss = plddt_loss = resolved_loss = self.zero
27152740

2716-
# otherwise, noise and make it learn to denoise
2717-
2718-
if exists(atom_pos):
2719-
diffusion_loss, denoised_atom_pos, diffusion_loss_breakdown, _ = self.edm(
2720-
atom_pos,
2721-
additional_residue_feats = additional_residue_feats,
2722-
add_smooth_lddt_loss = diffusion_add_smooth_lddt_loss,
2723-
add_bond_loss = diffusion_add_bond_loss,
2724-
return_denoised_pos = True,
2725-
**diffusion_cond
2726-
)
2727-
2728-
# calculate all logits and losses
2741+
# calculate distogram logits and losses
27292742

27302743
ignore = self.ignore_index
27312744

@@ -2736,15 +2749,82 @@ def forward(
27362749
distogram_logits = self.distogram_head(pairwise)
27372750
distogram_loss = F.cross_entropy(distogram_logits, distance_labels, ignore_index = ignore)
27382751

2752+
# otherwise, noise and make it learn to denoise
2753+
2754+
calc_diffusion_loss = exists(atom_pos)
2755+
2756+
if calc_diffusion_loss:
2757+
2758+
num_augs = self.num_augmentations
2759+
2760+
# take care of augmentation
2761+
# they did 48 during training, as the trunk did the heavy lifting
2762+
2763+
if num_augs > 1:
2764+
(
2765+
atom_pos,
2766+
atom_mask,
2767+
atom_feats,
2768+
atompair_feats,
2769+
mask,
2770+
pairwise_mask,
2771+
single,
2772+
single_inputs,
2773+
pairwise,
2774+
relative_position_encoding,
2775+
additional_residue_feats,
2776+
residue_atom_indices,
2777+
pae_labels,
2778+
pde_labels,
2779+
plddt_labels,
2780+
resolved_labels,
2781+
2782+
) = tuple(
2783+
maybe(repeat)(t, 'b ... -> (b a) ...', a = num_augs)
2784+
for t in (
2785+
atom_pos,
2786+
atom_mask,
2787+
atom_feats,
2788+
atompair_feats,
2789+
mask,
2790+
pairwise_mask,
2791+
single,
2792+
single_inputs,
2793+
pairwise,
2794+
relative_position_encoding,
2795+
additional_residue_feats,
2796+
residue_atom_indices,
2797+
pae_labels,
2798+
pde_labels,
2799+
plddt_labels,
2800+
resolved_labels
2801+
)
2802+
)
2803+
2804+
atom_pos = self.augmenter(atom_pos)
2805+
2806+
diffusion_loss, denoised_atom_pos, diffusion_loss_breakdown, _ = self.edm(
2807+
atom_pos,
2808+
additional_residue_feats = additional_residue_feats,
2809+
add_smooth_lddt_loss = diffusion_add_smooth_lddt_loss,
2810+
add_bond_loss = diffusion_add_bond_loss,
2811+
atom_feats = atom_feats,
2812+
atompair_feats = atompair_feats,
2813+
atom_mask = atom_mask,
2814+
mask = mask,
2815+
single_trunk_repr = single,
2816+
single_inputs_repr = single_inputs,
2817+
pairwise_trunk = pairwise,
2818+
pairwise_rel_pos_feats = relative_position_encoding,
2819+
return_denoised_pos = True,
2820+
)
2821+
27392822
# confidence head
27402823

27412824
should_call_confidence_head = any([*map(exists, confidence_head_labels)])
27422825
return_pae_logits = exists(pae_labels)
27432826

2744-
if should_call_confidence_head:
2745-
assert exists(atom_pos), 'diffusion module needs to have been called'
2746-
2747-
assert exists(residue_atom_indices)
2827+
if calc_diffusion_loss and should_call_confidence_head:
27482828

27492829
pred_atom_pos = einx.get_at('b (n [w]) c, b n -> b n c', denoised_atom_pos, residue_atom_indices)
27502830

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.16"
3+
version = "0.0.17"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)