@@ -2853,10 +2853,10 @@ def forward(
28532853# confidence head
28542854
28552855class ConfidenceHeadLogits (NamedTuple ):
2856- pae : Float ['b pae n n' ] | None
2857- pde : Float ['b pde n n' ]
2858- plddt : Float ['b plddt n' ]
2859- resolved : Float ['b 2 n' ]
2856+ pae : Float ['b pae n n' ] | Float [ 'b pae m m' ] | None
2857+ pde : Float ['b pde n n' ] | Float [ 'b pde m m' ]
2858+ plddt : Float ['b plddt n' ] | Float [ 'b plddt m' ]
2859+ resolved : Float ['b 2 n' ] | Float [ 'b 2 m' ]
28602860
28612861class ConfidenceHead (Module ):
28622862 """ Algorithm 31 """
@@ -2866,6 +2866,8 @@ def __init__(
28662866 self ,
28672867 * ,
28682868 dim_single_inputs ,
2869+ atom_resolution = False , # @amorehead discovers that the public api has per-atom resolution confidences. improvise a solution
2870+ dim_atom = 128 ,
28692871 atompair_dist_bins : List [float ],
28702872 dim_single = 384 ,
28712873 dim_pairwise = 128 ,
@@ -2918,6 +2920,19 @@ def __init__(
29182920 Rearrange ('b ... l -> b l ...' )
29192921 )
29202922
2923+ # atom resolution
2924+ # for now, just embed per atom distances, sum to atom features, project to pairwise dimension
2925+
2926+ self .atom_resolution = atom_resolution
2927+
2928+ if atom_resolution :
2929+ self .atom_feats_to_single = LinearNoBias (dim_atom , dim_single )
2930+ self .atom_feats_to_pairwise = LinearNoBiasThenOuterSum (dim_atom , dim_pairwise )
2931+
2932+ # tensor typing
2933+
2934+ self .da = dim_atom
2935+
29212936 @typecheck
29222937 def forward (
29232938 self ,
@@ -2927,7 +2942,9 @@ def forward(
29272942 pairwise_repr : Float ['b n n dp' ],
29282943 pred_atom_pos : Float ['b n 3' ] | Float ['b m 3' ],
29292944 molecule_atom_indices : Int ['b n' ] | None = None ,
2945+ molecule_atom_lens : Int ['b n' ] | None = None ,
29302946 mask : Bool ['b n' ] | None = None ,
2947+ atom_feats : Float ['b m {self.da}' ] | None = None ,
29312948 return_pae_logits = True
29322949
29332950 ) -> ConfidenceHeadLogits :
@@ -2938,16 +2955,24 @@ def forward(
29382955
29392956 is_atom_seq = pred_atom_pos .shape [- 2 ] > single_inputs_repr .shape [- 2 ]
29402957
2941- assert not is_atom_seq or exists (molecule_atom_indices )
2958+ # handle atom resolution vs not
2959+
2960+ if self .atom_resolution :
2961+ assert exists (atom_feats ), 'atom_feats must be passed in if atom_resolution is turned on for ConfidenceHead'
2962+ assert is_atom_seq , '`pred_atom_pos` must be passed in with atomic length'
2963+ assert exists (molecule_atom_lens )
29422964
29432965 if is_atom_seq :
2944- pred_atom_pos = einx .get_at ('b [m] c, b n -> b n c' , pred_atom_pos , molecule_atom_indices )
2966+ assert exists (molecule_atom_indices ), 'molecule_atom_indices must be passed into ConfidenceHead if pred_atom_pos is atomic length'
2967+ pred_molecule_pos = einx .get_at ('b [m] c, b n -> b n c' , pred_atom_pos , molecule_atom_indices )
2968+ else :
2969+ pred_molecule_pos = pred_atom_pos
29452970
29462971 # interatomic distances - embed and add to pairwise
29472972
2948- interatom_dist = torch .cdist (pred_atom_pos , pred_atom_pos , p = 2 )
2973+ intermolecule_dist = torch .cdist (pred_molecule_pos , pred_molecule_pos , p = 2 )
29492974
2950- dist_from_dist_bins = einx .subtract ('b m dist, dist_bins -> b m dist dist_bins' , interatom_dist , self .atompair_dist_bins ).abs ()
2975+ dist_from_dist_bins = einx .subtract ('b m dist, dist_bins -> b m dist dist_bins' , intermolecule_dist , self .atompair_dist_bins ).abs ()
29512976 dist_bin_indices = dist_from_dist_bins .argmin (dim = - 1 )
29522977 pairwise_repr = pairwise_repr + self .dist_bin_pairwise_embed (dist_bin_indices )
29532978
@@ -2959,6 +2984,27 @@ def forward(
29592984 mask = mask
29602985 )
29612986
2987+ # handle maybe atom level resolution
2988+
2989+ if self .atom_resolution :
2990+ single_repr = repeat_consecutive_with_lens (single_repr , molecule_atom_lens )
2991+
2992+ pairwise_repr = repeat_consecutive_with_lens (pairwise_repr , molecule_atom_lens )
2993+
2994+ molecule_atom_lens = repeat (molecule_atom_lens , 'b ... -> (b r) ...' , r = pairwise_repr .shape [1 ])
2995+ pairwise_repr , ps = pack_one (pairwise_repr , '* n d' )
2996+ pairwise_repr = repeat_consecutive_with_lens (pairwise_repr , molecule_atom_lens )
2997+ pairwise_repr = unpack_one (pairwise_repr , ps , '* n d' )
2998+
2999+ interatomic_dist = torch .cdist (pred_atom_pos , pred_atom_pos , p = 2 )
3000+
3001+ dist_from_dist_bins = einx .subtract ('b m dist, dist_bins -> b m dist dist_bins' , interatomic_dist , self .atompair_dist_bins ).abs ()
3002+ dist_bin_indices = dist_from_dist_bins .argmin (dim = - 1 )
3003+ pairwise_repr = pairwise_repr + self .dist_bin_pairwise_embed (dist_bin_indices )
3004+
3005+ single_repr = single_repr + self .atom_feats_to_single (atom_feats )
3006+ pairwise_repr = pairwise_repr + self .atom_feats_to_pairwise (atom_feats )
3007+
29623008 # to logits
29633009
29643010 symmetric_pairwise_repr = pairwise_repr + rearrange (pairwise_repr , 'b i j d -> b j i d' )
0 commit comments