@@ -26,7 +26,6 @@ from alphafold3_pytorch import Alphafold3
2626
2727alphafold3 = Alphafold3(
2828 dim_atom_inputs = 77 ,
29- dim_additional_residue_feats = 33 ,
3029 dim_template_feats = 44
3130)
3231
@@ -38,12 +37,13 @@ atom_seq_len = seq_len * 27
3837atom_inputs = torch.randn(2 , atom_seq_len, 77 )
3938atom_lens = torch.randint(0 , 27 , (2 , seq_len))
4039atompair_feats = torch.randn(2 , atom_seq_len, atom_seq_len, 16 )
41- additional_residue_feats = torch.randn(2 , seq_len, 33 )
40+ additional_residue_feats = torch.randn(2 , seq_len, 10 )
4241
4342template_feats = torch.randn(2 , 2 , seq_len, seq_len, 44 )
4443template_mask = torch.ones((2 , 2 )).bool()
4544
4645msa = torch.randn(2 , 7 , seq_len, 64 )
46+ msa_mask = torch.ones((2 , 7 )).bool()
4747
4848# required for training, but omitted on inference
4949
@@ -65,6 +65,7 @@ loss = alphafold3(
6565 atompair_feats = atompair_feats,
6666 additional_residue_feats = additional_residue_feats,
6767 msa = msa,
68+ msa_mask = msa_mask,
6869 templates = template_feats,
6970 template_mask = template_mask,
7071 atom_pos = atom_pos,
@@ -84,10 +85,11 @@ sampled_atom_pos = alphafold3(
8485 num_recycling_steps = 4 ,
8586 num_sample_steps = 16 ,
8687 atom_inputs = atom_inputs,
87- atom_mask = atom_mask ,
88+ residue_atom_lens = atom_lens ,
8889 atompair_feats = atompair_feats,
8990 additional_residue_feats = additional_residue_feats,
9091 msa = msa,
92+ msa_mask = msa_mask,
9193 templates = template_feats,
9294 template_mask = template_mask
9395)
0 commit comments