3636 ConfidenceHeadLogits ,
3737 ComputeModelSelectionScore ,
3838 ComputeModelSelectionScore ,
39- collate_inputs_to_batched_atom_input
39+ collate_inputs_to_batched_atom_input ,
40+ alphafold3_inputs_to_batched_atom_input ,
4041)
4142
4243from alphafold3_pytorch .mocks import MockAtomDataset
6162 molecule_to_atom_input ,
6263 pdb_input_to_molecule_input ,
6364 PDBInput ,
65+ Alphafold3Input ,
6466 PDBDataset ,
6567 default_extract_atom_feats_fn ,
6668 default_extract_atompair_feats_fn ,
@@ -1226,3 +1228,149 @@ def test_unresolved_protein_rasa():
12261228 molecule_atom_lens = batched_atom_input_dict ['molecule_atom_lens' ],
12271229 atom_pos = batched_atom_input_dict ['atom_pos' ],
12281230 atom_mask = ~ batched_atom_input_dict ['missing_atom_mask' ])
1231+
1232+ def test_readme1 ():
1233+ alphafold3 = Alphafold3 (
1234+ dim_atom_inputs = 77 ,
1235+ dim_template_feats = 44
1236+ )
1237+
1238+ # mock inputs
1239+
1240+ seq_len = 16
1241+ molecule_atom_lens = torch .randint (1 , 3 , (2 , seq_len ))
1242+ atom_seq_len = molecule_atom_lens .sum (dim = - 1 ).amax ()
1243+
1244+ atom_inputs = torch .randn (2 , atom_seq_len , 77 )
1245+ atompair_inputs = torch .randn (2 , atom_seq_len , atom_seq_len , 5 )
1246+
1247+ additional_molecule_feats = torch .randint (0 , 2 , (2 , seq_len , 5 ))
1248+ additional_token_feats = torch .randn (2 , seq_len , 33 )
1249+ is_molecule_types = torch .randint (0 , 2 , (2 , seq_len , 5 )).bool ()
1250+ is_molecule_mod = torch .randint (0 , 2 , (2 , seq_len , 4 )).bool ()
1251+ molecule_ids = torch .randint (0 , 32 , (2 , seq_len ))
1252+
1253+ template_feats = torch .randn (2 , 2 , seq_len , seq_len , 44 )
1254+ template_mask = torch .ones ((2 , 2 )).bool ()
1255+
1256+ msa = torch .randn (2 , 7 , seq_len , 32 )
1257+ msa_mask = torch .ones ((2 , 7 )).bool ()
1258+
1259+ additional_msa_feats = torch .randn (2 , 7 , seq_len , 2 )
1260+
1261+ # required for training, but omitted on inference
1262+
1263+ atom_pos = torch .randn (2 , atom_seq_len , 3 )
1264+
1265+ molecule_atom_indices = molecule_atom_lens - 1 # last atom, as an example
1266+ molecule_atom_indices += (molecule_atom_lens .cumsum (dim = - 1 ) - molecule_atom_lens )
1267+
1268+ distance_labels = torch .randint (0 , 37 , (2 , seq_len , seq_len ))
1269+ resolved_labels = torch .randint (0 , 2 , (2 , atom_seq_len ))
1270+
1271+ # train
1272+
1273+ loss = alphafold3 (
1274+ num_recycling_steps = 2 ,
1275+ atom_inputs = atom_inputs ,
1276+ atompair_inputs = atompair_inputs ,
1277+ molecule_ids = molecule_ids ,
1278+ molecule_atom_lens = molecule_atom_lens ,
1279+ additional_molecule_feats = additional_molecule_feats ,
1280+ additional_msa_feats = additional_msa_feats ,
1281+ additional_token_feats = additional_token_feats ,
1282+ is_molecule_types = is_molecule_types ,
1283+ is_molecule_mod = is_molecule_mod ,
1284+ msa = msa ,
1285+ msa_mask = msa_mask ,
1286+ templates = template_feats ,
1287+ template_mask = template_mask ,
1288+ atom_pos = atom_pos ,
1289+ molecule_atom_indices = molecule_atom_indices ,
1290+ distance_labels = distance_labels ,
1291+ resolved_labels = resolved_labels
1292+ )
1293+
1294+ loss .backward ()
1295+
1296+ # after much training ...
1297+
1298+ sampled_atom_pos = alphafold3 (
1299+ num_recycling_steps = 4 ,
1300+ num_sample_steps = 16 ,
1301+ atom_inputs = atom_inputs ,
1302+ atompair_inputs = atompair_inputs ,
1303+ molecule_ids = molecule_ids ,
1304+ molecule_atom_lens = molecule_atom_lens ,
1305+ additional_molecule_feats = additional_molecule_feats ,
1306+ additional_msa_feats = additional_msa_feats ,
1307+ additional_token_feats = additional_token_feats ,
1308+ is_molecule_types = is_molecule_types ,
1309+ is_molecule_mod = is_molecule_mod ,
1310+ msa = msa ,
1311+ msa_mask = msa_mask ,
1312+ templates = template_feats ,
1313+ template_mask = template_mask
1314+ )
1315+
1316+ sampled_atom_pos .shape # (2, <atom_seqlen>, 3)
1317+ assert sampled_atom_pos .ndim == 3
1318+
1319+ def test_readme2 ():
1320+ contrived_protein = 'AG'
1321+
1322+ mock_atompos = [
1323+ torch .randn (5 , 3 ), # alanine has 5 non-hydrogen atoms
1324+ torch .randn (4 , 3 ) # glycine has 4 non-hydrogen atoms
1325+ ]
1326+
1327+ train_alphafold3_input = Alphafold3Input (
1328+ proteins = [contrived_protein ],
1329+ atom_pos = mock_atompos
1330+ )
1331+
1332+ eval_alphafold3_input = Alphafold3Input (
1333+ proteins = [contrived_protein ]
1334+ )
1335+
1336+ batched_atom_input = alphafold3_inputs_to_batched_atom_input (train_alphafold3_input , atoms_per_window = 27 )
1337+
1338+ # training
1339+
1340+ alphafold3 = Alphafold3 (
1341+ dim_atom_inputs = 3 ,
1342+ dim_atompair_inputs = 5 ,
1343+ atoms_per_window = 27 ,
1344+ dim_template_feats = 44 ,
1345+ num_dist_bins = 38 ,
1346+ num_molecule_mods = 0 ,
1347+ confidence_head_kwargs = dict (
1348+ pairformer_depth = 1
1349+ ),
1350+ template_embedder_kwargs = dict (
1351+ pairformer_stack_depth = 1
1352+ ),
1353+ msa_module_kwargs = dict (
1354+ depth = 1
1355+ ),
1356+ pairformer_stack = dict (
1357+ depth = 2
1358+ ),
1359+ diffusion_module_kwargs = dict (
1360+ atom_encoder_depth = 1 ,
1361+ token_transformer_depth = 1 ,
1362+ atom_decoder_depth = 1 ,
1363+ )
1364+ )
1365+
1366+ loss = alphafold3 (** batched_atom_input .model_forward_dict ())
1367+ loss .backward ()
1368+
1369+ # sampling
1370+
1371+ batched_eval_atom_input = alphafold3_inputs_to_batched_atom_input (eval_alphafold3_input , atoms_per_window = 27 )
1372+
1373+ alphafold3 .eval ()
1374+ sampled_atom_pos = alphafold3 (** batched_eval_atom_input .model_forward_dict ())
1375+
1376+ assert sampled_atom_pos .shape == (1 , (5 + 4 ), 3 )
0 commit comments