@@ -3392,7 +3392,6 @@ def __init__(
33923392 dim_pairwise = 128 ,
33933393 num_dist_bins = 38 ,
33943394 dim_atom = 128 ,
3395- atom_resolution = False
33963395 ):
33973396 super ().__init__ ()
33983397
@@ -3404,10 +3403,7 @@ def __init__(
34043403 # atom resolution
34053404 # for now, just embed per atom distances, sum to atom features, project to pairwise dimension
34063405
3407- self .atom_resolution = atom_resolution
3408-
3409- if atom_resolution :
3410- self .atom_feats_to_pairwise = LinearNoBiasThenOuterSum (dim_atom , dim_pairwise )
3406+ self .atom_feats_to_pairwise = LinearNoBiasThenOuterSum (dim_atom , dim_pairwise )
34113407
34123408 # tensor typing
34133409
@@ -3417,16 +3413,12 @@ def __init__(
34173413 def forward (
34183414 self ,
34193415 pairwise_repr : Float ['b n n d' ],
3420- molecule_atom_lens : Int ['b n' ] | None = None ,
3421- atom_feats : Float ['b m {self.da}' ] | None = None ,
3416+ molecule_atom_lens : Int ['b n' ],
3417+ atom_feats : Float ['b m {self.da}' ],
34223418 ) -> Float ['b l n n' ] | Float ['b l m m' ]:
34233419
3424- if self .atom_resolution :
3425- assert exists (molecule_atom_lens )
3426- assert exists (atom_feats )
3427-
3428- pairwise_repr = batch_repeat_interleave_pairwise (pairwise_repr , molecule_atom_lens )
3429- pairwise_repr = pairwise_repr + self .atom_feats_to_pairwise (atom_feats )
3420+ pairwise_repr = batch_repeat_interleave_pairwise (pairwise_repr , molecule_atom_lens )
3421+ pairwise_repr = pairwise_repr + self .atom_feats_to_pairwise (atom_feats )
34303422
34313423 logits = self .to_distogram_logits (symmetrize (pairwise_repr ))
34323424
@@ -4989,7 +4981,6 @@ def __init__(
49894981 lddt_mask_other_cutoff = 15. ,
49904982 augment_kwargs : dict = dict (),
49914983 stochastic_frame_average = False ,
4992- distogram_atom_resolution = False ,
49934984 checkpoint_input_embedding = False ,
49944985 checkpoint_trunk_pairformer = False ,
49954986 checkpoint_distogram_head = False ,
@@ -5171,13 +5162,10 @@ def __init__(
51715162
51725163 assert len (distance_bins_tensor ) == num_dist_bins , '`distance_bins` must have a length equal to the `num_dist_bins` passed in'
51735164
5174- self .distogram_atom_resolution = distogram_atom_resolution
5175-
51765165 self .distogram_head = DistogramHead (
51775166 dim_pairwise = dim_pairwise ,
51785167 dim_atom = dim_atom ,
51795168 num_dist_bins = num_dist_bins ,
5180- atom_resolution = distogram_atom_resolution ,
51815169 )
51825170
51835171 # lddt related
@@ -5679,15 +5667,7 @@ def forward(
56795667 if not exists (distance_labels ) and atom_pos_given and exists (distogram_atom_indices ):
56805668
56815669 distogram_pos = atom_pos
5682-
5683- if not self .distogram_atom_resolution :
5684- # molecule_pos = einx.get_at('b [m] c, b n -> b n c', atom_pos, distogram_atom_indices)
5685-
5686- distogram_atom_indices = repeat (distogram_atom_indices , 'b n -> b n c' , c = distogram_pos .shape [- 1 ])
5687- molecule_pos = distogram_pos = distogram_pos .gather (1 , distogram_atom_indices )
5688- distogram_mask = valid_distogram_mask
5689- else :
5690- distogram_mask = atom_mask
5670+ distogram_mask = atom_mask
56915671
56925672 distogram_dist = torch .cdist (distogram_pos , distogram_pos , p = 2 )
56935673 distance_labels = distance_to_bins (distogram_dist , self .distance_bins )
@@ -5700,9 +5680,7 @@ def forward(
57005680 if exists (distance_labels ):
57015681
57025682 distogram_mask = pairwise_mask
5703-
5704- if self .distogram_atom_resolution :
5705- distogram_mask = to_pairwise_mask (atom_mask )
5683+ distogram_mask = to_pairwise_mask (atom_mask )
57065684
57075685 distance_labels = torch .where (distogram_mask , distance_labels , ignore )
57085686
0 commit comments