@@ -3639,19 +3639,17 @@ def forward(
36393639 if exists (atom_mask ):
36403640 sampled_atom_pos = einx .where ('b m, b m c, -> b m c' , atom_mask , sampled_atom_pos , 0. )
36413641
3642+ if return_confidence_head_logits :
3643+ assert exists (molecule_atom_indices )
3644+ pred_atom_pos = einx .get_at ('b [m] c, b n -> b n c' , sampled_atom_pos , molecule_atom_indices )
3645+
36423646 if exists (missing_atom_mask ) and return_present_sampled_atoms :
36433647 sampled_atom_pos = sampled_atom_pos [~ missing_atom_mask ]
36443648
36453649 if not return_confidence_head_logits :
36463650 return sampled_atom_pos
36473651
3648- # todo - handle missing atoms
3649-
3650- assert exists (molecule_atom_indices )
3651-
3652- pred_atom_pos = einx .get_at ('b [m] c, b n -> b n c' , sampled_atom_pos , molecule_atom_indices )
3653-
3654- logits = self .confidence_head (
3652+ confidence_head_logits = self .confidence_head (
36553653 single_repr = single .detach (),
36563654 single_inputs_repr = single_inputs .detach (),
36573655 pairwise_repr = pairwise .detach (),
@@ -3660,7 +3658,7 @@ def forward(
36603658 return_pae_logits = True
36613659 )
36623660
3663- return sampled_atom_pos , logits
3661+ return sampled_atom_pos , confidence_head_logits
36643662
36653663 # if being forced to return loss, but do not have sufficient information to return losses, just return 0
36663664
@@ -3826,7 +3824,7 @@ def forward(
38263824
38273825 pred_atom_pos = einx .get_at ('b [m] c, b n -> b n c' , denoised_atom_pos , molecule_atom_indices )
38283826
3829- logits = self .confidence_head (
3827+ ch_logits = self .confidence_head (
38303828 single_repr = single .detach (),
38313829 single_inputs_repr = single_inputs .detach (),
38323830 pairwise_repr = pairwise .detach (),
@@ -3837,19 +3835,19 @@ def forward(
38373835
38383836 if exists (pae_labels ):
38393837 pae_labels = torch .where (pairwise_mask , pae_labels , ignore )
3840- pae_loss = F .cross_entropy (logits .pae , pae_labels , ignore_index = ignore )
3838+ pae_loss = F .cross_entropy (ch_logits .pae , pae_labels , ignore_index = ignore )
38413839
38423840 if exists (pde_labels ):
38433841 pde_labels = torch .where (pairwise_mask , pde_labels , ignore )
3844- pde_loss = F .cross_entropy (logits .pde , pde_labels , ignore_index = ignore )
3842+ pde_loss = F .cross_entropy (ch_logits .pde , pde_labels , ignore_index = ignore )
38453843
38463844 if exists (plddt_labels ):
38473845 plddt_labels = torch .where (mask , plddt_labels , ignore )
3848- plddt_loss = F .cross_entropy (logits .plddt , plddt_labels , ignore_index = ignore )
3846+ plddt_loss = F .cross_entropy (ch_logits .plddt , plddt_labels , ignore_index = ignore )
38493847
38503848 if exists (resolved_labels ):
38513849 resolved_labels = torch .where (mask , resolved_labels , ignore )
3852- resolved_loss = F .cross_entropy (logits .resolved , resolved_labels , ignore_index = ignore )
3850+ resolved_loss = F .cross_entropy (ch_logits .resolved , resolved_labels , ignore_index = ignore )
38533851
38543852 confidence_loss = pae_loss + pde_loss + plddt_loss + resolved_loss
38553853
0 commit comments