Skip to content

Commit c7970e0

Browse files
Fixup tests
realising there are multiple correct answers here
1 parent 86b8008 commit c7970e0

File tree

2 files changed

+34
-37
lines changed

2 files changed

+34
-37
lines changed

sc2ts/inference.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,7 @@ def add(self, samples, date, num_mismatches):
102102
pkl_compressed,
103103
)
104104
data.append(args)
105-
pango = sample.metadata.get("Viridian_pangolin", "Unknown")
106-
logger.debug(
107-
f"MatchDB insert: {sample.strain} {date} {pango} hmm_cost={hmm_cost[j]}"
108-
)
105+
logger.debug(f"MatchDB insert: hmm_cost={hmm_cost[j]} {sample.summary()}")
109106
# Batch insert, for efficiency.
110107
with self.conn:
111108
self.conn.executemany(sql, data)
@@ -150,11 +147,7 @@ def get(self, where_clause):
150147
for row in self.conn.execute(sql):
151148
pkl = row.pop("pickle")
152149
sample = pickle.loads(bz2.decompress(pkl))
153-
pango = sample.metadata.get("Viridian_pangolin", "Unknown")
154-
logger.debug(
155-
f"MatchDb got: {sample.strain} {sample.date} {pango} "
156-
f"hmm_cost={row['hmm_cost']}"
157-
)
150+
logger.debug(f"MatchDb got: {sample.summary()} hmm_cost={row['hmm_cost']}")
158151
# print(row)
159152
yield sample
160153

@@ -364,6 +357,18 @@ class Sample:
364357
# def __str__(self):
365358
# return f"{self.strain}: {self.path} + {self.mutations}"
366359

360+
def path_summary(self):
361+
return ",".join(f"({seg.left}:{seg.right}, {seg.parent})" for seg in self.path)
362+
363+
def mutation_summary(self):
364+
return "[" + ",".join(str(mutation) for mutation in self.mutations) + "]"
365+
366+
def summary(self):
367+
pango = self.metadata.get("Viridian_pangolin", "Unknown")
368+
return (f"{self.strain} {self.date} {pango} path={self.path_summary()} "
369+
f"mutations({len(self.mutations)})={self.mutation_summary()}"
370+
)
371+
367372
@property
368373
def breakpoints(self):
369374
breakpoints = [seg.left for seg in self.path]
@@ -415,9 +420,7 @@ def match_samples(
415420
exceeding_threshold = []
416421
for sample in run_batch:
417422
cost = sample.get_hmm_cost(num_mismatches)
418-
logger.debug(
419-
f"HMM@p={precision}: {sample.strain} hmm_cost={cost} path={sample.path}"
420-
)
423+
logger.debug(f"HMM@p={precision}: hmm_cost={cost} {sample.summary()}")
421424
if cost > cost_threshold:
422425
sample.path.clear()
423426
sample.mutations.clear()
@@ -441,11 +444,9 @@ def match_samples(
441444
show_progress=show_progress,
442445
)
443446
for sample in run_batch:
444-
hmm_cost = sample.get_hmm_cost(num_mismatches)
447+
cost = sample.get_hmm_cost(num_mismatches)
445448
# print(f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}")
446-
logger.debug(
447-
f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}"
448-
)
449+
logger.debug(f"Final HMM pass hmm_cost={cost} {sample.summary()}")
449450
return samples
450451

451452

