@@ -77,6 +77,7 @@ atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
7777additional_molecule_feats = torch.randint(0 , 2 , (2 , seq_len, 5 ))
7878additional_token_feats = torch.randn(2 , seq_len, 2 )
7979is_molecule_types = torch.randint(0 , 2 , (2 , seq_len, 5 )).bool()
80+ is_molecule_mod = torch.randint(0 , 2 , (2 , seq_len, 4 )).bool()
8081molecule_ids = torch.randint(0 , 32 , (2 , seq_len))
8182
8283template_feats = torch.randn(2 , 2 , seq_len, seq_len, 44 )
@@ -107,6 +108,7 @@ loss = alphafold3(
107108 additional_molecule_feats = additional_molecule_feats,
108109 additional_token_feats = additional_token_feats,
109110 is_molecule_types = is_molecule_types,
111+ is_molecule_mod = is_molecule_mod,
110112 msa = msa,
111113 msa_mask = msa_mask,
112114 templates = template_feats,
@@ -134,6 +136,7 @@ sampled_atom_pos = alphafold3(
134136 additional_molecule_feats = additional_molecule_feats,
135137 additional_token_feats = additional_token_feats,
136138 is_molecule_types = is_molecule_types,
139+ is_molecule_mod = is_molecule_mod,
137140 msa = msa,
138141 msa_mask = msa_mask,
139142 templates = template_feats,
@@ -180,6 +183,7 @@ alphafold3 = Alphafold3(
180183 atoms_per_window = 27 ,
181184 dim_template_feats = 44 ,
182185 num_dist_bins = 38 ,
186+ num_molecule_mods = 0 ,
183187 confidence_head_kwargs = dict (
184188 pairformer_depth = 1
185189 ),
0 commit comments