@@ -698,7 +698,7 @@ def test_alphafold3(
698698 dim_token = 8 ,
699699 atoms_per_window = atoms_per_window ,
700700 dim_template_feats = 108 ,
701- num_dist_bins = 38 ,
701+ num_dist_bins = 64 ,
702702 num_molecule_mods = num_molecule_mods ,
703703 confidence_head_kwargs = dict (
704704 pairformer_depth = 1
@@ -804,7 +804,7 @@ def test_alphafold3_without_msa_and_templates():
804804 alphafold3 = Alphafold3 (
805805 dim_atom_inputs = 77 ,
806806 dim_template_feats = 108 ,
807- num_dist_bins = 38 ,
807+ num_dist_bins = 64 ,
808808 num_molecule_mods = 0 ,
809809 checkpoint_trunk_pairformer = True ,
810810 checkpoint_diffusion_module = True ,
@@ -871,7 +871,7 @@ def test_alphafold3_force_return_loss():
871871 alphafold3 = Alphafold3 (
872872 dim_atom_inputs = 77 ,
873873 dim_template_feats = 108 ,
874- num_dist_bins = 38 ,
874+ num_dist_bins = 64 ,
875875 num_molecule_mods = 0 ,
876876 confidence_head_kwargs = dict (
877877 pairformer_depth = 1
@@ -953,7 +953,7 @@ def test_alphafold3_force_return_loss_with_confidence_logits():
953953 alphafold3 = Alphafold3 (
954954 dim_atom_inputs = 77 ,
955955 dim_template_feats = 108 ,
956- num_dist_bins = 38 ,
956+ num_dist_bins = 64 ,
957957 num_molecule_mods = 0 ,
958958 confidence_head_kwargs = dict (
959959 pairformer_depth = 1
@@ -1170,7 +1170,7 @@ def test_model_selection_score():
11701170 atom_mask = torch .randint (0 , 2 , (atom_pos_true .shape [:- 1 ])).type_as (atom_pos_true ).bool ()
11711171 tok_repr_atm_mask = torch .randint (0 , 2 , (batch_size , seq_len )).bool ()
11721172
1173- dist_logits = torch .randn (batch_size , 38 , seq_len , seq_len )
1173+ dist_logits = torch .randn (batch_size , 64 , seq_len , seq_len )
11741174 pde_logits = torch .randn (batch_size , 64 , seq_len , seq_len )
11751175
11761176 chain_length = [random .randint (seq_len // 4 , seq_len // 2 )
@@ -1222,7 +1222,7 @@ def test_model_selection_score_end_to_end():
12221222 dim_token = 8 ,
12231223 atoms_per_window = 27 ,
12241224 dim_template_feats = 108 ,
1225- num_dist_bins = 38 ,
1225+ num_dist_bins = 64 ,
12261226 confidence_head_kwargs = dict (
12271227 pairformer_depth = 1
12281228 ),
@@ -1427,7 +1427,7 @@ def test_readme2():
14271427 dim_atompair_inputs = 5 ,
14281428 atoms_per_window = 27 ,
14291429 dim_template_feats = 108 ,
1430- num_dist_bins = 38 ,
1430+ num_dist_bins = 64 ,
14311431 num_molecule_mods = 0 ,
14321432 confidence_head_kwargs = dict (
14331433 pairformer_depth = 1
0 commit comments