Skip to content

Commit b93c800

Browse files
Reasonable tests for alignments module
1 parent a862c48 commit b93c800

File tree

2 files changed

+80
-31
lines changed

2 files changed

+80
-31
lines changed

sc2ts/alignments.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class AlignmentStore(collections.abc.Mapping):
8080
def __init__(self, path, mode="r"):
8181
map_size = 1024**4
8282
self.env = lmdb.Environment(
83-
path, subdir=False, readonly=mode == "r", map_size=map_size
83+
str(path), subdir=False, readonly=mode == "r", map_size=map_size
8484
)
8585

8686
def __enter__(self):
@@ -95,21 +95,6 @@ def close(self):
9595
def __str__(self):
9696
return f"AlignmentStore at {self.env.path()} contains {len(self)} alignments"
9797

98-
@staticmethod
99-
def initialise(path):
100-
"""
101-
Create a new store at this path.
102-
"""
103-
db_path = pathlib.Path(path)
104-
if db_path.exists():
105-
db_path.unlink()
106-
107-
reference = core.get_reference_sequence()
108-
with lmdb.Environment(str(db_path), subdir=False) as env:
109-
with env.begin(write=True) as txn:
110-
txn.put("MN908947".encode(), compress_alignment(reference))
111-
return AlignmentStore(path, "a")
112-
11398
def _flush(self, chunk):
11499
logger.debug(f"Flushing {len(chunk)} sequences")
115100
with self.env.begin(write=True) as txn:
@@ -157,20 +142,6 @@ def __len__(self):
157142
with self.env.begin() as txn:
158143
return txn.stat()["entries"]
159144

160-
def get_all(self, strains, sequence_length):
161-
A = np.zeros((len(strains), sequence_length), dtype=np.int8)
162-
with self.env.begin() as txn:
163-
for j, strain in enumerate(strains):
164-
val = txn.get(strain.encode())
165-
if val is None:
166-
raise KeyError(f"{strain} not found")
167-
a = decompress_alignment(val)
168-
if len(a) != sequence_length:
169-
raise ValueError(
170-
f"Alignment for {strain} not of length {sequence_length}"
171-
)
172-
return A
173-
174145

175146
@dataclasses.dataclass
176147
class MaskedAlignment:

tests/test_alignments.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,67 @@
1+
import pathlib
2+
import shutil
3+
import gzip
4+
15
import numpy as np
26
import pytest
37
from numpy.testing import assert_array_equal
48

5-
# FIXME - quick hacks here to get tests working
69
from sc2ts import alignments as sa
710
from sc2ts import core
811

912

13+
@pytest.fixture
14+
def data_cache():
15+
cache_path = pathlib.Path("tests/data/cache")
16+
if not cache_path.exists():
17+
cache_path.mkdir()
18+
return cache_path
19+
20+
21+
@pytest.fixture
22+
def alignments_fasta(data_cache):
23+
cache_path = data_cache / "alignments.fasta"
24+
if not cache_path.exists():
25+
with gzip.open("tests/data/alignments.fasta.gz") as src:
26+
with open(cache_path, "wb") as dest:
27+
shutil.copyfileobj(src, dest)
28+
return cache_path
29+
30+
31+
@pytest.fixture
32+
def alignments_store(data_cache, alignments_fasta):
33+
cache_path = data_cache / "alignments.db"
34+
if not cache_path.exists():
35+
with sa.AlignmentStore(cache_path, "a") as a:
36+
fasta = core.FastaReader(alignments_fasta)
37+
a.append(fasta, show_progress=False)
38+
return sa.AlignmentStore(cache_path)
39+
40+
41+
class TestAlignmentsStore:
42+
def test_info(self, alignments_store):
43+
assert "contains" in str(alignments_store)
44+
45+
def test_len(self, alignments_store):
46+
assert len(alignments_store) == 55
47+
48+
def test_fetch_known(self, alignments_store):
49+
a = alignments_store["SRR11772659"]
50+
assert a.shape == (core.REFERENCE_SEQUENCE_LENGTH,)
51+
assert a[0] == "X"
52+
assert a[1] == "N"
53+
assert a[-1] == "N"
54+
55+
def test_keys(self, alignments_store):
56+
keys = list(alignments_store.keys())
57+
assert len(keys) == len(alignments_store)
58+
assert "SRR11772659" in keys
59+
60+
def test_in(self, alignments_store):
61+
assert "SRR11772659" in alignments_store
62+
assert "NOT_IN_STORE" not in alignments_store
63+
64+
1065
def test_get_gene_coordinates():
1166
d = core.get_gene_coordinates()
1267
assert len(d) == 11
@@ -66,6 +121,12 @@ def test_examples(self, a):
66121
with pytest.raises(ValueError):
67122
sa.decode_alignment(np.array(a))
68123

124+
def test_encode_real(self, alignments_store):
125+
h = alignments_store["SRR11772659"]
126+
a = sa.encode_alignment(h)
127+
assert a[0] == -1
128+
assert a[-1] == -1
129+
69130

70131
class TestMasking:
71132
# Window size of 1 is weird because we have to have two or more
@@ -113,3 +174,20 @@ def test_bad_window_size(self, w):
113174
a = np.zeros(2, dtype=np.int8)
114175
with pytest.raises(ValueError):
115176
sa.mask_alignment(a, window_size=w)
177+
178+
179+
class TestEncodeAndMask:
180+
def test_known(self, alignments_store):
181+
a = alignments_store["SRR11772659"]
182+
ma = sa.encode_and_mask(a)
183+
assert ma.original_base_composition == {
184+
"T": 9566,
185+
"A": 8894,
186+
"G": 5850,
187+
"C": 5472,
188+
"N": 121,
189+
}
190+
assert ma.original_md5 == "e96feaa72c4f4baba73c2e147ede7502"
191+
assert len(ma.masked_sites) == 133
192+
assert ma.masked_sites[0] == 1
193+
assert ma.masked_sites[-1] == 29903

0 commit comments

Comments
 (0)