@@ -38,13 +38,13 @@ alphafold3 = Alphafold3(
3838# mock inputs
3939
4040seq_len = 16
41- residue_atom_lens = torch.randint(1 , 3 , (2 , seq_len))
42- atom_seq_len = residue_atom_lens .sum(dim = - 1 ).amax()
41+ molecule_atom_lens = torch.randint(1 , 3 , (2 , seq_len))
42+ atom_seq_len = molecule_atom_lens .sum(dim = - 1 ).amax()
4343
4444atom_inputs = torch.randn(2 , atom_seq_len, 77 )
4545atompair_inputs = torch.randn(2 , atom_seq_len, atom_seq_len, 5 )
4646
47- additional_residue_feats = torch.randn(2 , seq_len, 10 )
47+ additional_molecule_feats = torch.randn(2 , seq_len, 10 )
4848
4949template_feats = torch.randn(2 , 2 , seq_len, seq_len, 44 )
5050template_mask = torch.ones((2 , 2 )).bool()
@@ -55,7 +55,7 @@ msa_mask = torch.ones((2, 7)).bool()
5555# required for training, but omitted on inference
5656
5757atom_pos = torch.randn(2 , atom_seq_len, 3 )
58- residue_atom_indices = residue_atom_lens - 1 # last atom, as an example
58+ molecule_atom_indices = molecule_atom_lens - 1 # last atom, as an example
5959
6060distance_labels = torch.randint(0 , 37 , (2 , seq_len, seq_len))
6161pae_labels = torch.randint(0 , 64 , (2 , seq_len, seq_len))
@@ -69,14 +69,14 @@ loss = alphafold3(
6969 num_recycling_steps = 2 ,
7070 atom_inputs = atom_inputs,
7171 atompair_inputs = atompair_inputs,
72- residue_atom_lens = residue_atom_lens ,
73- additional_residue_feats = additional_residue_feats ,
72+ molecule_atom_lens = molecule_atom_lens ,
73+ additional_molecule_feats = additional_molecule_feats ,
7474 msa = msa,
7575 msa_mask = msa_mask,
7676 templates = template_feats,
7777 template_mask = template_mask,
7878 atom_pos = atom_pos,
79- residue_atom_indices = residue_atom_indices ,
79+ molecule_atom_indices = molecule_atom_indices ,
8080 distance_labels = distance_labels,
8181 pae_labels = pae_labels,
8282 pde_labels = pde_labels,
@@ -93,8 +93,8 @@ sampled_atom_pos = alphafold3(
9393 num_sample_steps = 16 ,
9494 atom_inputs = atom_inputs,
9595 atompair_inputs = atompair_inputs,
96- residue_atom_lens = residue_atom_lens ,
97- additional_residue_feats = additional_residue_feats ,
96+ molecule_atom_lens = molecule_atom_lens ,
97+ additional_molecule_feats = additional_molecule_feats ,
9898 msa = msa,
9999 msa_mask = msa_mask,
100100 templates = template_feats,
0 commit comments