@@ -1439,7 +1440,7 @@ def get_closest_mutation(node, site_id):
14391440
sample.mutations.append(
14401441
MatchMutation(
14411442
site_id=site_id,
1442-
site_position=site_pos,
1443+
site_position=int(site_pos),
14431444
derived_state=derived_state,
14441445
inherited_state=inherited_state,
14451446
is_reversion=is_reversion,

tests/test_inference.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -571,20 +571,12 @@ def test_2020_02_02(self, tmp_path, fx_ts_map, fx_alignment_store, fx_metadata_d
571571
)
572572
assert ts.num_samples == 26
573573
assert np.sum(ts.nodes_time[ts.samples()] == 0) == 4
574-
samples = {}
575-
for u in ts.samples()[-4:]:
576-
node = ts.node(u)
577-
samples[node.metadata["strain"]] = node
578-
smd = node.metadata["sc2ts"]
579-
md = node.metadata
580-
print(md["date"], md["strain"], len(smd["mutations"]))
581574
# print(samples)
582575
# print(fx_ts_map["2020-02-01"])
583576
# print(ts)
584577
# print(fx_ts_map["2020-02-02"])
585578
ts.tables.assert_equals(fx_ts_map["2020-02-02"].tables, ignore_provenance=True)
586579

587-
588580
@pytest.mark.parametrize("date", dates)
589581
def test_date_metadata(self, fx_ts_map, date):
590582
ts = fx_ts_map[date]
@@ -601,7 +593,11 @@ def test_date_validate(self, fx_ts_map, fx_alignment_store, date):
601593

602594
@pytest.mark.parametrize("date", dates[1:])
603595
def test_node_mutation_counts(self, fx_ts_map, date):
604-
# Basic check to make sure our fixtures are what we expect
596+
# Basic check to make sure our fixtures are what we expect.
597+
# NOTE: this is somewhat fragile as the numbers of nodes does change
598+
# a little depending on the exact solution that the HMM choses, for
599+
# example when there are multiple single-mutation matches at different
600+
# sites.
605601
ts = fx_ts_map[date]
606602
expected = {
607603
"2020-01-19": {"nodes": 3, "mutations": 3},
@@ -616,13 +612,13 @@ def test_node_mutation_counts(self, fx_ts_map, date):
616612
"2020-02-03": {"nodes": 36, "mutations": 42},
617613
"2020-02-04": {"nodes": 41, "mutations": 48},
618614
"2020-02-05": {"nodes": 42, "mutations": 48},
619-
"2020-02-06": {"nodes": 48, "mutations": 51},
620-
"2020-02-07": {"nodes": 50, "mutations": 57},
621-
"2020-02-08": {"nodes": 56, "mutations": 58},
622-
"2020-02-09": {"nodes": 58, "mutations": 61},
623-
"2020-02-10": {"nodes": 59, "mutations": 65},
624-
"2020-02-11": {"nodes": 61, "mutations": 66},
625-
"2020-02-13": {"nodes": 65, "mutations": 68},
615+
"2020-02-06": {"nodes": 49, "mutations": 51},
616+
"2020-02-07": {"nodes": 51, "mutations": 57},
617+
"2020-02-08": {"nodes": 57, "mutations": 58},
618+
"2020-02-09": {"nodes": 59, "mutations": 61},
619+
"2020-02-10": {"nodes": 60, "mutations": 65},
620+
"2020-02-11": {"nodes": 62, "mutations": 66},
621+
"2020-02-13": {"nodes": 66, "mutations": 68},
626622
}
627623
assert ts.num_nodes == expected[date]["nodes"]
628624
assert ts.num_mutations == expected[date]["mutations"]
@@ -635,9 +631,9 @@ def test_node_mutation_counts(self, fx_ts_map, date):
635631
(13, "SRR11597132", 10),
636632
(16, "SRR11597177", 10),
637633
(41, "SRR11597156", 10),
638-
(56, "SRR11597216", 1),
639-
(59, "SRR11597207", 40),
640-
(61, "ERR4205570", 57),
634+
(57, "SRR11597216", 1),
635+
(60, "SRR11597207", 40),
636+
(62, "ERR4205570", 58),
641637
],
642638
)
643639
def test_exact_matches(self, fx_ts_map, node, strain, parent):
@@ -697,7 +693,7 @@ class TestMatchingDetails:
697693
# assert s.path[0].parent == 37
698694

699695
@pytest.mark.parametrize(
700-
("strain", "parent"), [("SRR11597207", 41), ("ERR4205570", 58)]
696+
("strain", "parent"), [("SRR11597207", 40), ("ERR4205570", 58)]
701697
)
702698
@pytest.mark.parametrize("num_mismatches", [2, 3, 4])
703699
@pytest.mark.parametrize("precision", [0, 1, 2, 12])

0 commit comments

Comments
 (0)