Skip to content

Commit 71fa9fe

Browse files
committed
update test for protein reader for tokenindexer changes
1 parent d653f52 commit 71fa9fe

File tree

1 file changed

+65
-10
lines changed

1 file changed

+65
-10
lines changed

tests/unit/readers/testProteinDataReader.py

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
from typing import List
33
from unittest.mock import mock_open, patch
44

5-
from chebai_proteins.preprocessing.reader import EMBEDDING_OFFSET, ProteinDataReader
5+
from chebai.preprocessing.reader import EMBEDDING_OFFSET
6+
7+
from chebai_proteins.preprocessing.reader import ProteinDataReader
68

79

810
class TestProteinDataReader(unittest.TestCase):
@@ -25,14 +27,16 @@ def setUpClass(cls, mock_file: mock_open) -> None:
2527
"""
2628
cls.reader = ProteinDataReader(token_path="/mock/path")
2729
# After initializing, cls.reader.cache should now be set to ['M', 'K', 'T', 'F', 'R', 'N']
28-
assert cls.reader.cache == [
29-
"M",
30-
"K",
31-
"T",
32-
"F",
33-
"R",
34-
"N",
35-
], "Cache initialization did not match expected tokens."
30+
assert list(cls.reader.cache.items()) == list(
31+
{
32+
"M": 0,
33+
"K": 1,
34+
"T": 2,
35+
"F": 3,
36+
"R": 4,
37+
"N": 5,
38+
}.items()
39+
), "Initial cache does not match expected values or the order doesn't match."
3640

3741
def test_read_data(self) -> None:
3842
"""
@@ -86,7 +90,7 @@ def test_read_data_with_new_token(self) -> None:
8690
)
8791
# Ensure it's at the correct index
8892
self.assertEqual(
89-
self.reader.cache.index("Y"),
93+
self.reader.cache["Y"],
9094
len(self.reader.cache) - 1,
9195
"The new token 'Y' was not added at the correct index in the cache.",
9296
)
@@ -134,6 +138,57 @@ def test_read_data_with_repeated_tokens(self) -> None:
134138
"The _read_data method did not correctly handle repeated tokens.",
135139
)
136140

141+
@patch("builtins.open", new_callable=mock_open)
142+
def test_finish_method_for_new_tokens(self, mock_file: mock_open) -> None:
143+
"""
144+
Test the on_finish method to ensure it appends only the new tokens to the token file in order.
145+
"""
146+
# Simulate that some tokens were already loaded
147+
self.reader._loaded_tokens_count = 6 # 6 tokens already loaded
148+
self.reader.cache = {
149+
"M": 0,
150+
"K": 1,
151+
"T": 2,
152+
"F": 3,
153+
"R": 4,
154+
"N": 5,
155+
"W": 6, # New token 1
156+
"Y": 7, # New token 2
157+
"V": 8, # New token 3
158+
"Q": 9, # New token 4
159+
"E": 10, # New token 5
160+
}
161+
162+
# Run the on_finish method
163+
self.reader.on_finish()
164+
165+
# Check that the file was opened in append mode ('a')
166+
mock_file.assert_called_with(self.reader.token_path, "a")
167+
168+
# Verify the new tokens were written in the correct order
169+
mock_file().writelines.assert_called_with(
170+
["[H-]\n", "Br\n", "Cl\n", "Na\n", "Mg\n"]
171+
)
172+
173+
def test_finish_method_no_new_tokens(self) -> None:
174+
"""
175+
Test the on_finish method when no new tokens are added (cache is the same).
176+
"""
177+
self.reader._loaded_tokens_count = 6 # No new tokens
178+
self.reader.cache = {
179+
"M": 0,
180+
"K": 1,
181+
"T": 2,
182+
"F": 3,
183+
"R": 4,
184+
"N": 5,
185+
}
186+
187+
with patch("builtins.open", new_callable=mock_open) as mock_file:
188+
self.reader.on_finish()
189+
# Check that no new tokens were written
190+
mock_file().writelines.assert_not_called()
191+
137192

138193
if __name__ == "__main__":
139194
unittest.main()

0 commit comments

Comments
 (0)