22global ein notation:
33
44b - batch
5+ ba - batch with augmentation
56h - heads
67n - residue sequence length
78i - residue sequence length (source)
1920from __future__ import annotations
2021
2122from math import pi , sqrt
22- from functools import partial
23+ from functools import partial , wraps
2324from collections import namedtuple
2425
2526import torch
@@ -78,6 +79,14 @@ def pack_one(t, pattern):
7879def unpack_one (t , ps , pattern ):
7980 return unpack (t , ps , pattern )[0 ]
8081
82+ def maybe (fn ):
83+ @wraps (fn )
84+ def inner (t , * args , ** kwargs ):
85+ if not exists (t ):
86+ return None
87+ return fn (t , * args , ** kwargs )
88+ return inner
89+
8190# Loss functions
8291
8392@typecheck
@@ -1562,9 +1571,9 @@ class DiffusionLossBreakdown(NamedTuple):
15621571
15631572class ElucidatedAtomDiffusionReturn (NamedTuple ):
15641573 loss : Float ['' ]
1565- denoised_atom_pos : Float ['b m 3' ]
1574+ denoised_atom_pos : Float ['ba m 3' ]
15661575 loss_breakdown : DiffusionLossBreakdown
1567- noise_sigmas : Float [' b ' ]
1576+ noise_sigmas : Float [' ba ' ]
15681577
15691578class ElucidatedAtomDiffusion (Module ):
15701579 @typecheck
@@ -1583,7 +1592,7 @@ def __init__(
15831592 S_tmin = 0.05 ,
15841593 S_tmax = 50 ,
15851594 S_noise = 1.003 ,
1586- smooth_lddt_loss_kwargs : dict = dict ()
1595+ smooth_lddt_loss_kwargs : dict = dict (),
15871596 ):
15881597 super ().__init__ ()
15891598 self .net = net
@@ -1749,16 +1758,24 @@ def forward(
17491758 self ,
17501759 normalized_atom_pos : Float ['b m 3' ],
17511760 atom_mask : Bool ['b m' ],
1761+ atom_feats : Float ['b m da' ],
1762+ atompair_feats : Float ['b m m dap' ],
1763+ mask : Bool ['b n' ],
1764+ single_trunk_repr : Float ['b n dst' ],
1765+ single_inputs_repr : Float ['b n dsi' ],
1766+ pairwise_trunk : Float ['b n n dpt' ],
1767+ pairwise_rel_pos_feats : Float ['b n n dpr' ],
17521768 return_denoised_pos = False ,
17531769 additional_residue_feats : Float ['b n 10' ] | None = None ,
17541770 add_smooth_lddt_loss = False ,
17551771 add_bond_loss = False ,
17561772 nucleotide_loss_weight = 5. ,
17571773 ligand_loss_weight = 10. ,
17581774 return_loss_breakdown = False ,
1759- ** network_condition_kwargs
17601775 ) -> ElucidatedAtomDiffusionReturn :
17611776
1777+ # diffusion loss
1778+
17621779 batch_size = normalized_atom_pos .shape [0 ]
17631780
17641781 sigmas = self .noise_distribution (batch_size )
@@ -1768,12 +1785,19 @@ def forward(
17681785
17691786 noised_atom_pos = normalized_atom_pos + padded_sigmas * noise # alphas are 1. in the paper
17701787
1771- network_condition_kwargs .update (atom_mask = atom_mask )
1772-
17731788 denoised_atom_pos = self .preconditioned_network_forward (
17741789 noised_atom_pos ,
17751790 sigmas ,
1776- network_condition_kwargs = network_condition_kwargs
1791+ network_condition_kwargs = dict (
1792+ atom_feats = atom_feats ,
1793+ atom_mask = atom_mask ,
1794+ atompair_feats = atompair_feats ,
1795+ mask = mask ,
1796+ single_trunk_repr = single_trunk_repr ,
1797+ single_inputs_repr = single_inputs_repr ,
1798+ pairwise_trunk = pairwise_trunk ,
1799+ pairwise_rel_pos_feats = pairwise_rel_pos_feats ,
1800+ )
17771801 )
17781802
17791803 total_loss = 0.
@@ -2386,6 +2410,7 @@ def __init__(
23862410 num_pde_bins = 64 ,
23872411 num_pae_bins = 64 ,
23882412 sigma_data = 16 ,
2413+ diffusion_num_augmentations = 4 ,
23892414 loss_confidence_weight = 1e-4 ,
23902415 loss_distogram_weight = 1e-2 ,
23912416 loss_diffusion_weight = 4. ,
@@ -2447,12 +2472,18 @@ def __init__(
24472472 S_tmin = 0.05 ,
24482473 S_tmax = 50 ,
24492474 S_noise = 1.003 ,
2450- )
2475+ ),
2476+ augment_kwargs : dict = dict ()
24512477 ):
24522478 super ().__init__ ()
24532479
24542480 self .atoms_per_window = atoms_per_window
24552481
2482+ # augmentation
2483+
2484+ self .num_augmentations = diffusion_num_augmentations
2485+ self .augmenter = CentreRandomAugmentation (** augment_kwargs )
2486+
24562487 # input feature embedder
24572488
24582489 self .input_embedder = InputFeatureEmbedder (
@@ -2688,44 +2719,26 @@ def forward(
26882719
26892720 return_loss = atom_pos_given or has_labels
26902721
2691- # setup all the data necessary for conditioning the diffusion module
2692-
2693- diffusion_cond = dict (
2694- atom_feats = atom_feats ,
2695- atompair_feats = atompair_feats ,
2696- atom_mask = atom_mask ,
2697- mask = mask ,
2698- single_trunk_repr = single ,
2699- single_inputs_repr = single_inputs ,
2700- pairwise_trunk = pairwise ,
2701- pairwise_rel_pos_feats = relative_position_encoding
2702- )
2703-
27042722 # if neither atom positions or any labels are passed in, sample a structure and return
27052723
27062724 if not return_loss :
27072725 return self .edm .sample (
27082726 num_sample_steps = num_sample_steps ,
2709- ** diffusion_cond
2727+ atom_feats = atom_feats ,
2728+ atompair_feats = atompair_feats ,
2729+ atom_mask = atom_mask ,
2730+ mask = mask ,
2731+ single_trunk_repr = single ,
2732+ single_inputs_repr = single_inputs ,
2733+ pairwise_trunk = pairwise ,
2734+ pairwise_rel_pos_feats = relative_position_encoding
27102735 )
27112736
27122737 # losses default to 0
27132738
27142739 distogram_loss = diffusion_loss = confidence_loss = pae_loss = pde_loss = plddt_loss = resolved_loss = self .zero
27152740
2716- # otherwise, noise and make it learn to denoise
2717-
2718- if exists (atom_pos ):
2719- diffusion_loss , denoised_atom_pos , diffusion_loss_breakdown , _ = self .edm (
2720- atom_pos ,
2721- additional_residue_feats = additional_residue_feats ,
2722- add_smooth_lddt_loss = diffusion_add_smooth_lddt_loss ,
2723- add_bond_loss = diffusion_add_bond_loss ,
2724- return_denoised_pos = True ,
2725- ** diffusion_cond
2726- )
2727-
2728- # calculate all logits and losses
2741+ # calculate distogram logits and losses
27292742
27302743 ignore = self .ignore_index
27312744
@@ -2736,15 +2749,82 @@ def forward(
27362749 distogram_logits = self .distogram_head (pairwise )
27372750 distogram_loss = F .cross_entropy (distogram_logits , distance_labels , ignore_index = ignore )
27382751
2752+ # otherwise, noise and make it learn to denoise
2753+
2754+ calc_diffusion_loss = exists (atom_pos )
2755+
2756+ if calc_diffusion_loss :
2757+
2758+ num_augs = self .num_augmentations
2759+
2760+ # take care of augmentation
2761+ # they did 48 during training, as the trunk did the heavy lifting
2762+
2763+ if num_augs > 1 :
2764+ (
2765+ atom_pos ,
2766+ atom_mask ,
2767+ atom_feats ,
2768+ atompair_feats ,
2769+ mask ,
2770+ pairwise_mask ,
2771+ single ,
2772+ single_inputs ,
2773+ pairwise ,
2774+ relative_position_encoding ,
2775+ additional_residue_feats ,
2776+ residue_atom_indices ,
2777+ pae_labels ,
2778+ pde_labels ,
2779+ plddt_labels ,
2780+ resolved_labels ,
2781+
2782+ ) = tuple (
2783+ maybe (repeat )(t , 'b ... -> (b a) ...' , a = num_augs )
2784+ for t in (
2785+ atom_pos ,
2786+ atom_mask ,
2787+ atom_feats ,
2788+ atompair_feats ,
2789+ mask ,
2790+ pairwise_mask ,
2791+ single ,
2792+ single_inputs ,
2793+ pairwise ,
2794+ relative_position_encoding ,
2795+ additional_residue_feats ,
2796+ residue_atom_indices ,
2797+ pae_labels ,
2798+ pde_labels ,
2799+ plddt_labels ,
2800+ resolved_labels
2801+ )
2802+ )
2803+
2804+ atom_pos = self .augmenter (atom_pos )
2805+
2806+ diffusion_loss , denoised_atom_pos , diffusion_loss_breakdown , _ = self .edm (
2807+ atom_pos ,
2808+ additional_residue_feats = additional_residue_feats ,
2809+ add_smooth_lddt_loss = diffusion_add_smooth_lddt_loss ,
2810+ add_bond_loss = diffusion_add_bond_loss ,
2811+ atom_feats = atom_feats ,
2812+ atompair_feats = atompair_feats ,
2813+ atom_mask = atom_mask ,
2814+ mask = mask ,
2815+ single_trunk_repr = single ,
2816+ single_inputs_repr = single_inputs ,
2817+ pairwise_trunk = pairwise ,
2818+ pairwise_rel_pos_feats = relative_position_encoding ,
2819+ return_denoised_pos = True ,
2820+ )
2821+
27392822 # confidence head
27402823
27412824 should_call_confidence_head = any ([* map (exists , confidence_head_labels )])
27422825 return_pae_logits = exists (pae_labels )
27432826
2744- if should_call_confidence_head :
2745- assert exists (atom_pos ), 'diffusion module needs to have been called'
2746-
2747- assert exists (residue_atom_indices )
2827+ if calc_diffusion_loss and should_call_confidence_head :
27482828
27492829 pred_atom_pos = einx .get_at ('b (n [w]) c, b n -> b n c' , denoised_atom_pos , residue_atom_indices )
27502830
0 commit comments