@@ -472,6 +472,7 @@ def test_alphafold3(
472472
473473 atom_pos = torch .randn (2 , atom_seq_len , 3 )
474474 distogram_atom_indices = molecule_atom_lens - 1
475+ molecule_atom_indices = molecule_atom_lens - 1
475476
476477 pae_labels = torch .randint (0 , 64 , (2 , seq_len , seq_len ))
477478 pde_labels = torch .randint (0 , 64 , (2 , seq_len , seq_len ))
@@ -524,6 +525,7 @@ def test_alphafold3(
524525 template_mask = template_mask ,
525526 atom_pos = atom_pos ,
526527 distogram_atom_indices = distogram_atom_indices ,
528+ molecule_atom_indices = molecule_atom_indices ,
527529 pae_labels = pae_labels ,
528530 pde_labels = pde_labels ,
529531 plddt_labels = plddt_labels ,
@@ -630,6 +632,7 @@ def test_alphafold3_force_return_loss():
630632
631633 atom_pos = torch .randn (2 , atom_seq_len , 3 )
632634 distogram_atom_indices = molecule_atom_lens - 1
635+ molecule_atom_indices = molecule_atom_lens - 1
633636
634637 distance_labels = torch .randint (0 , 38 , (2 , seq_len , seq_len ))
635638 pae_labels = torch .randint (0 , 64 , (2 , seq_len , seq_len ))
@@ -671,6 +674,7 @@ def test_alphafold3_force_return_loss():
671674 additional_token_feats = additional_token_feats ,
672675 atom_pos = atom_pos ,
673676 distogram_atom_indices = distogram_atom_indices ,
677+ molecule_atom_indices = molecule_atom_indices ,
674678 distance_labels = distance_labels ,
675679 pae_labels = pae_labels ,
676680 pde_labels = pde_labels ,
@@ -682,6 +686,91 @@ def test_alphafold3_force_return_loss():
682686
683687 assert sampled_atom_pos .ndim == 3
684688
689+ loss , _ = alphafold3 (
690+ num_recycling_steps = 2 ,
691+ atom_inputs = atom_inputs ,
692+ molecule_ids = molecule_ids ,
693+ molecule_atom_lens = molecule_atom_lens ,
694+ atompair_inputs = atompair_inputs ,
695+ is_molecule_types = is_molecule_types ,
696+ additional_molecule_feats = additional_molecule_feats ,
697+ additional_token_feats = additional_token_feats ,
698+ molecule_atom_indices = molecule_atom_indices ,
699+ return_loss_breakdown = True ,
700+ return_loss = True # force returning loss even if no labels given
701+ )
702+
703+ assert loss == 0.
704+
705+ def test_alphafold3_force_return_loss_with_confidence_logits ():
706+ seq_len = 16
707+ molecule_atom_lens = torch .randint (1 , 3 , (2 , seq_len ))
708+ atom_seq_len = molecule_atom_lens .sum (dim = - 1 ).amax ()
709+
710+ atom_inputs = torch .randn (2 , atom_seq_len , 77 )
711+ atompair_inputs = torch .randn (2 , atom_seq_len , atom_seq_len , 5 )
712+ additional_molecule_feats = torch .randint (0 , 2 , (2 , seq_len , 5 ))
713+ additional_token_feats = torch .randn (2 , seq_len , 2 )
714+ is_molecule_types = torch .randint (0 , 2 , (2 , seq_len , IS_MOLECULE_TYPES )).bool ()
715+ molecule_ids = torch .randint (0 , 32 , (2 , seq_len ))
716+
717+ atom_pos = torch .randn (2 , atom_seq_len , 3 )
718+ distogram_atom_indices = molecule_atom_lens - 1
719+ molecule_atom_indices = molecule_atom_lens - 1
720+
721+ distance_labels = torch .randint (0 , 38 , (2 , seq_len , seq_len ))
722+ pae_labels = torch .randint (0 , 64 , (2 , seq_len , seq_len ))
723+ pde_labels = torch .randint (0 , 64 , (2 , seq_len , seq_len ))
724+ plddt_labels = torch .randint (0 , 50 , (2 , seq_len ))
725+ resolved_labels = torch .randint (0 , 2 , (2 , seq_len ))
726+
727+ alphafold3 = Alphafold3 (
728+ dim_atom_inputs = 77 ,
729+ dim_template_feats = 44 ,
730+ num_dist_bins = 38 ,
731+ confidence_head_kwargs = dict (
732+ pairformer_depth = 1
733+ ),
734+ template_embedder_kwargs = dict (
735+ pairformer_stack_depth = 1
736+ ),
737+ msa_module_kwargs = dict (
738+ depth = 1
739+ ),
740+ pairformer_stack = dict (
741+ depth = 2
742+ ),
743+ diffusion_module_kwargs = dict (
744+ atom_encoder_depth = 1 ,
745+ token_transformer_depth = 1 ,
746+ atom_decoder_depth = 1 ,
747+ ),
748+ )
749+
750+ sampled_atom_pos , confidence_head_logits = alphafold3 (
751+ num_recycling_steps = 2 ,
752+ atom_inputs = atom_inputs ,
753+ molecule_ids = molecule_ids ,
754+ molecule_atom_lens = molecule_atom_lens ,
755+ atompair_inputs = atompair_inputs ,
756+ is_molecule_types = is_molecule_types ,
757+ additional_molecule_feats = additional_molecule_feats ,
758+ additional_token_feats = additional_token_feats ,
759+ atom_pos = atom_pos ,
760+ distogram_atom_indices = distogram_atom_indices ,
761+ molecule_atom_indices = molecule_atom_indices ,
762+ distance_labels = distance_labels ,
763+ pae_labels = pae_labels ,
764+ pde_labels = pde_labels ,
765+ plddt_labels = plddt_labels ,
766+ resolved_labels = resolved_labels ,
767+ return_loss_breakdown = True ,
768+ return_loss = False , # force sampling even if labels are given
769+ return_confidence_head_logits = True
770+ )
771+
772+ assert sampled_atom_pos .ndim == 3
773+
685774 loss , _ = alphafold3 (
686775 num_recycling_steps = 2 ,
687776 atom_inputs = atom_inputs ,
@@ -733,6 +822,7 @@ def test_alphafold3_with_atom_and_bond_embeddings():
733822
734823 atom_pos = torch .randn (2 , atom_seq_len , 3 )
735824 distogram_atom_indices = molecule_atom_lens - 1 # last atom, as an example
825+ molecule_atom_indices = molecule_atom_lens - 1
736826
737827 distance_labels = torch .randint (0 , 37 , (2 , seq_len , seq_len ))
738828 pae_labels = torch .randint (0 , 64 , (2 , seq_len , seq_len ))
@@ -759,6 +849,7 @@ def test_alphafold3_with_atom_and_bond_embeddings():
759849 template_mask = template_mask ,
760850 atom_pos = atom_pos ,
761851 distogram_atom_indices = distogram_atom_indices ,
852+ molecule_atom_indices = molecule_atom_indices ,
762853 distance_labels = distance_labels ,
763854 pae_labels = pae_labels ,
764855 pde_labels = pde_labels ,
0 commit comments