|
37 | 37 | atom_ref_pos_to_atompair_inputs |
38 | 38 | ) |
39 | 39 |
|
40 | | -def join(str, delimiter = ','): |
41 | | - return delimiter.join(str) |
42 | | - |
43 | 40 | def test_atom_ref_pos_to_atompair_inputs(): |
44 | 41 | atom_ref_pos = torch.randn(16, 3) |
45 | 42 | atom_ref_space_uid = torch.ones(16).long() |
@@ -409,14 +406,8 @@ def test_distogram_head(): |
409 | 406 |
|
410 | 407 | logits = distogram_head(pairwise_repr) |
411 | 408 |
|
412 | | -@pytest.mark.parametrize( |
413 | | - join([ |
414 | | - 'window_atompair_inputs', |
415 | | - 'stochastic_frame_average' |
416 | | - ]), [ |
417 | | - (True, False), |
418 | | - (True, False) |
419 | | - ]) |
| 409 | +@pytest.mark.parametrize('window_atompair_inputs', (True, False)) |
| 410 | +@pytest.mark.parametrize('stochastic_frame_average', (True, False)) |
420 | 411 | def test_alphafold3( |
421 | 412 | window_atompair_inputs: bool, |
422 | 413 | stochastic_frame_average: bool |
@@ -572,6 +563,78 @@ def test_alphafold3_without_msa_and_templates(): |
572 | 563 |
|
573 | 564 | loss.backward() |
574 | 565 |
|
| 566 | +def test_alphafold3_force_return_loss(): |
| 567 | + seq_len = 16 |
| 568 | + molecule_atom_lens = torch.randint(1, 3, (2, seq_len)) |
| 569 | + atom_seq_len = molecule_atom_lens.sum(dim = -1).amax() |
| 570 | + |
| 571 | + atom_inputs = torch.randn(2, atom_seq_len, 77) |
| 572 | + atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5) |
| 573 | + additional_molecule_feats = torch.randn(2, seq_len, 10) |
| 574 | + |
| 575 | + atom_pos = torch.randn(2, atom_seq_len, 3) |
| 576 | + molecule_atom_indices = molecule_atom_lens - 1 |
| 577 | + |
| 578 | + distance_labels = torch.randint(0, 38, (2, seq_len, seq_len)) |
| 579 | + pae_labels = torch.randint(0, 64, (2, seq_len, seq_len)) |
| 580 | + pde_labels = torch.randint(0, 64, (2, seq_len, seq_len)) |
| 581 | + plddt_labels = torch.randint(0, 50, (2, seq_len)) |
| 582 | + resolved_labels = torch.randint(0, 2, (2, seq_len)) |
| 583 | + |
| 584 | + alphafold3 = Alphafold3( |
| 585 | + dim_atom_inputs = 77, |
| 586 | + dim_template_feats = 44, |
| 587 | + num_dist_bins = 38, |
| 588 | + confidence_head_kwargs = dict( |
| 589 | + pairformer_depth = 1 |
| 590 | + ), |
| 591 | + template_embedder_kwargs = dict( |
| 592 | + pairformer_stack_depth = 1 |
| 593 | + ), |
| 594 | + msa_module_kwargs = dict( |
| 595 | + depth = 1 |
| 596 | + ), |
| 597 | + pairformer_stack = dict( |
| 598 | + depth = 2 |
| 599 | + ), |
| 600 | + diffusion_module_kwargs = dict( |
| 601 | + atom_encoder_depth = 1, |
| 602 | + token_transformer_depth = 1, |
| 603 | + atom_decoder_depth = 1, |
| 604 | + ), |
| 605 | + ) |
| 606 | + |
| 607 | + sampled_atom_pos = alphafold3( |
| 608 | + num_recycling_steps = 2, |
| 609 | + atom_inputs = atom_inputs, |
| 610 | + molecule_atom_lens = molecule_atom_lens, |
| 611 | + atompair_inputs = atompair_inputs, |
| 612 | + additional_molecule_feats = additional_molecule_feats, |
| 613 | + atom_pos = atom_pos, |
| 614 | + molecule_atom_indices = molecule_atom_indices, |
| 615 | + distance_labels = distance_labels, |
| 616 | + pae_labels = pae_labels, |
| 617 | + pde_labels = pde_labels, |
| 618 | + plddt_labels = plddt_labels, |
| 619 | + resolved_labels = resolved_labels, |
| 620 | + return_loss_breakdown = True, |
| 621 | + return_loss = False # force sampling even if labels are given |
| 622 | + ) |
| 623 | + |
| 624 | + assert sampled_atom_pos.ndim == 3 |
| 625 | + |
| 626 | + loss, _ = alphafold3( |
| 627 | + num_recycling_steps = 2, |
| 628 | + atom_inputs = atom_inputs, |
| 629 | + molecule_atom_lens = molecule_atom_lens, |
| 630 | + atompair_inputs = atompair_inputs, |
| 631 | + additional_molecule_feats = additional_molecule_feats, |
| 632 | + return_loss_breakdown = True, |
| 633 | + return_loss = True # force returning loss even if no labels given |
| 634 | + ) |
| 635 | + |
| 636 | + assert loss == 0. |
| 637 | + |
575 | 638 | # test creation from config |
576 | 639 |
|
577 | 640 | def test_alphafold3_config(): |
|
0 commit comments