Skip to content

Commit ac8280e

Browse files
Intermediate update with identical result at full precision
1 parent 23ad2d3 commit ac8280e

File tree

2 files changed

+84
-51
lines changed

2 files changed

+84
-51
lines changed

sc2ts/inference.py

Lines changed: 58 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -392,22 +392,51 @@ def match_samples(
392392
date,
393393
samples,
394394
*,
395-
match_db,
396395
base_ts,
397396
num_mismatches=None,
398397
show_progress=False,
399398
num_threads=None,
400-
precision=None,
401-
mirror_coordinates=False,
402399
):
403-
if num_mismatches is None:
404-
# Default to no recombination
405-
num_mismatches = 1000
406-
407-
# FIXME Something wrong here, we don't seem to get precisely the same
408-
# ARG for some reason. Need to track it down
409-
# Also: should only run the things at low precision that have that HMM cost.
410-
# Start out by setting everything to have 0 mutations and work up from there.
400+
# First pass, compute the matches at precision=0.
401+
# precision = 0
402+
# match_tsinfer(
403+
# samples=samples,
404+
# ts=base_ts,
405+
# num_mismatches=num_mismatches,
406+
# precision=precision,
407+
# num_threads=num_threads,
408+
# show_progress=show_progress,
409+
# )
410+
411+
# cost_threshold = 1
412+
# rerun_batch = []
413+
# for sample in samples:
414+
# cost = sample.get_hmm_cost(num_mismatches)
415+
# logger.debug(
416+
# f"HMM@p={precision}: {sample.strain} hmm_cost={cost} path={sample.path}"
417+
# )
418+
# if cost > cost_threshold:
419+
# sample.path.clear()
420+
# sample.mutations.clear()
421+
# rerun_batch.append(sample)
422+
423+
rerun_batch = samples
424+
precision = 12
425+
logger.info(f"Rerunning batch of {len(rerun_batch)} at p={precision}")
426+
match_tsinfer(
427+
samples=rerun_batch,
428+
ts=base_ts,
429+
num_mismatches=num_mismatches,
430+
precision=12,
431+
num_threads=num_threads,
432+
show_progress=show_progress,
433+
)
434+
# for sample in samples_to_rerun:
435+
# hmm_cost = sample.get_hmm_cost(num_mismatches)
436+
# # print(f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}")
437+
# logger.debug(
438+
# f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}"
439+
# )
411440

412441
# remaining_samples = samples
413442
# for cost, precision in [(0, 0), (1, 2)]: #, (2, 3)]:
@@ -433,24 +462,8 @@ def match_samples(
433462
# samples_to_rerun.append(sample)
434463
# remaining_samples = samples_to_rerun
435464

436-
samples_to_rerun = samples
437-
match_tsinfer(
438-
samples=samples_to_rerun,
439-
ts=base_ts,
440-
num_mismatches=num_mismatches,
441-
precision=12,
442-
num_threads=num_threads,
443-
show_progress=show_progress,
444-
mirror_coordinates=mirror_coordinates,
445-
)
446-
for sample in samples_to_rerun:
447-
hmm_cost = sample.get_hmm_cost(num_mismatches)
448-
# print(f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}")
449-
logger.debug(
450-
f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}"
451-
)
452-
453-
match_db.add(samples, date, num_mismatches)
465+
# Return in sorted order so that results are deterministic
466+
return sorted(samples, key=lambda s: s.strain)
454467

455468

