Skip to content

Commit 7d5a5be

Browse files
Debugging
1 parent bad0019 commit 7d5a5be

File tree

1 file changed

+35
-69
lines changed

1 file changed

+35
-69
lines changed

sc2ts/inference.py

Lines changed: 35 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -388,51 +388,6 @@ def asdict(self):
388388
}
389389

390390

391-
# def daily_extend(
392-
# *,
393-
# alignment_store,
394-
# metadata_db,
395-
# base_ts,
396-
# match_db,
397-
# num_mismatches=None,
398-
# max_hmm_cost=None,
399-
# min_group_size=None,
400-
# num_past_days=None,
401-
# show_progress=False,
402-
# max_submission_delay=None,
403-
# max_daily_samples=None,
404-
# num_threads=None,
405-
# precision=None,
406-
# rng=None,
407-
# excluded_sample_dir=None,
408-
# ):
409-
# assert num_past_days is None
410-
# assert max_submission_delay is None
411-
412-
# start_day = last_date(base_ts)
413-
414-
# last_ts = base_ts
415-
# for date in metadata_db.get_days(start_day):
416-
# ts = extend(
417-
# alignment_store=alignment_store,
418-
# metadata_db=metadata_db,
419-
# date=date,
420-
# base_ts=last_ts,
421-
# match_db=match_db,
422-
# num_mismatches=num_mismatches,
423-
# max_hmm_cost=max_hmm_cost,
424-
# min_group_size=min_group_size,
425-
# show_progress=show_progress,
426-
# max_submission_delay=max_submission_delay,
427-
# max_daily_samples=max_daily_samples,
428-
# num_threads=num_threads,
429-
# precision=precision,
430-
# )
431-
# yield ts, date
432-
433-
# last_ts = ts
434-
435-
436391
def match_samples(
437392
date,
438393
samples,
@@ -449,41 +404,50 @@ def match_samples(
449404
# Default to no recombination
450405
num_mismatches = 1000
451406

452-
match_tsinfer(
453-
samples=samples,
454-
ts=base_ts,
455-
num_mismatches=num_mismatches,
456-
precision=2,
457-
num_threads=num_threads,
458-
show_progress=show_progress,
459-
mirror_coordinates=mirror_coordinates,
460-
)
461-
samples_to_rerun = []
462-
for sample in samples:
463-
hmm_cost = sample.get_hmm_cost(num_mismatches)
464-
logger.debug(
465-
f"First sketch: {sample.strain} hmm_cost={hmm_cost} path={sample.path}"
466-
)
467-
if hmm_cost >= 2:
468-
sample.path.clear()
469-
sample.mutations.clear()
470-
samples_to_rerun.append(sample)
407+
remaining_samples = samples
408+
# FIXME Something wrong here, we don't seem to get precisely the same
409+
# ARG for some reason. Need to track it down
410+
# Also: should only run the things at low precision that have that HMM cost.
411+
# Start out by setting everything to have 0 mutations and work up from there.
471412

472-
if len(samples_to_rerun) > 0:
413+
for cost, precision in [(0, 0), (1, 2)]: #, (2, 3)]:
473414
match_tsinfer(
474-
samples=samples_to_rerun,
415+
samples=remaining_samples,
475416
ts=base_ts,
476417
num_mismatches=num_mismatches,
477418
precision=precision,
478419
num_threads=num_threads,
479420
show_progress=show_progress,
480421
mirror_coordinates=mirror_coordinates,
481422
)
482-
for sample in samples_to_rerun:
423+
samples_to_rerun = []
424+
for sample in remaining_samples:
483425
hmm_cost = sample.get_hmm_cost(num_mismatches)
426+
# print(f"HMM@p={precision}: {sample.strain} hmm_cost={hmm_cost} path={sample.path}")
484427
logger.debug(
485-
f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}"
428+
f"HMM@p={precision}: {sample.strain} hmm_cost={hmm_cost} path={sample.path}"
486429
)
430+
if hmm_cost > cost:
431+
sample.path.clear()
432+
sample.mutations.clear()
433+
samples_to_rerun.append(sample)
434+
remaining_samples = samples_to_rerun
435+
436+
match_tsinfer(
437+
samples=samples_to_rerun,
438+
ts=base_ts,
439+
num_mismatches=num_mismatches,
440+
precision=12,
441+
num_threads=num_threads,
442+
show_progress=show_progress,
443+
mirror_coordinates=mirror_coordinates,
444+
)
445+
for sample in samples_to_rerun:
446+
hmm_cost = sample.get_hmm_cost(num_mismatches)
447+
# print(f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}")
448+
logger.debug(
449+
f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}"
450+
)
487451

488452
match_db.add(samples, date, num_mismatches)
489453

@@ -855,7 +819,7 @@ def solve_num_mismatches(ts, k):
855819
# for the optimal value of this parameter such that the magnitude of the
856820
# values within the HMM are as large as possible (so that we can truncate
857821
# usefully).
858-
mu = 1e-3
822+
mu = 1e-2
859823
denom = (1 - mu) ** k + (n - 1) * mu**k
860824
r = n * mu**k / denom
861825
assert mu < 0.5
@@ -1312,6 +1276,8 @@ def match_tsinfer(
13121276
show_progress=False,
13131277
mirror_coordinates=False,
13141278
):
1279+
if len(samples) == 0:
1280+
return
13151281
genotypes = np.array([sample.alignment for sample in samples], dtype=np.int8).T
13161282
input_ts = ts
13171283
if mirror_coordinates:

0 commit comments

Comments
 (0)