Skip to content

Commit e996bfd

Browse files
Merge pull request #255 from jeromekelleher/add-some-matching-tests
Add some matching tests
2 parents 4a35bec + bad0019 commit e996bfd

File tree

6 files changed

+261
-156
lines changed

6 files changed

+261
-156
lines changed

sc2ts/inference.py

Lines changed: 50 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -433,73 +433,6 @@ def asdict(self):
433433
# last_ts = ts
434434

435435

436-
def preprocess(
437-
date,
438-
*,
439-
base_ts,
440-
metadata_db,
441-
alignment_store,
442-
max_daily_samples=None,
443-
show_progress=False,
444-
):
445-
samples = []
446-
metadata_matches = list(metadata_db.get(date))
447-
448-
if len(metadata_matches) == 0:
449-
logger.warn(f"Zero metadata matches for {date}")
450-
return []
451-
452-
if date.endswith("12-31"):
453-
logger.warning(f"Skipping {len(metadata_matches)} samples for {date}")
454-
return []
455-
456-
# TODO implement this.
457-
assert max_daily_samples is None
458-
459-
keep_sites = base_ts.sites_position.astype(int)
460-
problematic_sites = core.get_problematic_sites()
461-
samples = []
462-
463-
with tqdm.tqdm(
464-
metadata_matches,
465-
desc=f"Preprocess:{date}",
466-
disable=not show_progress,
467-
) as bar:
468-
for md in bar:
469-
strain = md["strain"]
470-
try:
471-
alignment = alignment_store[strain]
472-
except KeyError:
473-
logger.debug(f"No alignment stored for {strain}")
474-
continue
475-
476-
sample = Sample(strain, date, metadata=md)
477-
ma = alignments.encode_and_mask(alignment)
478-
# Always mask the problematic_sites as well. We need to do this
479-
# for follow-up matching to inspect recombinants, as tsinfer
480-
# needs us to keep all sites in the table when doing mirrored
481-
# coordinates.
482-
ma.alignment[problematic_sites] = -1
483-
sample.alignment_qc = ma.qc_summary()
484-
sample.masked_sites = ma.masked_sites
485-
sample.alignment = ma.alignment[keep_sites]
486-
samples.append(sample)
487-
num_Ns = ma.original_base_composition.get("N", 0)
488-
non_nuc_counts = dict(ma.original_base_composition)
489-
for nuc in "ACGT":
490-
del non_nuc_counts[nuc]
491-
counts = ",".join(
492-
f"{key}={count}" for key, count in sorted(non_nuc_counts.items())
493-
)
494-
num_masked = len(ma.masked_sites)
495-
logger.debug(f"Mask {strain}: masked={num_masked} {counts}")
496-
497-
logger.info(
498-
f"Got alignments for {len(samples)} of {len(metadata_matches)} in metadata"
499-
)
500-
return samples
501-
502-
503436
def match_samples(
504437
date,
505438
samples,
@@ -563,6 +496,47 @@ def check_base_ts(ts):
563496
assert len(sc2ts_md["samples_strain"]) == ts.num_samples
564497

565498

499+
def preprocess(samples_md, base_ts, date, alignment_store, show_progress=False):
500+
keep_sites = base_ts.sites_position.astype(int)
501+
problematic_sites = core.get_problematic_sites()
502+
503+
samples = []
504+
with tqdm.tqdm(
505+
samples_md,
506+
desc=f"Preprocess",
507+
disable=not show_progress,
508+
) as bar:
509+
for md in bar:
510+
strain = md["strain"]
511+
try:
512+
alignment = alignment_store[strain]
513+
except KeyError:
514+
logger.debug(f"No alignment stored for {strain}")
515+
continue
516+
sample = Sample(strain, date, metadata=md)
517+
ma = alignments.encode_and_mask(alignment)
518+
# Always mask the problematic_sites as well. We need to do this
519+
# for follow-up matching to inspect recombinants, as tsinfer
520+
# needs us to keep all sites in the table when doing mirrored
521+
# coordinates.
522+
ma.alignment[problematic_sites] = -1
523+
sample.alignment_qc = ma.qc_summary()
524+
sample.masked_sites = ma.masked_sites
525+
sample.alignment = ma.alignment[keep_sites]
526+
samples.append(sample)
527+
num_Ns = ma.original_base_composition.get("N", 0)
528+
non_nuc_counts = dict(ma.original_base_composition)
529+
for nuc in "ACGT":
530+
del non_nuc_counts[nuc]
531+
counts = ",".join(
532+
f"{key}={count}" for key, count in sorted(non_nuc_counts.items())
533+
)
534+
num_masked = len(ma.masked_sites)
535+
logger.debug(f"Mask {strain}: masked={num_masked} {counts}")
536+
537+
return samples
538+
539+
566540
def extend(
567541
*,
568542
alignment_store,
@@ -594,19 +568,22 @@ def extend(
594568
f"mutations={base_ts.num_mutations};date={base_ts.metadata['sc2ts']['date']}"
595569
)
596570

571+
metadata_matches = list(metadata_db.get(date))
572+
# TODO implement this.
573+
assert max_daily_samples is None
574+
597575
samples = preprocess(
598-
date,
599-
metadata_db=metadata_db,
600-
alignment_store=alignment_store,
601-
base_ts=base_ts,
602-
max_daily_samples=max_daily_samples,
603-
show_progress=show_progress,
576+
metadata_matches, base_ts, date, alignment_store, show_progress=show_progress
604577
)
605578

606579
if len(samples) == 0:
607580
logger.warning(f"Nothing to do for {date}")
608581
return base_ts
609582

583+
logger.info(
584+
f"Got alignments for {len(samples)} of {len(metadata_matches)} in metadata"
585+
)
586+
610587
match_samples(
611588
date,
612589
samples,

tests/conftest.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,24 @@
11
import pathlib
22
import shutil
33
import gzip
4+
import tskit
45

56
import pytest
67

78
import sc2ts
89

910

1011
@pytest.fixture
11-
def data_cache():
12+
def fx_data_cache():
1213
cache_path = pathlib.Path("tests/data/cache")
1314
if not cache_path.exists():
1415
cache_path.mkdir()
1516
return cache_path
1617

1718

1819
@pytest.fixture
19-
def alignments_fasta(data_cache):
20-
cache_path = data_cache / "alignments.fasta"
20+
def fx_alignments_fasta(fx_data_cache):
21+
cache_path = fx_data_cache / "alignments.fasta"
2122
if not cache_path.exists():
2223
with gzip.open("tests/data/alignments.fasta.gz") as src:
2324
with open(cache_path, "wb") as dest:
@@ -26,18 +27,43 @@ def alignments_fasta(data_cache):
2627

2728

2829
@pytest.fixture
29-
def alignments_store(data_cache, alignments_fasta):
30-
cache_path = data_cache / "alignments.db"
30+
def fx_alignment_store(fx_data_cache, fx_alignments_fasta):
31+
cache_path = fx_data_cache / "alignments.db"
3132
if not cache_path.exists():
3233
with sc2ts.AlignmentStore(cache_path, "a") as a:
33-
fasta = sc2ts.core.FastaReader(alignments_fasta)
34+
fasta = sc2ts.core.FastaReader(fx_alignments_fasta)
3435
a.append(fasta, show_progress=False)
3536
return sc2ts.AlignmentStore(cache_path)
3637

3738
@pytest.fixture
38-
def metadata_db(data_cache):
39-
cache_path = data_cache / "metadata.db"
39+
def fx_metadata_db(fx_data_cache):
40+
cache_path = fx_data_cache / "metadata.db"
4041
tsv_path = "tests/data/metadata.tsv"
4142
if not cache_path.exists():
4243
sc2ts.MetadataDb.import_csv(tsv_path, cache_path)
4344
return sc2ts.MetadataDb(cache_path)
45+
46+
47+
@pytest.fixture
48+
def fx_ts_2020_02_10(tmp_path, fx_data_cache, fx_metadata_db, fx_alignment_store):
49+
target_date = "2020-02-10"
50+
cache_path = fx_data_cache / f"{target_date}.ts"
51+
if not cache_path.exists():
52+
last_ts = sc2ts.initial_ts()
53+
match_db = sc2ts.MatchDb.initialise(tmp_path / "match.db")
54+
for date in fx_metadata_db.date_sample_counts():
55+
print("INFERRING", date)
56+
last_ts = sc2ts.extend(
57+
alignment_store=fx_alignment_store,
58+
metadata_db=fx_metadata_db,
59+
base_ts=last_ts,
60+
date=date,
61+
match_db=match_db,
62+
min_group_size=2,
63+
)
64+
if date == target_date:
65+
break
66+
last_ts.dump(cache_path)
67+
return tskit.load(cache_path)
68+
69+

tests/test_alignments.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,27 @@
77

88

99
class TestAlignmentsStore:
10-
def test_info(self, alignments_store):
11-
assert "contains" in str(alignments_store)
10+
def test_info(self, fx_alignment_store):
11+
assert "contains" in str(fx_alignment_store)
1212

13-
def test_len(self, alignments_store):
14-
assert len(alignments_store) == 55
13+
def test_len(self, fx_alignment_store):
14+
assert len(fx_alignment_store) == 55
1515

16-
def test_fetch_known(self, alignments_store):
17-
a = alignments_store["SRR11772659"]
16+
def test_fetch_known(self, fx_alignment_store):
17+
a = fx_alignment_store["SRR11772659"]
1818
assert a.shape == (core.REFERENCE_SEQUENCE_LENGTH,)
1919
assert a[0] == "X"
2020
assert a[1] == "N"
2121
assert a[-1] == "N"
2222

23-
def test_keys(self, alignments_store):
24-
keys = list(alignments_store.keys())
25-
assert len(keys) == len(alignments_store)
23+
def test_keys(self, fx_alignment_store):
24+
keys = list(fx_alignment_store.keys())
25+
assert len(keys) == len(fx_alignment_store)
2626
assert "SRR11772659" in keys
2727

28-
def test_in(self, alignments_store):
29-
assert "SRR11772659" in alignments_store
30-
assert "NOT_IN_STORE" not in alignments_store
28+
def test_in(self, fx_alignment_store):
29+
assert "SRR11772659" in fx_alignment_store
30+
assert "NOT_IN_STORE" not in fx_alignment_store
3131

3232

3333
def test_get_gene_coordinates():
@@ -89,8 +89,8 @@ def test_error__examples(self, a):
8989
with pytest.raises(ValueError):
9090
sa.decode_alignment(np.array(a))
9191

92-
def test_encode_real(self, alignments_store):
93-
h = alignments_store["SRR11772659"]
92+
def test_encode_real(self, fx_alignment_store):
93+
h = fx_alignment_store["SRR11772659"]
9494
a = sa.encode_alignment(h)
9595
assert a[0] == -1
9696
assert a[-1] == -1
@@ -145,8 +145,8 @@ def test_bad_window_size(self, w):
145145

146146

147147
class TestEncodeAndMask:
148-
def test_known(self, alignments_store):
149-
a = alignments_store["SRR11772659"]
148+
def test_known(self, fx_alignment_store):
149+
a = fx_alignment_store["SRR11772659"]
150150
ma = sa.encode_and_mask(a)
151151
assert ma.original_base_composition == {
152152
"T": 9566,

tests/test_cli.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ def test_additional_problematic_sites(self, tmp_path, additional):
4747

4848

4949
class TestListDates:
50-
def test_defaults(self, metadata_db):
50+
def test_defaults(self, fx_metadata_db):
5151
runner = ct.CliRunner(mix_stderr=False)
5252
result = runner.invoke(
5353
cli.cli,
54-
f"list-dates {metadata_db.path}",
54+
f"list-dates {fx_metadata_db.path}",
5555
catch_exceptions=False,
5656
)
5757
assert result.exit_code == 0
@@ -78,11 +78,11 @@ def test_defaults(self, metadata_db):
7878
"2020-02-13",
7979
]
8080

81-
def test_counts(self, metadata_db):
81+
def test_counts(self, fx_metadata_db):
8282
runner = ct.CliRunner(mix_stderr=False)
8383
result = runner.invoke(
8484
cli.cli,
85-
f"list-dates {metadata_db.path} --counts",
85+
f"list-dates {fx_metadata_db.path} --counts",
8686
catch_exceptions=False,
8787
)
8888
assert result.exit_code == 0

0 commit comments

Comments
 (0)