Skip to content

Commit 8536d4e

Browse files
authored
Add ValueError when AptaTransPipeline receives a too small depth value (#220)
This PR resolves #196. Added a `raise ValueError(...)` when the `AptaTransPipeline` is initialized with `depth` smaller than 3. Therefore, changes are limited to `aptatrans.pipeline.AptaTransPipeline` and its corresponding tests, where I added a new test to check whether the exception is correctly raised. See the issue updated description for more info about why the bug occurred. I believe the best option is to raise an exception because the triplet (3-mers) encoding used by AptaTrans will simply produce empty vectors when provides sequences with length less than 3. This encoding is used in the original representation.
1 parent 7e3bfc8 commit 8536d4e

File tree

3 files changed

+30
-3
lines changed

3 files changed

+30
-3
lines changed

examples/aptatrans_tutorial.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,7 @@
548548
"name": "python",
549549
"nbconvert_exporter": "python",
550550
"pygments_lexer": "ipython3",
551-
"version": "3.12.9"
551+
"version": "3.11.13"
552552
}
553553
},
554554
"nbformat": 4,

pyaptamer/aptatrans/_pipeline.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,10 @@ class AptaTransPipeline:
4343
subsequences and their frequency should come from the same dataset used for
4444
pretraining the protein encoder.
4545
depth : int, optional, default=20
46-
The depth of the tree in the Monte Carlo Tree Search (MCTS) algorithm.
46+
The depth of the tree in the Monte Carlo Tree Search (MCTS) algorithm. Also
47+
defines the length of the generated aptamer candidates. Must be equal or
48+
greater than 3 since preprocessing uses triplet encoding (3-mers), which
49+
requires sequences of at least 3 nucleotides to extract overlapping triplets.
4750
n_iterations : int, optional, default=1000
4851
The number of iterations for the MCTS algorithm.
4952
@@ -91,7 +94,17 @@ def __init__(
9194
depth: int = 20,
9295
n_iterations: int = 1000,
9396
) -> None:
94-
super().__init__()
97+
"""
98+
Raises
99+
------
100+
ValueError
101+
If `depth` is less than 3.
102+
"""
103+
if depth < 3:
104+
raise ValueError(
105+
f"Invalid depth value: {depth}. Must be grater or equal than 3."
106+
)
107+
95108
self.device = device
96109
self.model = model.to(device)
97110
self.depth = depth

pyaptamer/aptatrans/tests/test_aptatrans.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,20 @@ def test_initialization(self, device, prot_words):
230230
expected_prot_count = sum(1 for freq in prot_words.values() if freq > mean_freq)
231231
assert len(pipeline.prot_words) == expected_prot_count
232232

233+
@pytest.mark.parametrize("depth", [-1, 0, 1, 2])
234+
def test_initialization_with_small_depth(self, depth):
235+
"""Check ValueError is raised at initialization when depth is less than 3."""
236+
model = MockAptaTransNeuralNet(torch.device("cpu"))
237+
prot_words = {"AAA": 0.5, "AAC": 0.3, "AAG": 0.8}
238+
239+
with pytest.raises(ValueError):
240+
AptaTransPipeline(
241+
device=torch.device("cpu"),
242+
model=model,
243+
prot_words=prot_words,
244+
depth=depth,
245+
)
246+
233247
@pytest.mark.parametrize(
234248
"device, target",
235249
[

0 commit comments

Comments
 (0)