Skip to content

Commit e8435e5

Browse files
Add systematic matching tests against fixtures
1 parent 6c54c0f commit e8435e5

File tree

2 files changed

+120
-90
lines changed

2 files changed

+120
-90
lines changed

sc2ts/inference.py

Lines changed: 50 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -433,73 +433,6 @@ def asdict(self):
433433
# last_ts = ts
434434

435435

436-
def preprocess(
437-
date,
438-
*,
439-
base_ts,
440-
metadata_db,
441-
alignment_store,
442-
max_daily_samples=None,
443-
show_progress=False,
444-
):
445-
samples = []
446-
metadata_matches = list(metadata_db.get(date))
447-
448-
if len(metadata_matches) == 0:
449-
logger.warn(f"Zero metadata matches for {date}")
450-
return []
451-
452-
if date.endswith("12-31"):
453-
logger.warning(f"Skipping {len(metadata_matches)} samples for {date}")
454-
return []
455-
456-
# TODO implement this.
457-
assert max_daily_samples is None
458-
459-
keep_sites = base_ts.sites_position.astype(int)
460-
problematic_sites = core.get_problematic_sites()
461-
samples = []
462-
463-
with tqdm.tqdm(
464-
metadata_matches,
465-
desc=f"Preprocess:{date}",
466-
disable=not show_progress,
467-
) as bar:
468-
for md in bar:
469-
strain = md["strain"]
470-
try:
471-
alignment = alignment_store[strain]
472-
except KeyError:
473-
logger.debug(f"No alignment stored for {strain}")
474-
continue
475-
476-
sample = Sample(strain, date, metadata=md)
477-
ma = alignments.encode_and_mask(alignment)
478-
# Always mask the problematic_sites as well. We need to do this
479-
# for follow-up matching to inspect recombinants, as tsinfer
480-
# needs us to keep all sites in the table when doing mirrored
481-
# coordinates.
482-
ma.alignment[problematic_sites] = -1
483-
sample.alignment_qc = ma.qc_summary()
484-
sample.masked_sites = ma.masked_sites
485-
sample.alignment = ma.alignment[keep_sites]
486-
samples.append(sample)
487-
num_Ns = ma.original_base_composition.get("N", 0)
488-
non_nuc_counts = dict(ma.original_base_composition)
489-
for nuc in "ACGT":
490-
del non_nuc_counts[nuc]
491-
counts = ",".join(
492-
f"{key}={count}" for key, count in sorted(non_nuc_counts.items())
493-
)
494-
num_masked = len(ma.masked_sites)
495-
logger.debug(f"Mask {strain}: masked={num_masked} {counts}")
496-
497-
logger.info(
498-
f"Got alignments for {len(samples)} of {len(metadata_matches)} in metadata"
499-
)
500-
return samples
501-
502-
503436
def match_samples(
504437
date,
505438
samples,
@@ -563,6 +496,47 @@ def check_base_ts(ts):
563496
assert len(sc2ts_md["samples_strain"]) == ts.num_samples
564497

565498

499+
def preprocess(samples_md, base_ts, date, alignment_store, show_progress=False):
500+
keep_sites = base_ts.sites_position.astype(int)
501+
problematic_sites = core.get_problematic_sites()
502+
503+
samples = []
504+
with tqdm.tqdm(
505+
samples_md,
506+
desc=f"Preprocess",
507+
disable=not show_progress,
508+
) as bar:
509+
for md in bar:
510+
strain = md["strain"]
511+
try:
512+
alignment = alignment_store[strain]
513+
except KeyError:
514+
logger.debug(f"No alignment stored for {strain}")
515+
continue
516+
sample = Sample(strain, date, metadata=md)
517+
ma = alignments.encode_and_mask(alignment)
518+
# Always mask the problematic_sites as well. We need to do this
519+
# for follow-up matching to inspect recombinants, as tsinfer
520+
# needs us to keep all sites in the table when doing mirrored
521+
# coordinates.
522+
ma.alignment[problematic_sites] = -1
523+
sample.alignment_qc = ma.qc_summary()
524+
sample.masked_sites = ma.masked_sites
525+
sample.alignment = ma.alignment[keep_sites]
526+
samples.append(sample)
527+
num_Ns = ma.original_base_composition.get("N", 0)
528+
non_nuc_counts = dict(ma.original_base_composition)
529+
for nuc in "ACGT":
530+
del non_nuc_counts[nuc]
531+
counts = ",".join(
532+
f"{key}={count}" for key, count in sorted(non_nuc_counts.items())
533+
)
534+
num_masked = len(ma.masked_sites)
535+
logger.debug(f"Mask {strain}: masked={num_masked} {counts}")
536+
537+
return samples
538+
539+
566540
def extend(
567541
*,
568542
alignment_store,
@@ -594,19 +568,22 @@ def extend(
594568
f"mutations={base_ts.num_mutations};date={base_ts.metadata['sc2ts']['date']}"
595569
)
596570

571+
metadata_matches = list(metadata_db.get(date))
572+
# TODO implement this.
573+
assert max_daily_samples is None
574+
597575
samples = preprocess(
598-
date,
599-
metadata_db=metadata_db,
600-
alignment_store=alignment_store,
601-
base_ts=base_ts,
602-
max_daily_samples=max_daily_samples,
603-
show_progress=show_progress,
576+
metadata_matches, base_ts, date, alignment_store, show_progress=show_progress
604577
)
605578

606579
if len(samples) == 0:
607580
logger.warning(f"Nothing to do for {date}")
608581
return base_ts
609582

583+
logger.info(
584+
f"Got alignments for {len(samples)} of {len(metadata_matches)} in metadata"
585+
)
586+
610587
match_samples(
611588
date,
612589
samples,

tests/test_inference.py

Lines changed: 70 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -546,21 +546,74 @@ def test_2020_02_10_metadata(self, fx_ts_2020_02_10):
546546

547547

548548
class TestMatchingDetails:
549-
550-
def test_exact_matches(self, fx_ts_2020_02_10, fx_alignment_store, fx_metadata_db):
551-
print("HERE")
552-
553-
def test_other_exact_matches(self, tmp_path, fx_ts_2020_02_10, fx_alignment_store, fx_metadata_db):
554-
print("HERE")
555-
match_db = sc2ts.MatchDb.initialise(tmp_path / "match.db")
556-
ts = sc2ts.extend(
557-
alignment_store=fx_alignment_store,
558-
metadata_db=fx_metadata_db,
559-
base_ts=fx_ts_2020_02_10,
560-
date="2020-02-11",
561-
match_db=match_db,
562-
min_group_size=2,
549+
@pytest.mark.parametrize(
550+
("strain", "parent"), [("SRR11597207", 42), ("ERR4205570", 62)]
551+
)
552+
@pytest.mark.parametrize("num_mismatches", [1, 2, 3, 4])
553+
@pytest.mark.parametrize("precision", [0, 1, 2, 12])
554+
def test_exact_matches(
555+
self,
556+
fx_ts_2020_02_10,
557+
fx_alignment_store,
558+
fx_metadata_db,
559+
strain,
560+
parent,
561+
num_mismatches,
562+
precision,
563+
):
564+
samples = sc2ts.preprocess(
565+
[fx_metadata_db[strain]], fx_ts_2020_02_10, "2020-02-20", fx_alignment_store
563566
)
564-
565-
566-
567+
sc2ts.match_tsinfer(
568+
samples=samples,
569+
ts=fx_ts_2020_02_10,
570+
num_mismatches=num_mismatches,
571+
precision=precision,
572+
num_threads=0,
573+
)
574+
s = samples[0]
575+
assert len(s.mutations) == 0
576+
assert len(s.path) == 1
577+
assert s.path[0].parent == parent
578+
579+
# def test_stuff(
580+
# self, tmp_path, fx_ts_2020_02_10, fx_alignment_store, fx_metadata_db
581+
# ):
582+
# # SRR11597207 0 42 0
583+
# # SRR11597218 1 10 1
584+
585+
# # date = "2020-02-11" # 2 samples
586+
# date = "2020-02-13" # 4 samples
587+
# samples = sc2ts.preprocess(
588+
# date,
589+
# metadata_db=fx_metadata_db,
590+
# alignment_store=fx_alignment_store,
591+
# base_ts=fx_ts_2020_02_10,
592+
# )
593+
# # print(samples)
594+
595+
# num_mismatches = 3
596+
# sc2ts.match_tsinfer(
597+
# samples=samples,
598+
# ts=fx_ts_2020_02_10,
599+
# num_mismatches=3,
600+
# precision=12,
601+
# num_threads=0,
602+
# )
603+
# for sample in samples:
604+
# print(
605+
# sample.strain,
606+
# sample.get_hmm_cost(num_mismatches),
607+
# sample.path[0].parent,
608+
# len(sample.mutations),
609+
# )
610+
611+
# # match_db = sc2ts.MatchDb.initialise(tmp_path / "match.db")
612+
# # ts = sc2ts.extend(
613+
# # alignment_store=fx_alignment_store,
614+
# # metadata_db=fx_metadata_db,
615+
# # base_ts=fx_ts_2020_02_10,
616+
# # date="2020-02-11",
617+
# # match_db=match_db,
618+
# # min_group_size=2,
619+
# # )

0 commit comments

Comments
 (0)