Skip to content

Commit 8abd14d

Browse files
committed
add test for protein pretraining class
1 parent a71b199 commit 8abd14d

File tree

1 file changed

+71
-0
lines changed

1 file changed

+71
-0
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import unittest
2+
from unittest.mock import PropertyMock, mock_open, patch
3+
from chebai.preprocessing.datasets.protein_pretraining import _ProteinPretrainingData
4+
from chebai.preprocessing.reader import ProteinDataReader
5+
from tests.unit.mock_data.ontology_mock_data import GOUniProtMockData
6+
7+
8+
class TestProteinPretrainingData(unittest.TestCase):
9+
"""
10+
Unit tests for the _ProteinPretrainingData class.
11+
Tests focus on data parsing and validation checks for protein pretraining.
12+
"""
13+
14+
@classmethod
15+
@patch.multiple(_ProteinPretrainingData, __abstractmethods__=frozenset())
16+
@patch.object(_ProteinPretrainingData, "base_dir", new_callable=PropertyMock)
17+
@patch.object(_ProteinPretrainingData, "_name", new_callable=PropertyMock)
18+
@patch("os.makedirs", return_value=None)
19+
def setUpClass(
20+
cls,
21+
mock_makedirs,
22+
mock_name_property: PropertyMock,
23+
mock_base_dir_property: PropertyMock,
24+
) -> None:
25+
"""
26+
Class setup for mocking abstract properties of _ProteinPretrainingData.
27+
28+
Mocks the required abstract properties and sets up the data extractor.
29+
"""
30+
mock_base_dir_property.return_value = "MockedBaseDirPropProteinPretrainingData"
31+
mock_name_property.return_value = "MockedNameProp_ProteinPretrainingData"
32+
33+
# Set the READER class for the pretraining data
34+
_ProteinPretrainingData.READER = ProteinDataReader
35+
36+
# Initialize the extractor instance
37+
cls.extractor = _ProteinPretrainingData()
38+
39+
@patch(
40+
"builtins.open",
41+
new_callable=mock_open,
42+
read_data=GOUniProtMockData.get_UniProt_raw_data(),
43+
)
44+
def test_parse_protein_data_for_pretraining(self, mock_open_file: mock_open) -> None:
45+
"""
46+
Tests the _parse_protein_data_for_pretraining method.
47+
48+
Verifies that:
49+
- The parsed DataFrame contains the expected protein IDs.
50+
- The protein sequences are not empty.
51+
"""
52+
# Parse the pretraining data
53+
pretrain_df = self.extractor._parse_protein_data_for_pretraining()
54+
list_of_pretrain_swiss_ids = GOUniProtMockData.proteins_for_pretraining()
55+
56+
# Assert that all expected Swiss-Prot IDs are present in the DataFrame
57+
self.assertEqual(
58+
set(pretrain_df['swiss_id']),
59+
set(list_of_pretrain_swiss_ids),
60+
msg="The parsed DataFrame does not contain the expected Swiss-Prot IDs for pretraining."
61+
)
62+
63+
# Assert that all sequences are not empty
64+
self.assertTrue(
65+
pretrain_df['sequence'].str.len().gt(0).all(),
66+
msg="Some protein sequences in the pretraining DataFrame are empty."
67+
)
68+
69+
70+
if __name__ == "__main__":
71+
unittest.main()

0 commit comments

Comments
 (0)