@@ -36,12 +36,12 @@ alphafold3 = Alphafold3(
3636# mock inputs
3737
3838seq_len = 16
39- atom_seq_len = seq_len * 27
39+ residue_atom_lens = torch.randint(1 , 3 , (2 , seq_len))
40+ atom_seq_len = residue_atom_lens.sum(dim = - 1 ).amax()
4041
4142atom_inputs = torch.randn(2 , atom_seq_len, 77 )
4243atompair_inputs = torch.randn(2 , atom_seq_len, atom_seq_len, 5 )
4344
44- atom_lens = torch.randint(0 , 27 , (2 , seq_len))
4545additional_residue_feats = torch.randn(2 , seq_len, 10 )
4646
4747template_feats = torch.randn(2 , 2 , seq_len, seq_len, 44 )
@@ -53,7 +53,7 @@ msa_mask = torch.ones((2, 7)).bool()
5353# required for training, but omitted on inference
5454
5555atom_pos = torch.randn(2 , atom_seq_len, 3 )
56- residue_atom_indices = torch.randint( 0 , 27 , ( 2 , seq_len))
56+ residue_atom_indices = residue_atom_lens - 1 # last atom, as an example
5757
5858distance_labels = torch.randint(0 , 37 , (2 , seq_len, seq_len))
5959pae_labels = torch.randint(0 , 64 , (2 , seq_len, seq_len))
@@ -67,7 +67,7 @@ loss = alphafold3(
6767 num_recycling_steps = 2 ,
6868 atom_inputs = atom_inputs,
6969 atompair_inputs = atompair_inputs,
70- residue_atom_lens = atom_lens ,
70+ residue_atom_lens = residue_atom_lens ,
7171 additional_residue_feats = additional_residue_feats,
7272 msa = msa,
7373 msa_mask = msa_mask,
@@ -91,15 +91,15 @@ sampled_atom_pos = alphafold3(
9191 num_sample_steps = 16 ,
9292 atom_inputs = atom_inputs,
9393 atompair_inputs = atompair_inputs,
94- residue_atom_lens = atom_lens ,
94+ residue_atom_lens = residue_atom_lens ,
9595 additional_residue_feats = additional_residue_feats,
9696 msa = msa,
9797 msa_mask = msa_mask,
9898 templates = template_feats,
9999 template_mask = template_mask
100100)
101101
102- sampled_atom_pos.shape # (2, 16 * 27 , 3)
102+ sampled_atom_pos.shape # (2, <atom_seqlen> , 3)
103103```
104104
105105## Contributing
0 commit comments