@@ -3520,8 +3520,7 @@ def forward(
35203520
35213521 intermolecule_dist = torch .cdist (pred_molecule_pos , pred_molecule_pos , p = 2 )
35223522
3523- dist_from_dist_bins = einx .subtract ('b m dist, dist_bins -> b m dist dist_bins' , intermolecule_dist , self .atompair_dist_bins ).abs ()
3524- dist_bin_indices = dist_from_dist_bins .argmin (dim = - 1 )
3523+ dist_bin_indices = distance_to_bins (intermolecule_dist , self .atompair_dist_bins )
35253524 pairwise_repr = pairwise_repr + self .dist_bin_pairwise_embed (dist_bin_indices )
35263525
35273526 # pairformer stack
@@ -3546,8 +3545,7 @@ def forward(
35463545
35473546 interatomic_dist = torch .cdist (pred_atom_pos , pred_atom_pos , p = 2 )
35483547
3549- dist_from_dist_bins = einx .subtract ('b m dist, dist_bins -> b m dist dist_bins' , interatomic_dist , self .atompair_dist_bins ).abs ()
3550- dist_bin_indices = dist_from_dist_bins .argmin (dim = - 1 )
3548+ dist_bin_indices = distance_to_bins (interatomic_dist , self .atompair_dist_bins )
35513549 pairwise_repr = pairwise_repr + self .dist_bin_pairwise_embed (dist_bin_indices )
35523550
35533551 single_repr = single_repr + self .atom_feats_to_single (atom_feats )
@@ -4889,11 +4887,12 @@ def __init__(
48894887 num_atompair_embeds : int | None = None ,
48904888 num_molecule_mods : int | None = DEFAULT_NUM_MOLECULE_MODS ,
48914889 distance_bins : List [float ] = torch .linspace (3 , 20 , 38 ).float ().tolist (),
4890+ pae_bins : List [float ] = torch .linspace (0.5 , 32 , 64 ).float ().tolist (),
48924891 ignore_index = - 1 ,
48934892 num_dist_bins : int | None = None ,
48944893 num_plddt_bins = 50 ,
48954894 num_pde_bins = 64 ,
4896- num_pae_bins = 64 ,
4895+ num_pae_bins : int | None = None ,
48974896 sigma_data = 16 ,
48984897 num_rollout_steps = 20 ,
48994898 diffusion_num_augmentations = 4 ,
@@ -5126,21 +5125,30 @@ def __init__(
51265125 ** edm_kwargs
51275126 )
51285127
5128+ self .num_rollout_steps = num_rollout_steps
5129+
51295130 # logit heads
51305131
51315132 distance_bins_tensor = Tensor (distance_bins )
51325133
51335134 self .register_buffer ('distance_bins' , distance_bins_tensor )
51345135 num_dist_bins = default (num_dist_bins , len (distance_bins_tensor ))
51355136
5137+
51365138 assert len (distance_bins_tensor ) == num_dist_bins , '`distance_bins` must have a length equal to the `num_dist_bins` passed in'
51375139
51385140 self .distogram_head = DistogramHead (
51395141 dim_pairwise = dim_pairwise ,
51405142 num_dist_bins = num_dist_bins
51415143 )
51425144
5143- self .num_rollout_steps = num_rollout_steps
5145+ # pae bins
5146+
5147+ pae_bins_tensor = Tensor (pae_bins )
5148+ self .register_buffer ('pae_bins' , pae_bins_tensor )
5149+ num_pae_bins = len (pae_bins )
5150+
5151+ # confidence head
51445152
51455153 self .confidence_head = ConfidenceHead (
51465154 dim_single_inputs = dim_single_inputs ,
0 commit comments