@@ -1982,12 +1982,14 @@ def __init__(
19821982 dim_input_embedder_token = 384 ,
19831983 dim_single = 384 ,
19841984 dim_pairwise = 128 ,
1985+ dim_token = 768 ,
19851986 atompair_dist_bins : Float [' dist_bins' ] = torch .linspace (3 , 20 , 37 ),
19861987 ignore_index = - 1 ,
19871988 num_dist_bins = 38 ,
19881989 num_plddt_bins = 50 ,
19891990 num_pde_bins = 64 ,
19901991 num_pae_bins = 64 ,
1992+ sigma_data = 16 ,
19911993 loss_confidence_weight = 1e-4 ,
19921994 loss_distogram_weight = 1e-2 ,
19931995 loss_diffusion_weight = 4. ,
@@ -2023,6 +2025,32 @@ def __init__(
20232025 relative_position_encoding_kwargs : dict = dict (
20242026 r_max = 32 ,
20252027 s_max = 2 ,
2028+ ),
2029+ diffusion_module_kwargs : dict = dict (
2030+ single_cond_kwargs = dict (
2031+ num_transitions = 2 ,
2032+ transition_expansion_factor = 2 ,
2033+ ),
2034+ pairwise_cond_kwargs = dict (
2035+ num_transitions = 2
2036+ ),
2037+ atom_encoder_depth = 3 ,
2038+ atom_encoder_heads = 4 ,
2039+ token_transformer_depth = 24 ,
2040+ token_transformer_heads = 16 ,
2041+ atom_decoder_depth = 3 ,
2042+ atom_decoder_heads = 4
2043+ ),
2044+ edm_kwargs : dict = dict (
2045+ sigma_min = 0.002 ,
2046+ sigma_max = 80 ,
2047+ rho = 7 ,
2048+ P_mean = - 1.2 ,
2049+ P_std = 1.2 ,
2050+ S_churn = 80 ,
2051+ S_tmin = 0.05 ,
2052+ S_tmax = 50 ,
2053+ S_noise = 1.003 ,
20262054 )
20272055 ):
20282056 super ().__init__ ()
@@ -2091,6 +2119,27 @@ def __init__(
20912119 LinearNoBias (dim_pairwise , dim_pairwise )
20922120 )
20932121
2122+ # diffusion
2123+
2124+ self .diffusion_module = DiffusionModule (
2125+ dim_pairwise_trunk = dim_pairwise ,
2126+ dim_pairwise_rel_pos_feats = dim_pairwise ,
2127+ atoms_per_window = atoms_per_window ,
2128+ dim_pairwise = dim_pairwise ,
2129+ sigma_data = sigma_data ,
2130+ dim_atom = dim_atom ,
2131+ dim_atompair = dim_atompair ,
2132+ dim_token = dim_token ,
2133+ dim_single = dim_single + dim_single_inputs ,
2134+ ** diffusion_module_kwargs
2135+ )
2136+
2137+ self .edm = ElucidatedAtomDiffusion (
2138+ self .diffusion_module ,
2139+ sigma_data = sigma_data ,
2140+ ** edm_kwargs
2141+ )
2142+
20942143 # logit heads
20952144
20962145 self .distogram_head = DistogramHead (
@@ -2116,11 +2165,11 @@ def __init__(
21162165 self .loss_confidence_weight = loss_confidence_weight
21172166 self .loss_diffusion_weight = loss_diffusion_weight
21182167
2119- self .register_buffer ('dummy ' , torch .tensor (0 ), persistent = False )
2168+ self .register_buffer ('zero ' , torch .tensor (0. ), persistent = False )
21202169
21212170 @property
21222171 def device (self ):
2123- return self .dummy .device
2172+ return self .zero .device
21242173
21252174 @typecheck
21262175 def forward (
@@ -2134,6 +2183,8 @@ def forward(
21342183 templates : Float ['b t n n dt' ],
21352184 template_mask : Bool ['b t' ],
21362185 num_recycling_steps : int = 1 ,
2186+ num_sample_steps : int | None = None ,
2187+ atom_pos : Float ['b m 3' ] | None = None ,
21372188 distance_labels : Int ['b n n' ] | None = None ,
21382189 pae_labels : Int ['b n n' ] | None = None ,
21392190 pde_labels : Int ['b n n' ] | None = None ,
@@ -2228,23 +2279,52 @@ def forward(
22282279 # determine whether to return loss if any labels were to be passed in
22292280 # otherwise will sample the atomic coordinates
22302281
2282+ atom_pos_given = exists (atom_pos )
2283+
22312284 labels = (distance_labels , pae_labels , pde_labels , plddt_labels , resolved_labels )
2232- return_loss = any ([* map (exists , labels )])
2285+ has_labels = any ([* map (exists , labels )])
2286+
2287+ return_loss = atom_pos_given or has_labels
2288+
2289+ # setup all the data necessary for conditioning the diffusion module
2290+
2291+ diffusion_cond = dict (
2292+ atom_feats = atom_feats ,
2293+ atompair_feats = atompair_feats ,
2294+ atom_mask = atom_mask ,
2295+ mask = mask ,
2296+ single_trunk_repr = single ,
2297+ single_inputs_repr = single_inputs ,
2298+ pairwise_trunk = pairwise ,
2299+ pairwise_rel_pos_feats = relative_position_encoding
2300+ )
2301+
2302+ # if neither atom positions or any labels are passed in, sample a structure and return
22332303
22342304 if not return_loss :
2235- return torch .randn ((* atom_inputs .shape [:2 ], 3 ), device = self .device )
2305+ return self .edm .sample (num_sample_steps = num_sample_steps , ** diffusion_cond )
2306+
2307+ # otherwise, noise and make it learn to denoise
2308+
2309+ diffusion_loss = self .zero
2310+
2311+ if exists (atom_pos ):
2312+ diffusion_loss = self .edm (atom_pos , ** diffusion_cond )
22362313
22372314 # calculate all logits and losses
22382315
22392316 ignore = self .ignore_index
22402317
2318+ distogram_loss = self .zero
2319+
22412320 if exists (distance_labels ):
22422321 distance_labels = torch .where (pairwise_mask , distance_labels , ignore )
22432322 distogram_logits = self .distogram_head (pairwise )
22442323 distogram_loss = F .cross_entropy (distogram_logits , distance_labels , ignore_index = ignore )
22452324
22462325 loss = (
2247- distogram_loss * self .loss_distogram_weight
2326+ distogram_loss * self .loss_distogram_weight +
2327+ diffusion_loss * self .loss_diffusion_weight
22482328 )
22492329
22502330 return loss
0 commit comments