@@ -504,7 +504,7 @@ def test_alphafold3(
504504 template_feats = torch .randn (2 , 2 , seq_len , seq_len , 44 )
505505 template_mask = torch .ones ((2 , 2 )).bool ()
506506
507- msa = torch .randn (2 , 7 , seq_len , 64 )
507+ msa = torch .randn (2 , 7 , seq_len , 8 )
508508 msa_mask = torch .ones ((2 , 7 )).bool ()
509509
510510 atom_pos = torch .randn (2 , atom_seq_len , 3 )
@@ -519,7 +519,9 @@ def test_alphafold3(
519519
520520 alphafold3 = Alphafold3 (
521521 dim_atom_inputs = 77 ,
522- dim_pairwise = 64 ,
522+ dim_pairwise = 8 ,
523+ dim_single = 8 ,
524+ dim_token = 8 ,
523525 atoms_per_window = atoms_per_window ,
524526 dim_template_feats = 44 ,
525527 num_dist_bins = 38 ,
@@ -531,15 +533,28 @@ def test_alphafold3(
531533 pairformer_stack_depth = 1
532534 ),
533535 msa_module_kwargs = dict (
534- depth = 1
536+ depth = 1 ,
537+ dim_msa = 8 ,
535538 ),
536- pairformer_stack = dict (
537- depth = 2
539+ pairformer_stack = dict (
540+ depth = 1 ,
541+ pair_bias_attn_dim_head = 4 ,
542+ pair_bias_attn_heads = 2 ,
538543 ),
539- diffusion_module_kwargs = dict (
540- atom_encoder_depth = 1 ,
541- token_transformer_depth = 1 ,
542- atom_decoder_depth = 1 ,
544+ diffusion_module_kwargs = dict (
545+ atom_encoder_depth = 1 ,
546+ token_transformer_depth = 1 ,
547+ atom_decoder_depth = 1 ,
548+ atom_decoder_kwargs = dict (
549+ attn_pair_bias_kwargs = dict (
550+ dim_head = 4
551+ )
552+ ),
553+ atom_encoder_kwargs = dict (
554+ attn_pair_bias_kwargs = dict (
555+ dim_head = 4
556+ )
557+ )
543558 ),
544559 stochastic_frame_average = stochastic_frame_average ,
545560 confidence_head_atom_resolution = confidence_head_atom_resolution
@@ -569,6 +584,7 @@ def test_alphafold3(
569584 pde_labels = pde_labels ,
570585 plddt_labels = plddt_labels ,
571586 resolved_labels = resolved_labels ,
587+ num_rollout_steps = 1 ,
572588 diffusion_add_smooth_lddt_loss = True ,
573589 return_loss_breakdown = True
574590 )
0 commit comments