@@ -3137,7 +3137,8 @@ def __init__(
31373137 S_noise = 1.003 ,
31383138 ),
31393139 augment_kwargs : dict = dict (),
3140- stochastic_frame_average = False
3140+ stochastic_frame_average = False ,
3141+ confidence_head_atom_resolution = False
31413142 ):
31423143 super ().__init__ ()
31433144
@@ -3315,6 +3316,7 @@ def __init__(
33153316 num_plddt_bins = num_plddt_bins ,
33163317 num_pde_bins = num_pde_bins ,
33173318 num_pae_bins = num_pae_bins ,
3319+ atom_resolution = confidence_head_atom_resolution ,
33183320 ** confidence_head_kwargs
33193321 )
33203322
@@ -3434,11 +3436,11 @@ def forward(
34343436 molecule_atom_indices : Int ['b n' ] | None = None , # the 'token centre atoms' mentioned in the paper, unsure where it is used in the architecture
34353437 num_sample_steps : int | None = None ,
34363438 atom_pos : Float ['b m 3' ] | None = None ,
3437- distance_labels : Int ['b n n' ] | None = None ,
3438- pae_labels : Int ['b n n' ] | None = None ,
3439- pde_labels : Int ['b n n' ] | None = None ,
3440- plddt_labels : Int ['b n' ] | None = None ,
3441- resolved_labels : Int ['b n' ] | None = None ,
3439+ distance_labels : Int ['b n n' ] | Int [ 'b m m' ] | None = None ,
3440+ pae_labels : Int ['b n n' ] | Int [ 'b m m' ] | None = None ,
3441+ pde_labels : Int ['b n n' ] | Int [ 'b m m' ] | None = None ,
3442+ plddt_labels : Int ['b n' ] | Int [ 'b m' ] | None = None ,
3443+ resolved_labels : Int ['b n' ] | Int [ 'b m' ] | None = None ,
34423444 return_loss_breakdown = False ,
34433445 return_loss : bool = None ,
34443446 return_present_sampled_atoms : bool = False ,
@@ -3710,6 +3712,8 @@ def forward(
37103712 pairwise_repr = pairwise .detach (),
37113713 pred_atom_pos = confidence_head_atom_pos_input .detach (),
37123714 molecule_atom_indices = molecule_atom_indices ,
3715+ molecule_atom_lens = molecule_atom_lens ,
3716+ atom_feats = atom_feats ,
37133717 mask = mask ,
37143718 return_pae_logits = True
37153719 )
@@ -3884,24 +3888,37 @@ def forward(
38843888 pairwise_repr = pairwise .detach (),
38853889 pred_atom_pos = denoised_atom_pos .detach (),
38863890 molecule_atom_indices = molecule_atom_indices ,
3891+ molecule_atom_lens = molecule_atom_lens ,
38873892 mask = mask ,
3893+ atom_feats = atom_feats ,
38883894 return_pae_logits = return_pae_logits
38893895 )
38903896
3897+ # determine which mask to use for labels depending on atom resolution or not for confidence head
3898+
3899+ label_mask = mask
3900+
3901+ if self .confidence_head .atom_resolution :
3902+ label_mask = atom_mask
3903+
3904+ label_pairwise_mask = einx .logical_and ('... i, ... j -> ... i j' , label_mask , label_mask )
3905+
3906+ # cross entropy losses
3907+
38913908 if exists (pae_labels ):
3892- pae_labels = torch .where (pairwise_mask , pae_labels , ignore )
3909+ pae_labels = torch .where (label_pairwise_mask , pae_labels , ignore )
38933910 pae_loss = F .cross_entropy (ch_logits .pae , pae_labels , ignore_index = ignore )
38943911
38953912 if exists (pde_labels ):
3896- pde_labels = torch .where (pairwise_mask , pde_labels , ignore )
3913+ pde_labels = torch .where (label_pairwise_mask , pde_labels , ignore )
38973914 pde_loss = F .cross_entropy (ch_logits .pde , pde_labels , ignore_index = ignore )
38983915
38993916 if exists (plddt_labels ):
3900- plddt_labels = torch .where (mask , plddt_labels , ignore )
3917+ plddt_labels = torch .where (label_mask , plddt_labels , ignore )
39013918 plddt_loss = F .cross_entropy (ch_logits .plddt , plddt_labels , ignore_index = ignore )
39023919
39033920 if exists (resolved_labels ):
3904- resolved_labels = torch .where (mask , resolved_labels , ignore )
3921+ resolved_labels = torch .where (label_mask , resolved_labels , ignore )
39053922 resolved_loss = F .cross_entropy (ch_logits .resolved , resolved_labels , ignore_index = ignore )
39063923
39073924 confidence_loss = pae_loss + pde_loss + plddt_loss + resolved_loss
0 commit comments