@@ -3428,7 +3428,7 @@ def forward(
34283428
34293429class ConfidenceHeadLogits (NamedTuple ):
34303430 pae : Float ['b pae m m' ] | None
3431- pde : Float ['b pde n n ' ]
3431+ pde : Float ['b pde m m ' ]
34323432 plddt : Float ['b plddt m' ]
34333433 resolved : Float ['b 2 m' ]
34343434
@@ -3567,7 +3567,7 @@ def forward(
35673567
35683568 # to logits
35693569
3570- pde_logits = self .to_pde_logits (symmetrize (pairwise_repr ))
3570+ pde_logits = self .to_pde_logits (symmetrize (atom_pairwise_repr ))
35713571
35723572 plddt_logits = self .to_plddt_logits (atom_single_repr )
35733573 resolved_logits = self .to_resolved_logits (atom_single_repr )
@@ -4344,7 +4344,7 @@ def can_calculate_unresolved_protein_rasa(self):
43444344 @typecheck
43454345 def compute_gpde (
43464346 self ,
4347- pde_logits : Float ["b pde n n" ],
4347+ pde_logits : Float ["b pde n n" ],
43484348 dist_logits : Float ["b dist n n" ],
43494349 dist_breaks : Float [" dist_break" ],
43504350 tok_repr_atm_mask : Bool ["b n" ],
@@ -4852,6 +4852,7 @@ def compute_model_selection_score(
48524852 top_ranked_sample = max (
48534853 scored_samples , key = lambda x : x [- 1 ].mean ()
48544854 ) # rank by batch-averaged gPDE
4855+
48554856 best_of_5_sample = max (
48564857 scored_samples , key = lambda x : x [- 2 ].mean ()
48574858 ) # rank by batch-averaged lDDT
@@ -5898,29 +5899,11 @@ def forward(
58985899 pde_labels = None
58995900
59005901 if atom_pos_given :
5901- denoised_molecule_pos = None
5902-
5903- assert exists (
5904- molecule_atom_indices
5905- ), "`molecule_atom_indices` must be passed in for calculating non-atomic PDE labels"
5906-
5907- # molecule_pos = einx.get_at('b [m] c, b n -> b n c', atom_pos, molecule_atom_indices)
5908-
5909- mol_atom_indices = repeat (
5910- molecule_atom_indices , "b n -> b n c" , c = atom_pos .shape [- 1 ]
5911- )
59125902
5913- molecule_pos = atom_pos .gather (1 , mol_atom_indices )
5914- denoised_molecule_pos = denoised_atom_pos .gather (1 , mol_atom_indices )
5903+ pde_atom_mask = batch_repeat_interleave (valid_molecule_atom_mask , molecule_atom_lens )
59155904
5916- molecule_mask = valid_molecule_atom_mask
5917-
5918- pde_gt_dist = torch .cdist (molecule_pos , molecule_pos , p = 2 )
5919- pde_pred_dist = torch .cdist (
5920- denoised_molecule_pos ,
5921- denoised_molecule_pos ,
5922- p = 2 ,
5923- )
5905+ pde_gt_dist = torch .cdist (atom_pos , atom_pos )
5906+ pde_pred_dist = torch .cdist (denoised_atom_pos , denoised_atom_pos )
59245907
59255908 # calculate pde labels as distance error binned to 64 (0 - 32A)
59265909
@@ -5929,8 +5912,8 @@ def forward(
59295912
59305913 # account for representative molecule atom missing from residue (-1 set on molecule_atom_indices field)
59315914
5932- molecule_mask = to_pairwise_mask (molecule_mask )
5933- pde_labels .masked_fill_ (~ molecule_mask , ignore )
5915+ pde_pairwise_atom_mask = to_pairwise_mask (pde_atom_mask )
5916+ pde_labels .masked_fill_ (~ pde_pairwise_atom_mask , ignore )
59345917
59355918 # determine plddt labels if possible
59365919
@@ -6016,7 +5999,17 @@ def forward(
60165999
60176000 confidence_weight = confidence_mask .float ()
60186001
6019- def cross_entropy_with_weight (logits , labels , weight , ignore_index : int ):
6002+ @typecheck
6003+ def cross_entropy_with_weight (
6004+ logits : Float ['b l ...' ],
6005+ labels : Int ['b ...' ],
6006+ weight : Float [' b' ],
6007+ mask : Bool ['b ...' ],
6008+ ignore_index : int
6009+ ) -> Float ['' ]:
6010+
6011+ labels = torch .where (mask , labels , ignore_index )
6012+
60206013 return F .cross_entropy (
60216014 einx .multiply ('b ..., b -> b ...' , logits , weight ),
60226015 einx .multiply ('b ..., b -> b ...' , labels , weight .long ()),
@@ -6028,32 +6021,28 @@ def cross_entropy_with_weight(logits, labels, weight, ignore_index: int):
60286021 f"pae_labels shape { pae_labels .shape [- 1 ]} does not match "
60296022 f"ch_logits.pae shape { ch_logits .pae .shape [- 1 ]} "
60306023 )
6031- pae_labels = torch .where (label_pairwise_mask , pae_labels , ignore )
6032- pae_loss = cross_entropy_with_weight (ch_logits .pae , pae_labels , confidence_weight , ignore )
6024+ pae_loss = cross_entropy_with_weight (ch_logits .pae , pae_labels , confidence_weight , label_pairwise_mask , ignore )
60336025
60346026 if exists (pde_labels ):
60356027 assert pde_labels .shape [- 1 ] == ch_logits .pde .shape [- 1 ], (
60366028 f"pde_labels shape { pde_labels .shape [- 1 ]} does not match "
60376029 f"ch_logits.pde shape { ch_logits .pde .shape [- 1 ]} "
60386030 )
6039- pde_labels = torch .where (to_pairwise_mask (mask ), pde_labels , ignore )
6040- pde_loss = cross_entropy_with_weight (ch_logits .pde , pde_labels , confidence_weight , ignore )
6031+ pde_loss = cross_entropy_with_weight (ch_logits .pde , pde_labels , confidence_weight , label_pairwise_mask , ignore )
60416032
60426033 if exists (plddt_labels ):
60436034 assert plddt_labels .shape [- 1 ] == ch_logits .plddt .shape [- 1 ], (
60446035 f"plddt_labels shape { plddt_labels .shape [- 1 ]} does not match "
60456036 f"ch_logits.plddt shape { ch_logits .plddt .shape [- 1 ]} "
60466037 )
6047- plddt_labels = torch .where (label_mask , plddt_labels , ignore )
6048- plddt_loss = cross_entropy_with_weight (ch_logits .plddt , plddt_labels , confidence_weight , ignore )
6038+ plddt_loss = cross_entropy_with_weight (ch_logits .plddt , plddt_labels , confidence_weight , label_mask , ignore )
60496039
60506040 if exists (resolved_labels ):
60516041 assert resolved_labels .shape [- 1 ] == ch_logits .resolved .shape [- 1 ], (
60526042 f"resolved_labels shape { resolved_labels .shape [- 1 ]} does not match "
60536043 f"ch_logits.resolved shape { ch_logits .resolved .shape [- 1 ]} "
60546044 )
6055- resolved_labels = torch .where (label_mask , resolved_labels , ignore )
6056- resolved_loss = cross_entropy_with_weight (ch_logits .resolved , resolved_labels , confidence_weight , ignore )
6045+ resolved_loss = cross_entropy_with_weight (ch_logits .resolved , resolved_labels , confidence_weight , label_mask , ignore )
60576046
60586047 confidence_loss = pae_loss + pde_loss + plddt_loss + resolved_loss
60596048
0 commit comments