456469
def check_base_ts(ts):
@@ -526,7 +539,6 @@ def extend(
526539
min_group_size = 10
527540

528541
# TMP
529-
precision = 6
530542
check_base_ts(base_ts)
531543
logger.info(
532544
f"Extend {date}; ts:nodes={base_ts.num_nodes};samples={base_ts.num_samples};"
@@ -549,17 +561,16 @@ def extend(
549561
f"Got alignments for {len(samples)} of {len(metadata_matches)} in metadata"
550562
)
551563

552-
match_samples(
564+
samples = match_samples(
553565
date,
554566
samples,
555567
base_ts=base_ts,
556-
match_db=match_db,
557568
num_mismatches=num_mismatches,
558569
show_progress=show_progress,
559570
num_threads=num_threads,
560-
precision=precision,
561571
)
562572

573+
match_db.add(samples, date, num_mismatches)
563574
match_db.create_mask_table(base_ts)
564575
ts = increment_time(date, base_ts)
565576

@@ -810,23 +821,21 @@ def solve_num_mismatches(ts, k):
810821
NOTE! This is NOT taking into account the spatial distance along
811822
the genome, and so is not a very good model in some ways.
812823
"""
824+
# We can match against any node in tsinfer
813825
m = ts.num_sites
814-
n = ts.num_nodes # We can match against any node in tsinfer
815-
if k == 0:
816-
# Pathological things happen when k=0
817-
r = 1e-3
818-
mu = 1e-20
819-
else:
820-
# NOTE: the magnitude of mu matters because it puts a limit
821-
# on how low we can push the HMM precision. We should be able to solve
822-
# for the optimal value of this parameter such that the magnitude of the
823-
# values within the HMM are as large as possible (so that we can truncate
824-
# usefully).
825-
mu = 1e-2
826-
denom = (1 - mu) ** k + (n - 1) * mu**k
827-
r = n * mu**k / denom
828-
assert mu < 0.5
829-
assert r < 0.5
826+
n = ts.num_nodes
827+
# values of k <= 1 are not relevant for SC2 and lead to awkward corner cases
828+
assert k > 1
829+
830+
# NOTE: the magnitude of mu matters because it puts a limit
831+
# on how low we can push the HMM precision. We should be able to solve
832+
# for the optimal value of this parameter such that the magnitude of the
833+
# values within the HMM are as large as possible (so that we can truncate
834+
# usefully).
835+
# mu = 1e-2
836+
mu = 0.125
837+
denom = (1 - mu) ** k + (n - 1) * mu**k
838+
r = n * mu**k / denom
830839

831840
# Add a little bit of extra mass for recombination so that we deterministically
832841
# chose to recombine over k mutations

tests/test_inference.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,30 @@ def test_first_day(self, tmp_path, fx_ts_map, fx_alignment_store, fx_metadata_db
561561

562562
assert ts.tables.equals(fx_ts_map["2020-01-19"].tables, ignore_provenance=True)
563563

564+
def test_2020_02_02(self, tmp_path, fx_ts_map, fx_alignment_store, fx_metadata_db):
565+
ts = sc2ts.extend(
566+
alignment_store=fx_alignment_store,
567+
metadata_db=fx_metadata_db,
568+
base_ts=fx_ts_map["2020-02-01"],
569+
date="2020-02-02",
570+
match_db=sc2ts.MatchDb.initialise(tmp_path / "match.db"),
571+
)
572+
assert ts.num_samples == 26
573+
assert np.sum(ts.nodes_time[ts.samples()] == 0) == 4
574+
samples = {}
575+
for u in ts.samples()[-4:]:
576+
node = ts.node(u)
577+
samples[node.metadata["strain"]] = node
578+
smd = node.metadata["sc2ts"]
579+
md = node.metadata
580+
print(md["date"], md["strain"], len(smd["mutations"]))
581+
# print(samples)
582+
# print(fx_ts_map["2020-02-01"])
583+
# print(ts)
584+
# print(fx_ts_map["2020-02-02"])
585+
ts.tables.assert_equals(fx_ts_map["2020-02-02"].tables, ignore_provenance=True)
586+
587+
564588
@pytest.mark.parametrize("date", dates)
565589
def test_date_metadata(self, fx_ts_map, date):
566590
ts = fx_ts_map[date]
@@ -675,7 +699,7 @@ class TestMatchingDetails:
675699
@pytest.mark.parametrize(
676700
("strain", "parent"), [("SRR11597207", 41), ("ERR4205570", 58)]
677701
)
678-
@pytest.mark.parametrize("num_mismatches", [1, 2, 3, 4])
702+
@pytest.mark.parametrize("num_mismatches", [2, 3, 4])
679703
@pytest.mark.parametrize("precision", [0, 1, 2, 12])
680704
def test_exact_matches(
681705
self,
@@ -707,7 +731,7 @@ def test_exact_matches(
707731
("strain", "parent", "position", "derived_state"),
708732
[("SRR11597218", 10, 289, "T"), ("ERR4206593", 58, 26994, "T")],
709733
)
710-
@pytest.mark.parametrize("num_mismatches", [1, 2, 3, 4])
734+
@pytest.mark.parametrize("num_mismatches", [2, 3, 4])
711735
@pytest.mark.parametrize("precision", [0, 1, 2, 12])
712736
def test_one_mismatch(
713737
self,

0 commit comments

Comments
 (0)