@@ -2925,14 +2925,24 @@ def forward(
29252925 single_inputs_repr : Float ['b n dsi' ],
29262926 single_repr : Float ['b n ds' ],
29272927 pairwise_repr : Float ['b n n dp' ],
2928- pred_atom_pos : Float ['b n 3' ],
2928+ pred_atom_pos : Float ['b n 3' ] | Float ['b m 3' ],
2929+ molecule_atom_indices : Int ['b n' ] | None = None ,
29292930 mask : Bool ['b n' ] | None = None ,
29302931 return_pae_logits = True
29312932
29322933 ) -> ConfidenceHeadLogits :
29332934
29342935 pairwise_repr = pairwise_repr + self .single_inputs_to_pairwise (single_inputs_repr )
29352936
2937+ # pluck out the representative atoms for non-atomic resolution confidence head
2938+
2939+ is_atom_seq = pred_atom_pos .shape [- 2 ] > single_inputs_repr .shape [- 2 ]
2940+
2941+ assert not is_atom_seq or exists (molecule_atom_indices )
2942+
2943+ if is_atom_seq :
2944+ pred_atom_pos = einx .get_at ('b [m] c, b n -> b n c' , pred_atom_pos , molecule_atom_indices )
2945+
29362946 # interatomic distances - embed and add to pairwise
29372947
29382948 interatom_dist = torch .cdist (pred_atom_pos , pred_atom_pos , p = 2 )
@@ -3640,8 +3650,7 @@ def forward(
36403650 sampled_atom_pos = einx .where ('b m, b m c, -> b m c' , atom_mask , sampled_atom_pos , 0. )
36413651
36423652 if return_confidence_head_logits :
3643- assert exists (molecule_atom_indices )
3644- pred_atom_pos = einx .get_at ('b [m] c, b n -> b n c' , sampled_atom_pos , molecule_atom_indices )
3653+ confidence_head_atom_pos_input = sampled_atom_pos .clone ()
36453654
36463655 if exists (missing_atom_mask ) and return_present_sampled_atoms :
36473656 sampled_atom_pos = sampled_atom_pos [~ missing_atom_mask ]
@@ -3653,7 +3662,8 @@ def forward(
36533662 single_repr = single .detach (),
36543663 single_inputs_repr = single_inputs .detach (),
36553664 pairwise_repr = pairwise .detach (),
3656- pred_atom_pos = pred_atom_pos .detach (),
3665+ pred_atom_pos = confidence_head_atom_pos_input .detach (),
3666+ molecule_atom_indices = molecule_atom_indices ,
36573667 mask = mask ,
36583668 return_pae_logits = True
36593669 )
@@ -3822,13 +3832,12 @@ def forward(
38223832 tqdm_pbar_title = 'training rollout'
38233833 )
38243834
3825- pred_atom_pos = einx .get_at ('b [m] c, b n -> b n c' , denoised_atom_pos , molecule_atom_indices )
3826-
38273835 ch_logits = self .confidence_head (
38283836 single_repr = single .detach (),
38293837 single_inputs_repr = single_inputs .detach (),
38303838 pairwise_repr = pairwise .detach (),
3831- pred_atom_pos = pred_atom_pos .detach (),
3839+ pred_atom_pos = denoised_atom_pos .detach (),
3840+ molecule_atom_indices = molecule_atom_indices ,
38323841 mask = mask ,
38333842 return_pae_logits = return_pae_logits
38343843 )
0 commit comments