Skip to content

Commit 0056a19

Browse files
Merge pull request #249 from jeromekelleher/test_metadata
Test metadata
2 parents 62458c5 + d0c3a74 commit 0056a19

File tree

5 files changed

+148
-53
lines changed

5 files changed

+148
-53
lines changed

sc2ts/alignments.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,6 @@ def append(self, alignments, show_progress=False):
119119
self._flush(chunk)
120120
bar.close()
121121

122-
def __contains__(self, key):
123-
with self.env.begin() as txn:
124-
val = txn.get(key.encode())
125-
return val is not None
126-
127122
def __getitem__(self, key):
128123
with self.env.begin() as txn:
129124
val = txn.get(key.encode())

sc2ts/metadata.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import collections
12
import logging
23
import sqlite3
34
import pathlib
@@ -13,14 +14,27 @@ def dict_factory(cursor, row):
1314
return {key: value for key, value in zip(col_names, row)}
1415

1516

16-
class MetadataDb:
17+
class MetadataDb(collections.abc.Mapping):
1718
def __init__(self, path):
1819
uri = f"file:{path}"
1920
uri += "?mode=ro"
2021
self.uri = uri
2122
self.conn = sqlite3.connect(uri, uri=True)
2223
self.conn.row_factory = dict_factory
2324

25+
@staticmethod
26+
def import_csv(csv_path, db_path):
27+
df = pd.read_csv(csv_path, sep="\t")
28+
db_path = pathlib.Path(db_path)
29+
if db_path.exists():
30+
db_path.unlink()
31+
with sqlite3.connect(db_path) as conn:
32+
df.to_sql("samples", conn, index=False)
33+
conn.execute(
34+
"CREATE UNIQUE INDEX [ix_samples_strain] on 'samples' ([strain]);"
35+
)
36+
conn.execute("CREATE INDEX [ix_samples_date] on 'samples' ([date]);")
37+
2438
def __enter__(self):
2539
return self
2640

@@ -36,23 +50,23 @@ def __len__(self):
3650
row = self.conn.execute(sql).fetchone()
3751
return row["COUNT(*)"]
3852

53+
def __getitem__(self, key):
54+
sql = "SELECT * FROM samples WHERE strain==?"
55+
with self.conn:
56+
result = self.conn.execute(sql, [key]).fetchone()
57+
if result is None:
58+
raise KeyError(f"strain {key} not in DB")
59+
return result
60+
61+
def __iter__(self):
62+
sql = "SELECT strain FROM samples"
63+
with self.conn:
64+
for result in self.conn.execute(sql):
65+
yield result["strain"]
66+
3967
def close(self):
4068
self.conn.close()
4169

42-
@staticmethod
43-
def import_csv(csv_path, db_path):
44-
df = pd.read_csv(
45-
csv_path,
46-
sep="\t",
47-
)
48-
db_path = pathlib.Path(db_path)
49-
if db_path.exists():
50-
db_path.unlink()
51-
with sqlite3.connect(db_path) as conn:
52-
df.to_sql("samples", conn, index=False)
53-
conn.execute("CREATE INDEX [ix_samples_strain] on 'samples' ([strain]);")
54-
conn.execute("CREATE INDEX [ix_samples_date] on 'samples' ([date]);")
55-
5670
def get(self, date):
5771
sql = "SELECT * FROM samples WHERE date==?"
5872
with self.conn:

tests/conftest.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import pathlib
2+
import shutil
3+
import gzip
4+
5+
import pytest
6+
7+
import sc2ts
8+
9+
10+
@pytest.fixture
11+
def data_cache():
12+
cache_path = pathlib.Path("tests/data/cache")
13+
if not cache_path.exists():
14+
cache_path.mkdir()
15+
return cache_path
16+
17+
18+
@pytest.fixture
19+
def alignments_fasta(data_cache):
20+
cache_path = data_cache / "alignments.fasta"
21+
if not cache_path.exists():
22+
with gzip.open("tests/data/alignments.fasta.gz") as src:
23+
with open(cache_path, "wb") as dest:
24+
shutil.copyfileobj(src, dest)
25+
return cache_path
26+
27+
28+
@pytest.fixture
29+
def alignments_store(data_cache, alignments_fasta):
30+
cache_path = data_cache / "alignments.db"
31+
if not cache_path.exists():
32+
with sc2ts.AlignmentStore(cache_path, "a") as a:
33+
fasta = sc2ts.core.FastaReader(alignments_fasta)
34+
a.append(fasta, show_progress=False)
35+
return sc2ts.AlignmentStore(cache_path)
36+
37+
@pytest.fixture
38+
def metadata_db(data_cache):
39+
cache_path = data_cache / "metadata.db"
40+
tsv_path = "tests/data/metadata.tsv"
41+
if not cache_path.exists():
42+
sc2ts.MetadataDb.import_csv(tsv_path, cache_path)
43+
return sc2ts.MetadataDb(cache_path)

