@@ -3376,7 +3376,9 @@ def __init__(
33763376 self ,
33773377 * ,
33783378 dim_pairwise = 128 ,
3379- num_dist_bins = 38 , # think it is 38?
3379+ num_dist_bins = 38 ,
3380+ dim_atom = 128 ,
3381+ atom_resolution = False
33803382 ):
33813383 super ().__init__ ()
33823384
@@ -3385,13 +3387,42 @@ def __init__(
33853387 Rearrange ('b ... l -> b l ...' )
33863388 )
33873389
3390+ # atom resolution
3391+ # for now, just embed per atom distances, sum to atom features, project to pairwise dimension
3392+
3393+ self .atom_resolution = atom_resolution
3394+
3395+ if atom_resolution :
3396+ self .atom_feats_to_pairwise = LinearNoBiasThenOuterSum (dim_atom , dim_pairwise )
3397+
3398+ # tensor typing
3399+
3400+ self .da = dim_atom
3401+
33883402 @typecheck
33893403 def forward (
33903404 self ,
3391- pairwise_repr : Float ['b n n d' ]
3392- ) -> Float ['b l n n' ]:
3405+ pairwise_repr : Float ['b n n d' ],
3406+ molecule_atom_lens : Int ['b n' ] | None = None ,
3407+ atom_feats : Float ['b m {self.da}' ] | None = None ,
3408+ ) -> Float ['b l n n' ] | Float ['b l m m' ]:
3409+
3410+ if self .atom_resolution :
3411+ assert exists (molecule_atom_lens )
3412+ assert exists (atom_feats )
3413+
3414+ pairwise_repr = batch_repeat_interleave (pairwise_repr , molecule_atom_lens )
3415+
3416+ molecule_atom_lens = repeat (molecule_atom_lens , 'b ... -> (b r) ...' , r = pairwise_repr .shape [1 ])
3417+ pairwise_repr , unpack_one = pack_one (pairwise_repr , '* n d' )
3418+ pairwise_repr = batch_repeat_interleave (pairwise_repr , molecule_atom_lens )
3419+ pairwise_repr = unpack_one (pairwise_repr )
3420+
3421+ pairwise_repr = pairwise_repr + self .atom_feats_to_pairwise (atom_feats )
3422+
3423+ symmetrized_pairwise_repr = pairwise_repr + rearrange (pairwise_repr , 'b i j d -> b j i d' )
3424+ logits = self .to_distogram_logits (symmetrized_pairwise_repr )
33933425
3394- logits = self .to_distogram_logits (pairwise_repr )
33953426 return logits
33963427
33973428# confidence head
@@ -4973,6 +5004,7 @@ def __init__(
49735004 augment_kwargs : dict = dict (),
49745005 stochastic_frame_average = False ,
49755006 confidence_head_atom_resolution = False ,
5007+ distogram_atom_resolution = False ,
49765008 checkpoint_input_embedding = False ,
49775009 checkpoint_trunk_pairformer = False ,
49785010 checkpoint_diffusion_token_transformer = False ,
@@ -5147,9 +5179,13 @@ def __init__(
51475179
51485180 assert len (distance_bins_tensor ) == num_dist_bins , '`distance_bins` must have a length equal to the `num_dist_bins` passed in'
51495181
5182+ self .distogram_atom_resolution = distogram_atom_resolution
5183+
51505184 self .distogram_head = DistogramHead (
51515185 dim_pairwise = dim_pairwise ,
5152- num_dist_bins = num_dist_bins
5186+ dim_atom = dim_atom ,
5187+ num_dist_bins = num_dist_bins ,
5188+ atom_resolution = distogram_atom_resolution ,
51535189 )
51545190
51555191 # pae related bins and modules
@@ -5635,22 +5671,41 @@ def forward(
56355671 molecule_pos = None
56365672
56375673 if not exists (distance_labels ) and atom_pos_given and exists (distogram_atom_indices ):
5638- # molecule_pos = einx.get_at('b [m] c, b n -> b n c', atom_pos, distogram_atom_indices)
56395674
5640- distogram_atom_indices = repeat (distogram_atom_indices , 'b n -> b n c' , c = atom_pos .shape [- 1 ])
5641- molecule_pos = atom_pos .gather (1 , distogram_atom_indices )
5675+ distogram_pos = atom_pos
56425676
5643- molecule_dist = torch .cdist (molecule_pos , molecule_pos , p = 2 )
5644- distance_labels = distance_to_bins (molecule_dist , self .distance_bins )
5677+ if not self .distogram_atom_resolution :
5678+ # molecule_pos = einx.get_at('b [m] c, b n -> b n c', atom_pos, distogram_atom_indices)
5679+
5680+ distogram_atom_indices = repeat (distogram_atom_indices , 'b n -> b n c' , c = distogram_pos .shape [- 1 ])
5681+ molecule_pos = distogram_pos = distogram_pos .gather (1 , distogram_atom_indices )
5682+ distogram_mask = valid_distogram_mask
5683+ else :
5684+ distogram_mask = atom_mask
5685+
5686+ distogram_dist = torch .cdist (distogram_pos , distogram_pos , p = 2 )
5687+ distance_labels = distance_to_bins (distogram_dist , self .distance_bins )
56455688
56465689 # account for representative distogram atom missing from residue (-1 set on distogram_atom_indices field)
56475690
5648- valid_distogram_mask = to_pairwise_mask (valid_distogram_mask )
5649- distance_labels .masked_fill_ (~ valid_distogram_mask , ignore )
5691+ distogram_mask = to_pairwise_mask (distogram_mask )
5692+ distance_labels .masked_fill_ (~ distogram_mask , ignore )
56505693
56515694 if exists (distance_labels ):
5652- distance_labels = torch .where (pairwise_mask , distance_labels , ignore )
5653- distogram_logits = self .distogram_head (pairwise )
5695+
5696+ distogram_mask = pairwise_mask
5697+
5698+ if self .distogram_atom_resolution :
5699+ distogram_mask = to_pairwise_mask (atom_mask )
5700+
5701+ distance_labels = torch .where (distogram_mask , distance_labels , ignore )
5702+
5703+ distogram_logits = self .distogram_head (
5704+ pairwise ,
5705+ molecule_atom_lens = molecule_atom_lens ,
5706+ atom_feats = atom_feats
5707+ )
5708+
56545709 distogram_loss = F .cross_entropy (distogram_logits , distance_labels , ignore_index = ignore )
56555710
56565711 # otherwise, noise and make it learn to denoise
@@ -5771,7 +5826,12 @@ def forward(
57715826 denoised_molecule_pos = None
57725827
57735828 if not ch_atom_res :
5774- assert exists (molecule_pos ), '`distogram_atom_indices` must be passed in for calculating non-atomic PAE labels'
5829+ if not exists (molecule_pos ):
5830+ assert exists (distogram_atom_indices ), '`distogram_atom_indices` must be passed in for calculating non-atomic PAE labels'
5831+
5832+ distogram_atom_indices = repeat (distogram_atom_indices , 'b n -> b n c' , c = distogram_pos .shape [- 1 ])
5833+ molecule_pos = atom_pos .gather (1 , distogram_atom_indices )
5834+
57755835 denoised_molecule_pos = denoised_atom_pos .gather (1 , distogram_atom_indices )
57765836
57775837 # three_atoms = einx.get_at('b [m] c, b n three -> three b n c', atom_pos, atom_indices_for_frame)
0 commit comments