|
37 | 37 | atom_ref_pos_to_atompair_inputs |
38 | 38 | ) |
39 | 39 |
|
| 40 | +from alphafold3_pytorch.inputs import ( |
| 41 | + IS_MOLECULE_TYPES |
| 42 | +) |
| 43 | + |
40 | 44 | def test_atom_ref_pos_to_atompair_inputs(): |
41 | 45 | atom_ref_pos = torch.randn(16, 3) |
42 | 46 | atom_ref_space_uid = torch.ones(16).long() |
@@ -444,7 +448,7 @@ def test_alphafold3( |
444 | 448 |
|
445 | 449 | additional_molecule_feats = torch.randint(0, 2, (2, seq_len, 5)) |
446 | 450 | additional_token_feats = torch.randn(2, 16, 2) |
447 | | - is_molecule_types = torch.randint(0, 2, (2, seq_len, 5)).bool() |
| 451 | + is_molecule_types = torch.randint(0, 2, (2, seq_len, IS_MOLECULE_TYPES)).bool() |
448 | 452 | molecule_ids = torch.randint(0, 32, (2, seq_len)) |
449 | 453 |
|
450 | 454 | is_molecule_mod = None |
@@ -556,7 +560,7 @@ def test_alphafold3_without_msa_and_templates(): |
556 | 560 | atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5) |
557 | 561 | additional_molecule_feats = torch.randint(0, 2, (2, seq_len, 5)) |
558 | 562 | additional_token_feats = torch.randn(2, seq_len, 2) |
559 | | - is_molecule_types = torch.randint(0, 2, (2, seq_len, 5)).bool() |
| 563 | + is_molecule_types = torch.randint(0, 2, (2, seq_len, IS_MOLECULE_TYPES)).bool() |
560 | 564 | molecule_ids = torch.randint(0, 32, (2, seq_len)) |
561 | 565 |
|
562 | 566 | atom_pos = torch.randn(2, atom_seq_len, 3) |
@@ -621,7 +625,7 @@ def test_alphafold3_force_return_loss(): |
621 | 625 | atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5) |
622 | 626 | additional_molecule_feats = torch.randint(0, 2, (2, seq_len, 5)) |
623 | 627 | additional_token_feats = torch.randn(2, seq_len, 2) |
624 | | - is_molecule_types = torch.randint(0, 2, (2, seq_len, 4)).bool() |
| 628 | + is_molecule_types = torch.randint(0, 2, (2, seq_len, IS_MOLECULE_TYPES)).bool() |
625 | 629 | molecule_ids = torch.randint(0, 32, (2, seq_len)) |
626 | 630 |
|
627 | 631 | atom_pos = torch.randn(2, atom_seq_len, 3) |
@@ -716,7 +720,7 @@ def test_alphafold3_with_atom_and_bond_embeddings(): |
716 | 720 |
|
717 | 721 | additional_molecule_feats = torch.randint(0, 2, (2, seq_len, 5)) |
718 | 722 | additional_token_feats = torch.randn(2, seq_len, 2) |
719 | | - is_molecule_types = torch.randint(0, 2, (2, seq_len, 5)).bool() |
| 723 | + is_molecule_types = torch.randint(0, 2, (2, seq_len, IS_MOLECULE_TYPES)).bool() |
720 | 724 | molecule_ids = torch.randint(0, 32, (2, seq_len)) |
721 | 725 |
|
722 | 726 | template_feats = torch.randn(2, 2, seq_len, seq_len, 44) |
|
0 commit comments