Skip to content

Commit 673989d

Browse files
committed
fix tests
1 parent 6d34c69 commit 673989d

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

tests/test_af3.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@
3737
atom_ref_pos_to_atompair_inputs
3838
)
3939

40+
from alphafold3_pytorch.inputs import (
41+
IS_MOLECULE_TYPES
42+
)
43+
4044
def test_atom_ref_pos_to_atompair_inputs():
4145
atom_ref_pos = torch.randn(16, 3)
4246
atom_ref_space_uid = torch.ones(16).long()
@@ -444,7 +448,7 @@ def test_alphafold3(
444448

445449
additional_molecule_feats = torch.randint(0, 2, (2, seq_len, 5))
446450
additional_token_feats = torch.randn(2, 16, 2)
447-
is_molecule_types = torch.randint(0, 2, (2, seq_len, 5)).bool()
451+
is_molecule_types = torch.randint(0, 2, (2, seq_len, IS_MOLECULE_TYPES)).bool()
448452
molecule_ids = torch.randint(0, 32, (2, seq_len))
449453

450454
is_molecule_mod = None
@@ -556,7 +560,7 @@ def test_alphafold3_without_msa_and_templates():
556560
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
557561
additional_molecule_feats = torch.randint(0, 2, (2, seq_len, 5))
558562
additional_token_feats = torch.randn(2, seq_len, 2)
559-
is_molecule_types = torch.randint(0, 2, (2, seq_len, 5)).bool()
563+
is_molecule_types = torch.randint(0, 2, (2, seq_len, IS_MOLECULE_TYPES)).bool()
560564
molecule_ids = torch.randint(0, 32, (2, seq_len))
561565

562566
atom_pos = torch.randn(2, atom_seq_len, 3)
@@ -621,7 +625,7 @@ def test_alphafold3_force_return_loss():
621625
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
622626
additional_molecule_feats = torch.randint(0, 2, (2, seq_len, 5))
623627
additional_token_feats = torch.randn(2, seq_len, 2)
624-
is_molecule_types = torch.randint(0, 2, (2, seq_len, 4)).bool()
628+
is_molecule_types = torch.randint(0, 2, (2, seq_len, IS_MOLECULE_TYPES)).bool()
625629
molecule_ids = torch.randint(0, 32, (2, seq_len))
626630

627631
atom_pos = torch.randn(2, atom_seq_len, 3)
@@ -716,7 +720,7 @@ def test_alphafold3_with_atom_and_bond_embeddings():
716720

717721
additional_molecule_feats = torch.randint(0, 2, (2, seq_len, 5))
718722
additional_token_feats = torch.randn(2, seq_len, 2)
719-
is_molecule_types = torch.randint(0, 2, (2, seq_len, 5)).bool()
723+
is_molecule_types = torch.randint(0, 2, (2, seq_len, IS_MOLECULE_TYPES)).bool()
720724
molecule_ids = torch.randint(0, 32, (2, seq_len))
721725

722726
template_feats = torch.randn(2, 2, seq_len, seq_len, 44)

tests/test_trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
create_alphafold3_from_yaml
2222
)
2323

24+
from alphafold3_pytorch.inputs import (
25+
IS_MOLECULE_TYPES
26+
)
27+
2428
def exists(v):
2529
return v is not None
2630

@@ -50,7 +54,7 @@ def __getitem__(self, idx):
5054
molecule_atom_lens = torch.randint(1, self.atoms_per_window, (seq_len,))
5155
additional_molecule_feats = torch.randint(0, 2, (seq_len, 5))
5256
additional_token_feats = torch.randn(seq_len, 2)
53-
is_molecule_types = torch.randint(0, 2, (seq_len, 4)).bool()
57+
is_molecule_types = torch.randint(0, 2, (seq_len, IS_MOLECULE_TYPES)).bool()
5458
molecule_ids = torch.randint(0, 32, (seq_len,))
5559
token_bonds = torch.randint(0, 2, (seq_len, seq_len)).bool()
5660

0 commit comments

Comments
 (0)