tests/test_alignments.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
import pathlib
2-
import shutil
3-
import gzip
4-
51
import numpy as np
62
import pytest
73
from numpy.testing import assert_array_equal
@@ -10,34 +6,6 @@
106
from sc2ts import core
117

128

13-
@pytest.fixture
14-
def data_cache():
15-
cache_path = pathlib.Path("tests/data/cache")
16-
if not cache_path.exists():
17-
cache_path.mkdir()
18-
return cache_path
19-
20-
21-
@pytest.fixture
22-
def alignments_fasta(data_cache):
23-
cache_path = data_cache / "alignments.fasta"
24-
if not cache_path.exists():
25-
with gzip.open("tests/data/alignments.fasta.gz") as src:
26-
with open(cache_path, "wb") as dest:
27-
shutil.copyfileobj(src, dest)
28-
return cache_path
29-
30-
31-
@pytest.fixture
32-
def alignments_store(data_cache, alignments_fasta):
33-
cache_path = data_cache / "alignments.db"
34-
if not cache_path.exists():
35-
with sa.AlignmentStore(cache_path, "a") as a:
36-
fasta = core.FastaReader(alignments_fasta)
37-
a.append(fasta, show_progress=False)
38-
return sa.AlignmentStore(cache_path)
39-
40-
419
class TestAlignmentsStore:
4210
def test_info(self, alignments_store):
4311
assert "contains" in str(alignments_store)
@@ -117,7 +85,7 @@ def test_lowercase_nucleotide_missing(self, hap):
11785
[0, -2],
11886
],
11987
)
120-
def test_examples(self, a):
88+
def test_error__examples(self, a):
12189
with pytest.raises(ValueError):
12290
sa.decode_alignment(np.array(a))
12391

tests/test_metadata.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import pytest
2+
import pandas as pd
3+
4+
5+
class TestMetadataDb:
6+
def test_known(self, metadata_db):
7+
record = metadata_db["SRR11772659"]
8+
assert record["strain"] == "SRR11772659"
9+
assert record["date"] == "2020-01-19"
10+
assert record["Viridian_pangolin"] == "A"
11+
12+
def test_missing_sequence(self, metadata_db):
13+
# We include sequence that's not in the alignments DB
14+
assert "ERR_MISSING" in metadata_db
15+
16+
def test_keys(self, metadata_db):
17+
keys = list(metadata_db.keys())
18+
assert "SRR11772659" in keys
19+
assert len(set(keys)) == len(keys)
20+
df = pd.read_csv("tests/data/metadata.tsv", sep="\t")
21+
assert set(keys) == set(df["strain"])
22+
23+
def test_in(self, metadata_db):
24+
assert "SRR11772659" in metadata_db
25+
assert "DEFO_NOT_IN_DB" not in metadata_db
26+
27+
def test_get_all_days(self, metadata_db):
28+
results = metadata_db.get_days()
29+
assert results == [
30+
"2020-01-01",
31+
"2020-01-19",
32+
"2020-01-24",
33+
"2020-01-25",
34+
"2020-01-28",
35+
"2020-01-29",
36+
"2020-01-30",
37+
"2020-01-31",
38+
"2020-02-01",
39+
"2020-02-02",
40+
"2020-02-03",
41+
"2020-02-04",
42+
"2020-02-05",
43+
"2020-02-06",
44+
"2020-02-07",
45+
"2020-02-08",
46+
"2020-02-09",
47+
"2020-02-10",
48+
"2020-02-11",
49+
"2020-02-13",
50+
]
51+
52+
def test_get_days_greater(self, metadata_db):
53+
results = metadata_db.get_days("2020-02-06")
54+
assert results == [
55+
"2020-02-07",
56+
"2020-02-08",
57+
"2020-02-09",
58+
"2020-02-10",
59+
"2020-02-11",
60+
"2020-02-13",
61+
]
62+
63+
def test_get_days_none(self, metadata_db):
64+
assert metadata_db.get_days("2022-02-06") == []
65+
66+
def test_get_first(self, metadata_db):
67+
results = list(metadata_db.get("2020-01-01"))
68+
assert len(results) == 1
69+
assert results[0] == metadata_db["SRR14631544"]
70+
71+
def test_get_multi(self, metadata_db):
72+
results = list(metadata_db.get("2020-02-11"))
73+
assert len(results) == 2
74+
for result in results:
75+
assert result["date"] == "2020-02-11"

0 commit comments

Comments
 (0)