Skip to content

Commit 4727522

Browse files
First pass at adaptive HMM precision
1 parent d7934a3 commit 4727522

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

sc2ts/cli.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,11 @@ def setup_logging(verbosity, log_file=None):
7676
# to go to the log to see the traceback, but for production it may
7777
# be better to let daiquiri record the errors as well.
7878
daiquiri.setup(outputs=outputs, set_excepthook=False)
79-
# Only show stuff coming from sc2ts. Sometimes it's handy to look
80-
# at the tsinfer logs too, so we could add an option to set its
81-
# levels
79+
# Only show stuff coming from sc2ts and the relevant bits of tsinfer.
8280
logger = logging.getLogger("sc2ts")
8381
logger.setLevel(log_level)
82+
logger = logging.getLogger("tsinfer.inference")
83+
logger.setLevel(log_level)
8484

8585

8686
# TODO add options to list keys, dump specific alignments etc

sc2ts/inference.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,11 +506,37 @@ def match_samples(
506506
samples=samples,
507507
ts=base_ts,
508508
num_mismatches=num_mismatches,
509-
precision=precision,
509+
precision=2,
510510
num_threads=num_threads,
511511
show_progress=show_progress,
512512
mirror_coordinates=mirror_coordinates,
513513
)
514+
samples_to_rerun = []
515+
for sample in samples:
516+
hmm_cost = sample.get_hmm_cost(num_mismatches)
517+
logger.debug(
518+
f"First sketch: {sample.strain} hmm_cost={hmm_cost} path={sample.path}"
519+
)
520+
if hmm_cost >= 2:
521+
sample.path.clear()
522+
sample.mutations.clear()
523+
samples_to_rerun.append(sample)
524+
525+
if len(samples_to_rerun) > 0:
526+
match_tsinfer(
527+
samples=samples_to_rerun,
528+
ts=base_ts,
529+
num_mismatches=num_mismatches,
530+
precision=precision,
531+
num_threads=num_threads,
532+
show_progress=show_progress,
533+
mirror_coordinates=mirror_coordinates,
534+
)
535+
for sample in samples_to_rerun:
536+
hmm_cost = sample.get_hmm_cost(num_mismatches)
537+
logger.debug(
538+
f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}"
539+
)
514540

515541
match_db.add(samples, date, num_mismatches)
516542

0 commit comments

Comments
 (0)